├── README.md
├── config
└── framework_da.json
├── core
├── __pycache__
│ ├── logger.cpython-38.pyc
│ ├── metrics.cpython-38.pyc
│ └── wandb_logger.cpython-38.pyc
├── calc_indicator.py
├── logger.py
├── metrics.py
└── wandb_logger.py
├── data
├── FDA.py
├── HazeAug.py
├── LRHR_dataset.py
├── __init__.py
├── __pycache__
│ ├── FDA.cpython-38.pyc
│ ├── HazeAug.cpython-38.pyc
│ ├── LRHR_dataset.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ └── util.cpython-38.pyc
└── util.py
├── infer.py
├── misc
├── RTTS.jpg
└── framework-v3.jpg
├── model
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── base_model.cpython-38.pyc
│ ├── model.cpython-38.pyc
│ └── networks.cpython-38.pyc
├── base_model.py
├── dehaze_with_z_v2_modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── diffusion.cpython-38.pyc
│ │ └── unet.cpython-38.pyc
│ ├── diffusion.py
│ └── unet.py
├── model.py
└── networks.py
├── requirement.txt
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Frequency Compensated Diffusion Model for Real-scene Dehazing
3 |
4 |
5 | This is an official implementation of **Frequency Compensated Diffusion Model for Real-scene Dehazing** by **Pytorch**.
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
20 |
21 | ## News
22 | - 2025.03 We release a more powerful dehazing diffusion model [ProHaze](https://github.com/TianwenZhou/ProDehaze) based on SD-2.1.
23 |
24 | ## Getting started
25 | ### Installation
26 | * This repo is a modification on the [**SR3 Repo**](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement ).
27 |
28 | * Install third-party libraries.
29 |
30 | ```python
31 | pip install -r requirement.txt
32 | ```
33 |
34 | ### Data Prepare
35 |
36 | Download train/eval data from the following links:
37 |
38 | Training: [*RESIDE*](https://sites.google.com/view/reside-dehaze-datasets/reside-v0)
39 |
40 | Testing:
41 | [*I-Haze*](https://data.vision.ee.ethz.ch/cvl/ntire18//i-haze/#:~:text=To%20overcome%20this%20issue%20we%20introduce%20I-HAZE%2C%20a,real%20haze%20produced%20by%20a%20professional%20haze%20machine.) /
42 | [*O-Haze*](https://data.vision.ee.ethz.ch/cvl/ntire18/o-haze/) /
43 | [*Dense-Haze*](https://arxiv.org/abs/1904.02904#:~:text=To%20address%20this%20limitation%2C%20we%20introduce%20Dense-Haze%20-,introducing%20real%20haze%2C%20generated%20by%20professional%20haze%20machines.) /
44 | [*Nh-Haze*](https://data.vision.ee.ethz.ch/cvl/ntire20/nh-haze/) /
45 | [*RTTS*](https://sites.google.com/view/reside-dehaze-datasets/reside-standard?authuser=0)
46 |
47 | ```python
48 | mkdir dataset
49 | ```
50 |
51 | Re-organize the train/val images in the following file structure:
52 |
53 |
54 | ```shell
55 | #Training data file structure
56 | dataset/RESIDE/
57 | ├── HR # ground-truth clear images.
58 | ├── HR_hazy_src # hazy images.
59 | └── HR_depth # depth images (Generated by MonoDepth (github.com/OniroAI/MonoDepth-PyTorch)).
60 |
61 | #Testing data (e.g. DenseHaze) file structure
62 | dataset/{name}/
63 | ├── HR # ground-truth images.
64 | └── HR_hazy # hazy images.
65 | ```
66 |
67 | then make sure the correct data paths ("dataroot") in config/framework_da.json.
68 |
69 | ## Pretrained Model
70 |
71 | We prepared the pretrained model at:
72 |
73 | | Type | Weights |
74 | | ----------------------------------------------------------- | ------------------------------------------------------------ |
75 | | Generator | [OneDrive](https://1drv.ms/u/s!AsqtTP8eWS-penA8AqrU8c_I4jU) |
76 |
77 | ## Evaluation
78 |
79 | Download the test set (e.g O-Haze). Simply put the test images in "dataroot" and set the correct path in config/framework_da.json about "dataroot";
80 |
81 | Download the pretrained model and set the correct path in config/framework_da.json about "resume_state":
82 |
83 | ```json
84 | "path": {
85 | "log": "logs",
86 | "tb_logger": "tb_logger",
87 | "results": "results",
88 | "checkpoint": "checkpoint",
89 | "resume_state": "./ddpm_fcb_230221_121802"
90 | }
91 | "val": {
92 | "name": "dehaze_val",
93 | "mode": "LRHR",
94 | "dataroot": "dataset/O-HAZE-PROCESS",
95 | ...
96 | }
97 | ```
98 |
99 |
100 | ```python
101 | # infer
102 | python infer.py -c [config file]
103 | ```
104 |
105 | The default config file is config/framework_da.json. The outputs images are located at /data/diffusion/results. One can change output path in core/logger.py.
106 |
107 | ### Train
108 |
109 | Prepare train dataset and set the correct paths in config/framework_da.json about "datasets";
110 |
111 | If training from scratch, make sure "resume_state" is null in config/framework_da.json.
112 |
113 | ```python
114 | # infer
115 | python train.py -c [config file]
116 | ```
117 |
118 | ## Results
119 | Quantitative comparison on real-world hazy data (RTTS). Bold and underline indicate the best and the second-best, respectively.
120 |
121 |
122 |
123 |
124 | ## Todo
125 |
126 |
127 | - [x] Upload configs and pretrained models
128 |
129 | - [x] Upload evaluation scripts
130 |
131 | - [x] Upload train scripts
132 |
--------------------------------------------------------------------------------
/config/framework_da.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "framework_da",
3 | "phase": "train",
4 | // train or val
5 | "gpu_ids": [
6 | 0
7 | ],
8 | "change_sizes": {
9 | "0.0": 128,
10 | "0.3": 128,
11 | "0.9": 128,
12 | "1.01": 128
13 | },
14 | "path": {
15 | //set the path
16 | "log": "logs",
17 | "tb_logger": "tb_logger",
18 | "results": "results",
19 | "checkpoint": "checkpoint",
20 | "resume_state": "./ddpm_fcb_230221_121802"
21 | // "resume_state": null
22 | },
23 | "datasets": {
24 | "train": {
25 | "name": "RESIDE_train_syntheic",
26 | "mode": "HR",
27 | "dataroot": "dataset/RESIDE/HR_hazy_src",
28 | "hr_path": "dataset/RESIDE/HR",
29 | "datatype": "RESIDE_img_syntheic",
30 | "l_resolution": 128,
31 | "r_resolution": 128,
32 | "batch_size": 3,
33 | "num_workers": 12,
34 | "use_shuffle": true,
35 | "HazeAug": true,
36 | "rt_da_ref": [
37 | "dataset/RESIDE/HR_hazy_src"
38 | ],
39 | "depth_img_path": "dataset/RESIDE/HR_depth/",
40 | "data_len": -1 // -1 represents all data used in train
41 | },
42 | "val": {
43 | "name": "dehaze_val",
44 | "mode": "LRHR",
45 | // "dataroot": "dataset/I-HAZE-PROCESS",
46 | // "dataroot": "dataset/RTTS-PROCESS",
47 | "dataroot": "dataset/O-HAZE-PROCESS",
48 | // "dataroot": "dataset/DenseHaze",
49 | //"dataroot": "dataset/NhHaze",
50 | "datatype": "haze_img",
51 |
52 | "l_resolution": 512,
53 | "r_resolution": 512,
54 | "data_len": 5000
55 | }
56 | },
57 | "model": {
58 | "which_model_G": "dehaze_with_z_v2",
59 | "finetune_norm": false,
60 | "FCB": true,
61 | "unet": {
62 | "in_channel": 6,
63 | "out_channel": 3,
64 | "inner_channel": 64,
65 | "norm_groups": 16,
66 | "channel_multiplier": [
67 | 1,
68 | 2,
69 | 4,
70 | 8,
71 | 16
72 | ],
73 | "attn_res": [
74 | // 16
75 | ],
76 | "res_blocks": 1,
77 | "dropout": 0.2
78 | },
79 | "beta_schedule": {
80 | // use munual beta_schedule for acceleration
81 | "train": {
82 | "schedule": "linear",
83 | "n_timestep": 2000,
84 | "linear_start": 1e-6,
85 | "linear_end": 1e-2
86 | },
87 | "val": {
88 | "schedule": "linear",
89 | "n_timestep": 2000,
90 | "linear_start": 1e-6,
91 | "linear_end": 1e-2
92 | }
93 | },
94 | "diffusion": {
95 | "image_size": 128,
96 | "channels": 3,
97 | //sample channel
98 | "conditional": true,
99 | // unconditional generation or unconditional generation(super_resolution)
100 | "start_step": 1000
101 | }
102 | },
103 | "train": {
104 | "n_iter": 2000000,
105 | "save_checkpoint_freq": 1e4,
106 | "print_freq": 50,
107 | "optimizer": {
108 | "type": "adam",
109 | "lr": 1e-4
110 | },
111 | "ema_scheduler": {
112 | // not used now
113 | "step_start_ema": 5000,
114 | "update_ema_every": 1,
115 | "ema_decay": 0.9999
116 | }
117 | },
118 | "wandb": {
119 | "project": "dehaze_with_z_v2"
120 | }
121 | }
122 | // Ask AI to edit or generate...
123 |
124 |
--------------------------------------------------------------------------------
/core/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/core/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/core/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/core/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/core/__pycache__/wandb_logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/core/__pycache__/wandb_logger.cpython-38.pyc
--------------------------------------------------------------------------------
/core/calc_indicator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 | import cv2
5 | from torchvision.utils import make_grid
6 |
7 |
8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
9 | '''
10 | Converts a torch Tensor into an image Numpy array
11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
13 | '''
14 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
15 | tensor = (tensor - min_max[0]) / \
16 | (min_max[1] - min_max[0]) # to range [0,1]
17 | n_dim = tensor.dim()
18 | if n_dim == 4:
19 | n_img = len(tensor)
20 | img_np = make_grid(tensor, nrow=int(
21 | math.sqrt(n_img)), normalize=False).numpy()
22 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
23 | elif n_dim == 3:
24 | img_np = tensor.numpy()
25 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
26 | elif n_dim == 2:
27 | img_np = tensor.numpy()
28 | else:
29 | raise TypeError(
30 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
31 | if out_type == np.uint8:
32 | img_np = (img_np * 255.0).round()
33 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
34 | return img_np.astype(out_type)
35 |
36 |
37 | def save_img(img, img_path, mode='RGB'):
38 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
39 | # cv2.imwrite(img_path, img)
40 |
41 |
42 | def calculate_psnr(img1, img2):
43 | # img1 and img2 have range [0, 255]
44 | img1 = img1.astype(np.float64)
45 | img2 = img2.astype(np.float64)
46 | mse = np.mean((img1 - img2) ** 2)
47 | if mse == 0:
48 | return float('inf')
49 | return 20 * math.log10(255.0 / math.sqrt(mse))
50 |
51 |
52 | def ssim(img1, img2):
53 | C1 = (0.01 * 255) ** 2
54 | C2 = (0.03 * 255) ** 2
55 |
56 | img1 = img1.astype(np.float64)
57 | img2 = img2.astype(np.float64)
58 | kernel = cv2.getGaussianKernel(11, 1.5)
59 | window = np.outer(kernel, kernel.transpose())
60 |
61 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
62 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
63 | mu1_sq = mu1 ** 2
64 | mu2_sq = mu2 ** 2
65 | mu1_mu2 = mu1 * mu2
66 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
67 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
68 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
69 |
70 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
71 | (sigma1_sq + sigma2_sq + C2))
72 | return ssim_map.mean()
73 |
74 |
75 | def calculate_ssim(img1, img2):
76 | '''calculate SSIM
77 | the same outputs as MATLAB's
78 | img1, img2: [0, 255]
79 | '''
80 | if not img1.shape == img2.shape:
81 | raise ValueError('Input images must have the same dimensions.')
82 | if img1.ndim == 2:
83 | return ssim(img1, img2)
84 | elif img1.ndim == 3:
85 | if img1.shape[2] == 3:
86 | ssims = []
87 | for i in range(3):
88 | ssims.append(ssim(img1, img2))
89 | return np.array(ssims).mean()
90 | elif img1.shape[2] == 1:
91 | return ssim(np.squeeze(img1), np.squeeze(img2))
92 | else:
93 | raise ValueError('Wrong input image dimensions.')
94 |
95 |
96 | if __name__ == "__main__":
97 | path1 = "/data/ImageDehazing/tmp/DenseHaze/GCA/"
98 | path2 = "/data/ImageDehazing/DenseHaze/HR"
99 |
100 | img1s = sorted(os.listdir(path1))
101 | img2s = sorted(os.listdir(path2))
102 |
103 | ave_p, ave_s = 0, 0
104 | for idx, (img1, img2) in enumerate(zip(img1s, img2s)):
105 | im1 = cv2.imread(os.path.join(path1, img1))
106 | im2 = cv2.imread(os.path.join(path2, img2))
107 |
108 | im1 = cv2.resize(im1, (512, 512))
109 | im2 = cv2.resize(im2, (512, 512))
110 |
111 | s = calculate_ssim(im1, im2)
112 | p = calculate_psnr(im1, im2)
113 | print(img1, img2, "psnr:{}".format(p), " ssim:{}".format(s))
114 |
115 | ave_p += p
116 | ave_s += s
117 |
118 | print("ave ssim: {}, ave psnr:{}".format(ave_s / idx, ave_p / idx))
119 |
--------------------------------------------------------------------------------
/core/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | from collections import OrderedDict
5 | import json
6 | from datetime import datetime
7 |
8 |
9 | def mkdirs(paths):
10 | if isinstance(paths, str):
11 | os.makedirs(paths, exist_ok=True)
12 | else:
13 | for path in paths:
14 | os.makedirs(path, exist_ok=True)
15 |
16 |
17 | def get_timestamp():
18 | return datetime.now().strftime('%y%m%d_%H%M%S')
19 |
20 |
21 | def parse(args):
22 | phase = args.phase
23 | opt_path = args.config
24 | gpu_ids = args.gpu_ids
25 | enable_wandb = args.enable_wandb
26 | # remove comments starting with '//'
27 | json_str = ''
28 | with open(opt_path, 'r') as f:
29 | for line in f:
30 | line = line.split('//')[0] + '\n'
31 | json_str += line
32 | opt = json.loads(json_str, object_pairs_hook=OrderedDict)
33 |
34 | # set log directory
35 | if args.debug:
36 | opt['name'] = 'debug_{}'.format(opt['name'])
37 | experiments_root = os.path.join(
38 | '/data/diffusion/results', '{}_{}'.format(opt['name'], get_timestamp()))
39 | opt['path']['experiments_root'] = experiments_root
40 | for key, path in opt['path'].items():
41 | if 'resume' not in key and 'experiments' not in key:
42 | opt['path'][key] = os.path.join(experiments_root, path)
43 | mkdirs(opt['path'][key])
44 |
45 | # change dataset length limit
46 | opt['phase'] = phase
47 |
48 | # export CUDA_VISIBLE_DEVICES
49 | if gpu_ids is not None:
50 | opt['gpu_ids'] = [int(id) for id in gpu_ids.split(',')]
51 | gpu_list = gpu_ids
52 | else:
53 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
54 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
55 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
56 | if len(gpu_list) > 1:
57 | opt['distributed'] = True
58 | else:
59 | opt['distributed'] = False
60 |
61 | # debug
62 | if 'debug' in opt['name']:
63 | opt['train']['val_freq'] = 2
64 | opt['train']['print_freq'] = 2
65 | opt['train']['save_checkpoint_freq'] = 3
66 | opt['datasets']['train']['batch_size'] = 2
67 | opt['model']['beta_schedule']['train']['n_timestep'] = 10
68 | opt['model']['beta_schedule']['val']['n_timestep'] = 10
69 | opt['datasets']['train']['data_len'] = 6
70 | opt['datasets']['val']['data_len'] = 3
71 |
72 | # validation in train phase
73 | if phase == 'train':
74 | opt['datasets']['val']['data_len'] = 20
75 |
76 | # W&B Logging
77 | try:
78 | log_wandb_ckpt = args.log_wandb_ckpt
79 | opt['log_wandb_ckpt'] = log_wandb_ckpt
80 | except:
81 | pass
82 | try:
83 | log_eval = args.log_eval
84 | opt['log_eval'] = log_eval
85 | except:
86 | pass
87 | try:
88 | log_infer = args.log_infer
89 | opt['log_infer'] = log_infer
90 | except:
91 | pass
92 | opt['enable_wandb'] = enable_wandb
93 |
94 | return opt
95 |
96 |
97 | class NoneDict(dict):
98 | def __missing__(self, key):
99 | return None
100 |
101 |
102 | # convert to NoneDict, which return None for missing key.
103 | def dict_to_nonedict(opt):
104 | if isinstance(opt, dict):
105 | new_opt = dict()
106 | for key, sub_opt in opt.items():
107 | new_opt[key] = dict_to_nonedict(sub_opt)
108 | return NoneDict(**new_opt)
109 | elif isinstance(opt, list):
110 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
111 | else:
112 | return opt
113 |
114 |
115 | def dict2str(opt, indent_l=1):
116 | '''dict to string for logger'''
117 | msg = ''
118 | for k, v in opt.items():
119 | if isinstance(v, dict):
120 | msg += ' ' * (indent_l * 2) + k + ':[\n'
121 | msg += dict2str(v, indent_l + 1)
122 | msg += ' ' * (indent_l * 2) + ']\n'
123 | else:
124 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
125 | return msg
126 |
127 |
128 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False):
129 | '''set up logger'''
130 | l = logging.getLogger(logger_name)
131 | formatter = logging.Formatter(
132 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
133 | log_file = os.path.join(root, '{}.log'.format(phase))
134 | fh = logging.FileHandler(log_file, mode='w')
135 | fh.setFormatter(formatter)
136 | l.setLevel(level)
137 | l.addHandler(fh)
138 | if screen:
139 | sh = logging.StreamHandler()
140 | sh.setFormatter(formatter)
141 | l.addHandler(sh)
142 |
143 | return fh
144 |
--------------------------------------------------------------------------------
/core/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 | import cv2
5 | from torchvision.utils import make_grid
6 |
7 |
8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
9 | '''
10 | Converts a torch Tensor into an image Numpy array
11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
13 | '''
14 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
15 | tensor = (tensor - min_max[0]) / \
16 | (min_max[1] - min_max[0]) # to range [0,1]
17 | n_dim = tensor.dim()
18 | if n_dim == 4:
19 | n_img = len(tensor)
20 | img_np = make_grid(tensor, nrow=int(
21 | math.sqrt(n_img)), normalize=False).numpy()
22 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
23 | elif n_dim == 3:
24 | img_np = tensor.numpy()
25 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
26 | elif n_dim == 2:
27 | img_np = tensor.numpy()
28 | else:
29 | raise TypeError(
30 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
31 | if out_type == np.uint8:
32 | img_np = (img_np * 255.0).round()
33 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
34 | return img_np.astype(out_type)
35 |
36 |
37 | def save_img(img, img_path, mode='RGB'):
38 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
39 | # cv2.imwrite(img_path, img)
40 |
41 |
42 | def calculate_psnr(img1, img2):
43 | # img1 and img2 have range [0, 255]
44 | img1 = img1.astype(np.float64)
45 | img2 = img2.astype(np.float64)
46 | mse = np.mean((img1 - img2) ** 2)
47 | if mse == 0:
48 | return float('inf')
49 | return 20 * math.log10(255.0 / math.sqrt(mse))
50 |
51 |
52 | def ssim(img1, img2):
53 | C1 = (0.01 * 255) ** 2
54 | C2 = (0.03 * 255) ** 2
55 |
56 | img1 = img1.astype(np.float64)
57 | img2 = img2.astype(np.float64)
58 | kernel = cv2.getGaussianKernel(11, 1.5)
59 | window = np.outer(kernel, kernel.transpose())
60 |
61 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
62 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
63 | mu1_sq = mu1 ** 2
64 | mu2_sq = mu2 ** 2
65 | mu1_mu2 = mu1 * mu2
66 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
67 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
68 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
69 |
70 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
71 | (sigma1_sq + sigma2_sq + C2))
72 | return ssim_map.mean()
73 |
74 |
75 | def calculate_ssim(img1, img2):
76 | '''calculate SSIM
77 | the same outputs as MATLAB's
78 | img1, img2: [0, 255]
79 | '''
80 | if not img1.shape == img2.shape:
81 | raise ValueError('Input images must have the same dimensions.')
82 | if img1.ndim == 2:
83 | return ssim(img1, img2)
84 | elif img1.ndim == 3:
85 | if img1.shape[2] == 3:
86 | ssims = []
87 | for i in range(3):
88 | ssims.append(ssim(img1, img2))
89 | return np.array(ssims).mean()
90 | elif img1.shape[2] == 1:
91 | return ssim(np.squeeze(img1), np.squeeze(img2))
92 | else:
93 | raise ValueError('Wrong input image dimensions.')
94 |
--------------------------------------------------------------------------------
/core/wandb_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | class WandbLogger:
4 | """
5 | Log using `Weights and Biases`.
6 | """
7 | def __init__(self, opt):
8 | try:
9 | import wandb
10 | except ImportError:
11 | raise ImportError(
12 | "To use the Weights and Biases Logger please install wandb."
13 | "Run `pip install wandb` to install it."
14 | )
15 |
16 | self._wandb = wandb
17 |
18 | # Initialize a W&B run
19 | if self._wandb.run is None:
20 | self._wandb.init(
21 | project=opt['wandb']['project'],
22 | config=opt,
23 | dir='./experiments'
24 | )
25 |
26 | self.config = self._wandb.config
27 |
28 | if self.config.get('log_eval', None):
29 | self.eval_table = self._wandb.Table(columns=['fake_image',
30 | 'sr_image',
31 | 'hr_image',
32 | 'psnr',
33 | 'ssim'])
34 | else:
35 | self.eval_table = None
36 |
37 | if self.config.get('log_infer', None):
38 | self.infer_table = self._wandb.Table(columns=['fake_image',
39 | 'sr_image',
40 | 'hr_image'])
41 | else:
42 | self.infer_table = None
43 |
44 | def log_metrics(self, metrics, commit=True):
45 | """
46 | Log train/validation metrics onto W&B.
47 |
48 | metrics: dictionary of metrics to be logged
49 | """
50 | self._wandb.log(metrics, commit=commit)
51 |
52 | def log_image(self, key_name, image_array):
53 | """
54 | Log image array onto W&B.
55 |
56 | key_name: name of the key
57 | image_array: numpy array of image.
58 | """
59 | self._wandb.log({key_name: self._wandb.Image(image_array)})
60 |
61 | def log_images(self, key_name, list_images):
62 | """
63 | Log list of image array onto W&B
64 |
65 | key_name: name of the key
66 | list_images: list of numpy image arrays
67 | """
68 | self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]})
69 |
70 | def log_checkpoint(self, current_epoch, current_step):
71 | """
72 | Log the model checkpoint as W&B artifacts
73 |
74 | current_epoch: the current epoch
75 | current_step: the current batch step
76 | """
77 | model_artifact = self._wandb.Artifact(
78 | self._wandb.run.id + "_model", type="model"
79 | )
80 |
81 | gen_path = os.path.join(
82 | self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch))
83 | opt_path = os.path.join(
84 | self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch))
85 |
86 | model_artifact.add_file(gen_path)
87 | model_artifact.add_file(opt_path)
88 | self._wandb.log_artifact(model_artifact, aliases=["latest"])
89 |
90 | def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None):
91 | """
92 | Add data row-wise to the initialized table.
93 | """
94 | if psnr is not None and ssim is not None:
95 | self.eval_table.add_data(
96 | self._wandb.Image(fake_img),
97 | self._wandb.Image(sr_img),
98 | self._wandb.Image(hr_img),
99 | psnr,
100 | ssim
101 | )
102 | else:
103 | self.infer_table.add_data(
104 | self._wandb.Image(fake_img),
105 | self._wandb.Image(sr_img),
106 | self._wandb.Image(hr_img)
107 | )
108 |
109 | def log_eval_table(self, commit=False):
110 | """
111 | Log the table
112 | """
113 | if self.eval_table:
114 | self._wandb.log({'eval_data': self.eval_table}, commit=commit)
115 | elif self.infer_table:
116 | self._wandb.log({'infer_data': self.infer_table}, commit=commit)
117 |
--------------------------------------------------------------------------------
/data/FDA.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
4 | import numpy as np
5 | from PIL import Image
6 | # from FDA_utils import FDA_source_to_target_np
7 | import cv2
8 | import scipy.misc
9 | # from matplotlib import image
10 | import torch
11 | import numpy as np
12 |
13 |
14 | def extract_ampl_phase(fft_im):
15 | # fft_im: size should be bx3xhxwx2
16 | fft_amp = fft_im[:, :, :, :, 0] ** 2 + fft_im[:, :, :, :, 1] ** 2
17 | fft_amp = torch.sqrt(fft_amp)
18 | fft_pha = torch.atan2(fft_im[:, :, :, :, 1], fft_im[:, :, :, :, 0])
19 | return fft_amp, fft_pha
20 |
21 |
22 | def low_freq_mutate(amp_src, amp_trg, L=0.1):
23 | _, _, h, w = amp_src.size()
24 | b = (np.floor(np.amin((h, w)) * L)).astype(int) # get b
25 | amp_src[:, :, 0:b, 0:b] = amp_trg[:, :, 0:b, 0:b] # top left
26 | amp_src[:, :, 0:b, w - b:w] = amp_trg[:, :, 0:b, w - b:w] # top right
27 | amp_src[:, :, h - b:h, 0:b] = amp_trg[:, :, h - b:h, 0:b] # bottom left
28 | amp_src[:, :, h - b:h, w - b:w] = amp_trg[:, :, h - b:h, w - b:w] # bottom right
29 | return amp_src
30 |
31 |
32 | def low_freq_mutate_np(amp_src, amp_trg, L=0.1):
33 | a_src = np.fft.fftshift(amp_src, axes=(-2, -1))
34 | a_trg = np.fft.fftshift(amp_trg, axes=(-2, -1))
35 |
36 | _, h, w = a_src.shape
37 | b = (np.floor(np.amin((h, w)) * L)).astype(int)
38 | c_h = np.floor(h / 2.0).astype(int)
39 | c_w = np.floor(w / 2.0).astype(int)
40 |
41 | h1 = c_h - b
42 | h2 = c_h + b + 1
43 | w1 = c_w - b
44 | w2 = c_w + b + 1
45 |
46 | a_src[:, h1:h2, w1:w2] = a_trg[:, h1:h2, w1:w2]
47 | a_src = np.fft.ifftshift(a_src, axes=(-2, -1))
48 | return a_src
49 |
50 |
51 | def FDA_source_to_target(src_img, trg_img, L=0.1):
52 | # exchange magnitude
53 | # input: src_img, trg_img
54 |
55 | # get fft of both source and target
56 | fft_src = torch.rfft(src_img.clone(), signal_ndim=2, onesided=False)
57 | fft_trg = torch.rfft(trg_img.clone(), signal_ndim=2, onesided=False)
58 |
59 | # extract amplitude and phase of both ffts
60 | amp_src, pha_src = extract_ampl_phase(fft_src.clone())
61 | amp_trg, pha_trg = extract_ampl_phase(fft_trg.clone())
62 |
63 | # replace the low frequency amplitude part of source with that from target
64 | amp_src_ = low_freq_mutate(amp_src.clone(), amp_trg.clone(), L=L)
65 |
66 | # recompose fft of source
67 | fft_src_ = torch.zeros(fft_src.size(), dtype=torch.float)
68 | fft_src_[:, :, :, :, 0] = torch.cos(pha_src.clone()) * amp_src_.clone()
69 | fft_src_[:, :, :, :, 1] = torch.sin(pha_src.clone()) * amp_src_.clone()
70 |
71 | # get the recomposed image: source content, target style
72 | _, _, imgH, imgW = src_img.size()
73 | src_in_trg = torch.irfft(fft_src_, signal_ndim=2, onesided=False, signal_sizes=[imgH, imgW])
74 |
75 | return src_in_trg
76 |
77 |
78 | def FDA_source_to_target_np(src_img, trg_img, L=0.1):
79 | # exchange magnitude
80 | # input: src_img, trg_img
81 |
82 | src_img_np = src_img # .cpu().numpy()
83 | trg_img_np = trg_img # .cpu().numpy()
84 |
85 | # get fft of both source and target
86 | fft_src_np = np.fft.fft2(src_img_np, axes=(-2, -1))
87 | fft_trg_np = np.fft.fft2(trg_img_np, axes=(-2, -1))
88 |
89 | # extract amplitude and phase of both ffts
90 | amp_src, pha_src = np.abs(fft_src_np), np.angle(fft_src_np)
91 | amp_trg, pha_trg = np.abs(fft_trg_np), np.angle(fft_trg_np)
92 |
93 | # mutate the amplitude part of source with target
94 | amp_src_ = low_freq_mutate_np(amp_src, amp_trg, L=L)
95 |
96 | # mutated fft of source
97 | fft_src_ = amp_src_ * np.exp(1j * pha_src)
98 |
99 | # get the mutated image
100 | src_in_trg = np.fft.ifft2(fft_src_, axes=(-2, -1))
101 | src_in_trg = np.real(src_in_trg)
102 |
103 | return src_in_trg
104 |
105 |
106 | def trans_image_by_ref(in_path, ref_path, value=0.002):
107 | # im_src = Image.open(in_path).convert('RGB')
108 | im_src = in_path
109 | im_trg = Image.open(ref_path).convert('RGB')
110 | src_h, src_w, src_c = np.shape(im_src)
111 |
112 | im_src = im_src.resize((1024, 512), Image.BICUBIC)
113 | im_trg = im_trg.resize((1024, 512), Image.BICUBIC)
114 |
115 | im_src = np.asarray(im_src, np.float32)
116 | im_trg = np.asarray(im_trg, np.float32)
117 |
118 | im_src = im_src.transpose((2, 0, 1))
119 | im_trg = im_trg.transpose((2, 0, 1))
120 |
121 | src_in_trg = FDA_source_to_target_np(im_src, im_trg, L=value)
122 |
123 | src_in_trg = src_in_trg.transpose((1, 2, 0))
124 |
125 | # recover to src size
126 | src_in_trg = cv2.resize(src_in_trg, (src_w, src_h))
127 |
128 | src_in_trg = (src_in_trg - np.min(src_in_trg)) / (np.max(src_in_trg) - np.min(src_in_trg)) * 255
129 |
130 | # scipy.misc.toimage(src_in_trg, cmin=0.0, cmax=255.0).save('src_in_tar.png')
131 | # image.imsave('src_in_tar.png',src_in_trg) # cmap常用于改变绘制风格,如黑白gray,翠绿色virdidis
132 |
133 | # from PIL import Iamge
134 | img = Image.fromarray(np.uint8(src_in_trg)) # .covert('RGB')
135 |
136 | return img
137 |
--------------------------------------------------------------------------------
/data/HazeAug.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 | import cv2
5 | from data.FDA import trans_image_by_ref
6 |
7 | from PIL import Image, ImageFilter
8 |
9 | depth_argu = False
10 |
11 |
12 | def depth_change(depth):
13 | depth_strategy = np.random.uniform(0, 1)
14 |
15 | if 0.4 <= depth_strategy < 0.7:
16 | strategy = 'gamma'
17 | elif 0.7 <= depth_strategy < 1.0:
18 | strategy = 'normalize'
19 | else:
20 | strategy = 'identity'
21 |
22 | if strategy == "gamma":
23 | factor = np.random.uniform(0.2, 1.8)
24 |
25 | depth = np.array(depth ** factor)
26 |
27 | elif strategy == "normalize":
28 | # normalize float versions
29 | factor_alpha = np.random.uniform(0, 0.4)
30 | factor_beta = np.random.uniform(0, 2)
31 | depth = cv2.normalize(depth, None, alpha=factor_alpha, beta=factor_beta, norm_type=cv2.NORM_MINMAX,
32 | dtype=cv2.CV_32F)
33 |
34 | return depth
35 |
36 |
37 | class MyGaussianBlur(ImageFilter.Filter):
38 | name = "GaussianBlur"
39 |
40 | def __init__(self, radius=1, bounds=None):
41 | self.radius = radius
42 | self.bounds = bounds
43 |
44 | def filter(self, image):
45 | if self.bounds:
46 | clips = image.crop(self.bounds).gaussian_blur(self.radius)
47 | image.paste(clips, self.bounds)
48 | return image
49 | else:
50 | return image.gaussian_blur(self.radius)
51 |
52 |
53 | def rt_haze_enhancement(pil_img, depth_path, ref_path):
54 | # add_haze
55 | A = np.random.rand() * 1.3 + 0.5
56 | beta = 2 * np.random.rand() + 0.8
57 | color_strategy = np.random.rand()
58 | if color_strategy <= 0.5:
59 | strategy = 'colour_cast'
60 | # elif 0.3 < color_strategy <= 0.6:
61 | # strategy = 'luminance'
62 | else:
63 | strategy = 'add_hazy'
64 |
65 | img = cv2.imread(pil_img)
66 | depth = cv2.imread(depth_path, cv2.IMREAD_GRAYSCALE) / 256.0 # + 1e-7
67 |
68 | if depth_argu == False:
69 | depth = depth_change(depth)
70 |
71 | img_f = img / 255.0 # 归一化
72 |
73 | td_bk = np.exp(- np.array(depth) * beta)
74 | td_bk = np.expand_dims(td_bk, axis=-1).repeat(3, axis=-1)
75 | img_bk = np.array(img_f) * td_bk + A * (1 - td_bk)
76 |
77 | img_bk = img_bk / np.max(img_bk) * 255
78 | img_bk = img_bk[:, :, ::-1]
79 |
80 | if strategy == 'colour_cast':
81 | img_bk = Image.fromarray(np.uint8(img_bk)) # .covert('RGB')
82 | img_bk = trans_image_by_ref(
83 | in_path=img_bk,
84 | ref_path=ref_path,
85 | value=np.random.rand() * 0.002 + 0.0001
86 | )
87 |
88 | if strategy == 'luminance':
89 | img_bk = np.power(img_bk, 0.95) # 对像素值指数变换
90 | img_bk = Image.fromarray(np.uint8(img_bk)) # .covert('RGB')
91 |
92 | else:
93 | img_bk = Image.fromarray(np.uint8(img_bk)) # .covert('RGB')
94 |
95 | img_bk = img_bk.filter(ImageFilter.SMOOTH_MORE)
96 |
97 | return img_bk
98 |
--------------------------------------------------------------------------------
/data/LRHR_dataset.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import lmdb
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 | import random
6 | import data.util as Util
7 | import h5py, os
8 | import numpy as np
9 | import torch
10 | import copy
11 |
12 | from data.HazeAug import rt_haze_enhancement
13 |
14 |
15 | def neibor_16_mul(num, size=32):
16 | a = num // size
17 | b = num % size
18 | if b >= 0.5 * size:
19 | return size * (a + 1)
20 | else:
21 | return size * a
22 |
23 |
24 | class LRHRDataset(Dataset):
25 | def __init__(self, dataroot, datatype, l_resolution=16, r_resolution=128, split='train', data_len=-1, need_LR=False,
26 | other_params=None):
27 | self.datatype = datatype
28 | self.l_res = l_resolution
29 | self.r_res = r_resolution
30 | self.data_len = data_len
31 | self.need_LR = need_LR
32 | self.split = split
33 |
34 | self.down_sample = other_params['down_sample'] if "down_sample" in other_params.keys() else None
35 | self.real_hr_path = other_params['hr_path'] if "hr_path" in other_params.keys() else None
36 |
37 | # rt daRESIDE_img_syntheic
38 | self.rt_da = other_params['HazeAug'] if "HazeAug" in other_params.keys() else None
39 | if self.rt_da:
40 | self.rt_da_ref = other_params['rt_da_ref']
41 | self.ref_imgs = []
42 | for dir in self.rt_da_ref:
43 | self.ref_imgs += [os.path.join(dir, i) for i in os.listdir(dir)]
44 | self.depth_path = other_params['depth_img_path']
45 |
46 | if datatype in ["haze_img"]:
47 | self.sr_path = Util.get_paths_from_images("{}/HR_hazy".format(dataroot))
48 |
49 | self.hr_path = Util.get_paths_from_images("{}/HR".format(dataroot))
50 | self.dataset_len = len(self.hr_path)
51 |
52 | self.dis_prefix = other_params['distanse_prefix'] if "distanse_prefix" in other_params.keys() else None
53 | if self.data_len <= 0:
54 | self.data_len = self.dataset_len
55 | else:
56 | self.data_len = min(self.data_len, self.dataset_len)
57 |
58 | elif datatype in ["RESIDE_img_syntheic"]:
59 |
60 | self.sr_path = Util.get_paths_from_images(dataroot)
61 | self.hr_path = self.sr_path
62 | self.dataset_len = len(self.hr_path)
63 | if self.data_len <= 0:
64 | self.data_len = self.dataset_len
65 | else:
66 | self.data_len = min(self.data_len, self.dataset_len)
67 |
68 | else:
69 | raise NotImplementedError(
70 | 'data_type [{:s}] is not recognized.'.format(datatype))
71 |
72 | def __len__(self):
73 | return self.data_len
74 |
75 | def __getitem__(self, index):
76 | img_HR = None
77 | img_LR = None
78 |
79 | if self.datatype in ["RESIDE_img_syntheic"]:
80 |
81 | if self.rt_da:
82 |
83 | img_SR = rt_haze_enhancement(
84 | self.sr_path[index],
85 | os.path.join(self.depth_path, "{}.png".format(self.sr_path[index].split("/")[-1].split("_")[0])),
86 | ref_path=np.random.choice(self.ref_imgs)
87 | )
88 | else:
89 | img_SR = Image.open(self.sr_path[index]).convert("RGB")
90 |
91 | img_SR = img_SR.resize((self.r_res, self.r_res))
92 |
93 | # hr_path
94 | hr_path = "{}/{}.png".format(
95 | self.real_hr_path,
96 | self.sr_path[index].split("/")[-1].split("_")[0]
97 | )
98 | img_HR = Image.open(hr_path).convert("RGB")
99 | img_HR = img_HR.resize((self.r_res, self.r_res))
100 |
101 | if self.need_LR:
102 | img_LR = img_SR
103 |
104 | else:
105 | img_HR = Image.open(self.hr_path[index]).convert("RGB")
106 | img_SR = Image.open(self.sr_path[index]).convert("RGB")
107 | if self.need_LR:
108 | img_LR = Image.open(self.sr_path[index]).convert("RGB")
109 |
110 | if self.down_sample is not None:
111 | img_HR = self.resize(img_HR)
112 | img_SR = self.resize(img_SR)
113 | img_LR = self.resize(img_LR)
114 |
115 | if self.dis_prefix != None:
116 | img_depth = self.resize(img_depth)
117 |
118 | else:
119 | img_HR = self.resize_to_resolution(img_HR)
120 | img_SR = self.resize_to_resolution(img_SR)
121 | img_LR = self.resize_to_resolution(img_LR)
122 |
123 | if self.need_LR:
124 | [img_LR, img_SR, img_HR] = Util.transform_augment(
125 | [img_LR, img_SR, img_HR], split=self.split, min_max=(-1, 1))
126 |
127 | return {'LR': img_LR, 'HR': img_HR, 'SR': img_SR, 'Index': index}
128 | else:
129 | [img_SR, img_HR] = Util.transform_augment(
130 | [img_SR, img_HR], split=self.split, min_max=(-1, 1))
131 |
132 | return {'HR': img_HR, 'SR': img_SR, 'Index': index}
133 |
134 | def resize(self, input_image):
135 | H, W = np.shape(input_image)[:2]
136 | resize_H, resize_W = neibor_16_mul(int(H / self.down_sample)), neibor_16_mul(int(W / self.down_sample))
137 | out_image = input_image.resize((resize_W, resize_H))
138 | return out_image
139 |
140 | def resize_to_resolution(self, input_image):
141 | out_image = input_image.resize((self.r_res, self.r_res))
142 | return out_image
143 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | '''create dataset and dataloader'''
2 | import logging
3 | from re import split
4 | import torch.utils.data
5 |
6 |
7 | def create_dataloader(dataset, dataset_opt, phase):
8 | '''create dataloader '''
9 | if phase == 'train':
10 | return torch.utils.data.DataLoader(
11 | dataset,
12 | batch_size=dataset_opt['batch_size'],
13 | shuffle=dataset_opt['use_shuffle'],
14 | num_workers=dataset_opt['num_workers'],
15 | pin_memory=True)
16 | elif phase == 'val':
17 | return torch.utils.data.DataLoader(
18 | dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
19 | else:
20 | raise NotImplementedError(
21 | 'Dataloader [{:s}] is not found.'.format(phase))
22 |
23 |
24 | def create_dataset(dataset_opt, phase):
25 | '''create dataset'''
26 | mode = dataset_opt['mode']
27 | from data.LRHR_dataset import LRHRDataset as D
28 | dataset = D(dataroot=dataset_opt['dataroot'],
29 | datatype=dataset_opt['datatype'],
30 | l_resolution=dataset_opt['l_resolution'],
31 | r_resolution=dataset_opt['r_resolution'],
32 | split=phase,
33 | data_len=dataset_opt['data_len'],
34 | need_LR=(mode == 'LRHR'),
35 | other_params=dataset_opt
36 | )
37 | logger = logging.getLogger('base')
38 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
39 | dataset_opt['name']))
40 | return dataset
41 |
--------------------------------------------------------------------------------
/data/__pycache__/FDA.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/data/__pycache__/FDA.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/HazeAug.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/data/__pycache__/HazeAug.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/LRHR_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/data/__pycache__/LRHR_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/data/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/data/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | import random
5 | import numpy as np
6 |
7 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG',
8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
9 |
10 |
11 | def is_image_file(filename):
12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13 |
14 |
15 | def get_paths_from_images(path):
16 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
17 | images = []
18 | for dirpath, _, fnames in sorted(os.walk(path)):
19 | for fname in sorted(fnames):
20 | if is_image_file(fname):
21 | img_path = os.path.join(dirpath, fname)
22 | images.append(img_path)
23 | assert images, '{:s} has no valid image file'.format(path)
24 | return sorted(images)
25 |
26 |
27 | def augment(img_list, hflip=True, rot=True, split='val'):
28 | # horizontal flip OR rotate
29 | hflip = hflip and (split == 'train' and random.random() < 0.5)
30 | vflip = rot and (split == 'train' and random.random() < 0.5)
31 | rot90 = rot and (split == 'train' and random.random() < 0.5)
32 |
33 | def _augment(img):
34 | if hflip:
35 | img = img[:, ::-1, :]
36 | if vflip:
37 | img = img[::-1, :, :]
38 | if rot90:
39 | img = img.transpose(1, 0, 2)
40 | return img
41 |
42 | return [_augment(img) for img in img_list]
43 |
44 |
45 | def transform2numpy(img):
46 | img = np.array(img)
47 | img = img.astype(np.float32) / 255.
48 | if img.ndim == 2:
49 | img = np.expand_dims(img, axis=2)
50 | # some images have 4 channels
51 | if img.shape[2] > 3:
52 | img = img[:, :, :3]
53 | return img
54 |
55 |
56 | def transform2tensor(img, min_max=(0, 1)):
57 | # HWC to CHW
58 | img = torch.from_numpy(np.ascontiguousarray(
59 | np.transpose(img, (2, 0, 1)))).float()
60 | # to range min_max
61 | img = img * (min_max[1] - min_max[0]) + min_max[0]
62 | return img
63 |
64 |
65 | totensor = torchvision.transforms.ToTensor()
66 | hflip = torchvision.transforms.RandomHorizontalFlip()
67 |
68 |
69 | def transform_augment(img_list, split='val', min_max=(0, 1)):
70 | imgs = [totensor(img) for img in img_list]
71 | if split == 'train':
72 | imgs = torch.stack(imgs, 0)
73 | imgs = hflip(imgs)
74 | imgs = torch.unbind(imgs, dim=0)
75 | ret_img = [img * (min_max[1] - min_max[0]) + min_max[0] for img in imgs]
76 | return ret_img
77 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import data as Data
3 | import model as Model
4 | import argparse
5 | import logging
6 | import core.logger as Logger
7 | import core.metrics as Metrics
8 | from core.wandb_logger import WandbLogger
9 | from tensorboardX import SummaryWriter
10 | import os
11 | import numpy as np
12 | # from brisque import BRISQUE
13 | import cv2
14 | import random
15 |
16 | seed = 6666
17 | print('Random seed: {}'.format(seed))
18 | random.seed(seed)
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed_all(seed)
22 |
23 |
24 | def calc_mean_rgb(img):
25 | H, W, C = np.shape(img)
26 | img = np.reshape(img, (H * W, C))
27 | return np.mean(img, axis=0)
28 |
29 |
30 | def fix_img(img, img_ref):
31 | sr_R, sr_G, sr_B = calc_mean_rgb(img)
32 | hr_R, hr_G, hr_B = calc_mean_rgb(img_ref)
33 |
34 | R, G, B = sr_R - hr_R, sr_G - hr_G, sr_B - hr_B
35 | R = np.array(img[:, :, 0]) - R
36 | G = np.array(img[:, :, 1]) - G
37 | B = np.array(img[:, :, 2]) - B
38 |
39 | R = np.expand_dims(R, axis=-1)
40 | G = np.expand_dims(G, axis=-1)
41 | B = np.expand_dims(B, axis=-1)
42 |
43 | return np.array(np.concatenate((R, G, B), axis=-1), dtype=np.uint8)
44 |
45 |
46 | if __name__ == "__main__":
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('-c', '--config', type=str, default='config/framework_da.json',
49 | help='JSON file for configuration')
50 | parser.add_argument('-p', '--phase', type=str, choices=['val'], help='val(generation)', default='val')
51 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
52 | parser.add_argument('-debug', '-d', action='store_true')
53 | parser.add_argument('-enable_wandb', action='store_true')
54 | parser.add_argument('-log_infer', action='store_true')
55 | parser.add_argument('-color_fix', default=False)
56 |
57 | # parse configs
58 | args = parser.parse_args()
59 | print(args)
60 | opt = Logger.parse(args)
61 | # Convert to NoneDict, which return None for missing key.
62 | opt = Logger.dict_to_nonedict(opt)
63 |
64 | # logging
65 | torch.backends.cudnn.enabled = True
66 | torch.backends.cudnn.benchmark = True
67 |
68 | Logger.setup_logger(None, opt['path']['log'],
69 | 'train', level=logging.INFO, screen=True)
70 | Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
71 | logger = logging.getLogger('base')
72 | logger.info(Logger.dict2str(opt))
73 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])
74 |
75 | # Initialize WandbLogger
76 | if opt['enable_wandb']:
77 | wandb_logger = WandbLogger(opt)
78 | else:
79 | wandb_logger = None
80 |
81 | # dataset
82 | for phase, dataset_opt in opt['datasets'].items():
83 | if phase == 'val':
84 | val_set = Data.create_dataset(dataset_opt, phase)
85 | val_loader = Data.create_dataloader(
86 | val_set, dataset_opt, phase)
87 | logger.info('Initial Dataset Finished')
88 |
89 | # model
90 | diffusion = Model.create_model(opt)
91 | logger.info('Initial Model Finished')
92 |
93 | diffusion.set_new_noise_schedule(
94 | opt['model']['beta_schedule']['val'], schedule_phase='val')
95 |
96 | logger.info('Begin Model Inference.')
97 | current_step = 0
98 | current_epoch = 0
99 | idx = 0
100 | avg_psnr = 0.0
101 | avg_ssim = 0.0
102 |
103 | result_path = '{}'.format(opt['path']['results'])
104 | os.makedirs(result_path, exist_ok=True)
105 | for _, val_data in enumerate(val_loader):
106 |
107 | idx += 1
108 |
109 | diffusion.feed_data(val_data)
110 | diffusion.test(continous=True)
111 | visuals = diffusion.get_current_visuals(need_LR=False)
112 |
113 | visuals['SR'] = torch.cat([visuals['SR'], visuals['HR']], dim=0)
114 |
115 | hr_img = Metrics.tensor2img(visuals['HR']) # uint8
116 | fake_img = Metrics.tensor2img(visuals['INF']) # uint8
117 |
118 | sr_img_mode = 'grid'
119 | if sr_img_mode == 'single':
120 | # single img series
121 | sr_img = visuals['SR'] # uint8
122 | sample_num = sr_img.shape[0]
123 | for iter in range(0, sample_num):
124 | Metrics.save_img(
125 | Metrics.tensor2img(sr_img[iter]), '{}/{}_{}_sr_{}.png'.format(result_path, current_step, idx, iter))
126 | else:
127 | # grid img
128 | sr_img = Metrics.tensor2img(visuals['SR']) # uint8
129 |
130 | h, w, c = np.shape(hr_img)
131 |
132 | # try:
133 | # sr_img[-h-2:-2, -w-2:-2, :] = hr_img
134 | # except:
135 | # pass
136 |
137 | Metrics.save_img(
138 | sr_img, '{}/{}_{}_sr_process.png'.format(result_path, current_step, idx))
139 | Metrics.save_img(
140 | Metrics.tensor2img(visuals['SR'][-2]), '{}/{}_{}_sr.png'.format(result_path, current_step, idx))
141 |
142 | Metrics.save_img(
143 | hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx))
144 | Metrics.save_img(
145 | fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx))
146 |
147 | sr_img = Metrics.tensor2img(visuals['SR'][-2])
148 | if args.color_fix:
149 | # print(sr_img)
150 | # print(fake_img)
151 | # print(sr_img.shape)
152 | # print(fake_img.shape)
153 |
154 | sr_img = fix_img(sr_img, fake_img)
155 | # cv2.imwrite('{}/{}_{}_sr.png'.format(result_path, current_step, idx), sr_img)
156 |
157 | psnr = Metrics.calculate_psnr(sr_img, hr_img)
158 | ssim = Metrics.calculate_ssim(sr_img, hr_img)
159 | # brisque = BRISQUE('{}/{}_{}_sr.png'.format(result_path, current_step, idx)).score()
160 | brisque = 0
161 |
162 | avg_psnr += psnr
163 | avg_ssim += ssim
164 | print(f"psnr: {psnr}, ssim:{ssim}, save to {'{}/{}_{}_sr_process.png'.format(result_path, current_step, idx)}")
165 |
166 | if wandb_logger and opt['log_infer']:
167 | wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img)
168 |
169 | avg_psnr = avg_psnr / idx
170 | avg_ssim = avg_ssim / idx
171 |
172 | print(f"avg_psnr: {avg_psnr}, avg_ssim:{avg_ssim}")
173 |
174 | if wandb_logger and opt['log_infer']:
175 | wandb_logger.log_eval_table(commit=True)
176 |
--------------------------------------------------------------------------------
/misc/RTTS.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/misc/RTTS.jpg
--------------------------------------------------------------------------------
/misc/framework-v3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/misc/framework-v3.jpg
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.nn as nn
3 | logger = logging.getLogger('base')
4 |
5 |
6 | def create_model(opt):
7 | from .model import DDPM as M
8 | m = M(opt)
9 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
10 | # m = nn.DataParallel(m)
11 |
12 | return m
13 |
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/base_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/__pycache__/base_model.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/networks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/__pycache__/networks.cpython-38.pyc
--------------------------------------------------------------------------------
/model/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class BaseModel():
7 | def __init__(self, opt):
8 | self.opt = opt
9 | self.device = torch.device(
10 | 'cuda' if opt['gpu_ids'] is not None else 'cpu')
11 | self.begin_step = 0
12 | self.begin_epoch = 0
13 |
14 | def feed_data(self, data):
15 | pass
16 |
17 | def optimize_parameters(self, current_step):
18 | pass
19 |
20 | def get_current_visuals(self):
21 | pass
22 |
23 | def get_current_losses(self):
24 | pass
25 |
26 | def print_network(self):
27 | pass
28 |
29 | def set_device(self, x):
30 | if isinstance(x, dict):
31 | for key, item in x.items():
32 | if item is not None:
33 | x[key] = item.to(self.device)
34 | elif isinstance(x, list):
35 | for item in x:
36 | if item is not None:
37 | item = item.to(self.device)
38 | else:
39 | x = x.to(self.device)
40 | return x
41 |
42 | def get_network_description(self, network):
43 | '''Get the string and total parameters of the network'''
44 | if isinstance(network, nn.DataParallel):
45 | network = network.module
46 | s = str(network)
47 | n = sum(map(lambda x: x.numel(), network.parameters()))
48 | return s, n
49 |
--------------------------------------------------------------------------------
/model/dehaze_with_z_v2_modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/dehaze_with_z_v2_modules/__init__.py
--------------------------------------------------------------------------------
/model/dehaze_with_z_v2_modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/dehaze_with_z_v2_modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/dehaze_with_z_v2_modules/__pycache__/diffusion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/dehaze_with_z_v2_modules/__pycache__/diffusion.cpython-38.pyc
--------------------------------------------------------------------------------
/model/dehaze_with_z_v2_modules/__pycache__/unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/dehaze_with_z_v2_modules/__pycache__/unet.cpython-38.pyc
--------------------------------------------------------------------------------
/model/dehaze_with_z_v2_modules/diffusion.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import device, nn, einsum
4 | import torch.nn.functional as F
5 | from inspect import isfunction
6 | from functools import partial
7 | import numpy as np
8 | from tqdm import tqdm
9 | from torchvision.transforms import Resize
10 |
11 | import copy
12 |
13 |
14 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
15 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
16 | warmup_time = int(n_timestep * warmup_frac)
17 | betas[:warmup_time] = np.linspace(
18 | linear_start, linear_end, warmup_time, dtype=np.float64)
19 | return betas
20 |
21 |
22 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
23 | if schedule == 'quad':
24 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5,
25 | n_timestep, dtype=np.float64) ** 2
26 | elif schedule == 'linear':
27 | betas = np.linspace(linear_start, linear_end,
28 | n_timestep, dtype=np.float64)
29 | elif schedule == 'warmup10':
30 | betas = _warmup_beta(linear_start, linear_end,
31 | n_timestep, 0.1)
32 | elif schedule == 'warmup50':
33 | betas = _warmup_beta(linear_start, linear_end,
34 | n_timestep, 0.5)
35 | elif schedule == 'const':
36 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
37 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
38 | betas = 1. / np.linspace(n_timestep,
39 | 1, n_timestep, dtype=np.float64)
40 | elif schedule == "cosine":
41 | timesteps = (
42 | torch.arange(n_timestep + 1, dtype=torch.float64) /
43 | n_timestep + cosine_s
44 | )
45 | alphas = timesteps / (1 + cosine_s) * math.pi / 2
46 | alphas = torch.cos(alphas).pow(2)
47 | alphas = alphas / alphas[0]
48 | betas = 1 - alphas[1:] / alphas[:-1]
49 | betas = betas.clamp(max=0.999)
50 | else:
51 | raise NotImplementedError(schedule)
52 | return betas
53 |
54 |
55 | # gaussian diffusion trainer class
56 |
57 | def exists(x):
58 | return x is not None
59 |
60 |
61 | def default(val, d):
62 | if exists(val):
63 | return val
64 | return d() if isfunction(d) else d
65 |
66 |
67 | class GaussianDiffusion(nn.Module):
68 | def __init__(
69 | self,
70 | denoise_fn,
71 | image_size,
72 | channels=3,
73 | loss_type='l1',
74 | conditional=True,
75 | schedule_opt=None,
76 | start_step=1000
77 | ):
78 | super().__init__()
79 | self.channels = channels
80 | self.image_size = image_size
81 | self.denoise_fn = denoise_fn
82 | self.loss_type = loss_type
83 | self.conditional = conditional
84 | if schedule_opt is not None:
85 | pass
86 | # self.set_new_noise_schedule(schedule_opt)
87 |
88 | def set_loss(self, device):
89 | if self.loss_type == 'l1':
90 | self.loss_func = nn.L1Loss(reduction='sum').to(device)
91 | elif self.loss_type == 'l2':
92 | self.loss_func = nn.MSELoss(reduction='sum').to(device)
93 | else:
94 | raise NotImplementedError()
95 | self.optim_loss = nn.MSELoss(reduction='sum').to(device)
96 |
97 | def set_new_noise_schedule(self, schedule_opt, device):
98 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
99 |
100 | betas = make_beta_schedule(
101 | schedule=schedule_opt['schedule'],
102 | n_timestep=schedule_opt['n_timestep'],
103 | linear_start=schedule_opt['linear_start'],
104 | linear_end=schedule_opt['linear_end'])
105 | betas = betas.detach().cpu().numpy() if isinstance(
106 | betas, torch.Tensor) else betas
107 | alphas = 1. - betas
108 | alphas_cumprod = np.cumprod(alphas, axis=0)
109 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
110 | self.sqrt_alphas_cumprod_prev = np.sqrt(
111 | np.append(1., alphas_cumprod))
112 |
113 | timesteps, = betas.shape
114 | self.num_timesteps = int(timesteps)
115 | self.register_buffer('betas', to_torch(betas))
116 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
117 | self.register_buffer('alphas_cumprod_prev',
118 | to_torch(alphas_cumprod_prev))
119 |
120 | # calculations for diffusion q(x_t | x_{t-1}) and others
121 | self.register_buffer('sqrt_alphas_cumprod',
122 | to_torch(np.sqrt(alphas_cumprod)))
123 | self.register_buffer('sqrt_one_minus_alphas_cumprod',
124 | to_torch(np.sqrt(1. - alphas_cumprod)))
125 | self.register_buffer('log_one_minus_alphas_cumprod',
126 | to_torch(np.log(1. - alphas_cumprod)))
127 | self.register_buffer('sqrt_recip_alphas_cumprod',
128 | to_torch(np.sqrt(1. / alphas_cumprod)))
129 | self.register_buffer('sqrt_recipm1_alphas_cumprod',
130 | to_torch(np.sqrt(1. / alphas_cumprod - 1)))
131 |
132 | # calculations for posterior q(x_{t-1} | x_t, x_0)
133 | posterior_variance = betas * \
134 | (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
135 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
136 | self.register_buffer('posterior_variance',
137 | to_torch(posterior_variance))
138 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
139 | self.register_buffer('posterior_log_variance_clipped', to_torch(
140 | np.log(np.maximum(posterior_variance, 1e-20))))
141 | self.register_buffer('posterior_mean_coef1', to_torch(
142 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
143 | self.register_buffer('posterior_mean_coef2', to_torch(
144 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
145 |
146 | def predict_start_from_noise(self, x_t, t, noise):
147 | return self.sqrt_recip_alphas_cumprod[t] * x_t - \
148 | self.sqrt_recipm1_alphas_cumprod[t] * noise
149 |
150 | def q_posterior(self, x_start, x_t, t):
151 | posterior_mean = self.posterior_mean_coef1[t] * \
152 | x_start + self.posterior_mean_coef2[t] * x_t
153 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]
154 | return posterior_mean, posterior_log_variance_clipped
155 |
156 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None):
157 | batch_size = x.shape[0]
158 | noise_level = torch.FloatTensor(
159 | [self.sqrt_alphas_cumprod_prev[t + 1]]).repeat(batch_size, 1).to(x.device)
160 |
161 | if condition_x is not None:
162 | x_recon = self.predict_start_from_noise(
163 | x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level))
164 | else:
165 | x_recon = self.predict_start_from_noise(
166 | x, t=t, noise=self.denoise_fn(x, noise_level))
167 |
168 | if clip_denoised:
169 | x_recon.clamp_(-1., 1.)
170 |
171 | model_mean, posterior_log_variance = self.q_posterior(
172 | x_start=x_recon, x_t=x, t=t)
173 | return model_mean, posterior_log_variance
174 |
175 | @torch.no_grad()
176 | def p_sample(self, x, t, clip_denoised=True, condition_x=None):
177 | model_mean, model_log_variance = self.p_mean_variance(
178 | x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)
179 | noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
180 | return model_mean + noise * (0.5 * model_log_variance).exp()
181 |
182 | # calc ddim alpha
183 | def compute_alpha(self, beta, t):
184 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
185 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
186 | return a
187 |
188 | def slerp(self, z1, z2, alpha):
189 | theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
190 | return (
191 | torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1
192 | + torch.sin(alpha * theta) / torch.sin(theta) * z2
193 | )
194 |
195 | def neibor_16_mul(self, num, size=16):
196 | a = num // size
197 | b = num % size
198 | if b >= 0.5 * size:
199 | return size * (a + 1)
200 | else:
201 | return size * a
202 |
203 | @torch.no_grad()
204 | def p_sample_loop(self, x_in, continous=False):
205 | device = self.betas.device
206 |
207 | condition_ddim = True
208 | if condition_ddim:
209 | timesteps = 20
210 | ddim_eta = 1
211 | alpha = 0.5
212 |
213 | sample_inter = (1 | (timesteps // 10))
214 |
215 | x = copy.deepcopy(x_in)
216 | batch_size, C, H, W = x.shape
217 |
218 | ret_img = x_in
219 |
220 | skip = self.num_timesteps // timesteps
221 | seq = range(0, self.num_timesteps, skip)
222 | seq_next = [-1] + list(seq[:-1])
223 |
224 | # 初始化噪声
225 | shape = x.shape
226 | z1 = torch.randn([shape[0], 3, shape[2], shape[3]], device=device)
227 | z2 = torch.randn([shape[0], 3, shape[2], shape[3]], device=device)
228 | x = self.slerp(z1, z2, alpha)
229 |
230 | # reshape strategy
231 | reshape = False
232 | reshape_stage = 3
233 | h_gap, w_gap = H // reshape_stage, W // reshape_stage
234 | hs = [self.neibor_16_mul(h) for h in range(h_gap, H, h_gap)] + [H]
235 | ws = [self.neibor_16_mul(w) for w in range(w_gap, W, w_gap)] + [W]
236 |
237 | len_seq = len(seq)
238 | for idx, (i, j) in tqdm(enumerate(zip(reversed(seq), reversed(seq_next))), desc='sampling loop time step',
239 | total=len_seq):
240 | t = (torch.ones(batch_size) * i).to(x.device)
241 | next_t = (torch.ones(batch_size) * j).to(x.device)
242 |
243 | at = self.compute_alpha(self.betas, t.long())
244 | at_next = self.compute_alpha(self.betas, next_t.long())
245 |
246 | noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[i + 1]]).repeat(batch_size, 1).to(
247 | x.device)
248 |
249 | if reshape:
250 | cur_idx = int(idx / int(len_seq / reshape_stage))
251 | cur_idx = cur_idx if cur_idx < reshape_stage else reshape_stage - 1
252 |
253 | h, w = hs[cur_idx], ws[cur_idx]
254 | im_resize = Resize([h, w])
255 |
256 | x_in_tmp = im_resize(x_in)
257 | x = im_resize(x)
258 |
259 | et = self.denoise_fn(torch.cat([x_in_tmp, x], dim=1), noise_level)
260 | else:
261 |
262 | et = self.denoise_fn(torch.cat([x_in, x], dim=1), noise_level)
263 |
264 | x0_t = (x - et * (1 - at).sqrt()) / at.sqrt()
265 |
266 | c1 = (
267 | ddim_eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
268 | )
269 | c2 = ((1 - at_next) - c1 ** 2).sqrt()
270 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
271 |
272 | x = xt_next
273 |
274 | if i % sample_inter == 0 or (i == len(seq) - 1):
275 |
276 | if x.shape[-1] != W:
277 | im_resize = Resize([H, W])
278 | x_ = im_resize(x)
279 | else:
280 | x_ = x
281 |
282 | ret_img = torch.cat([ret_img, x_], dim=0)
283 |
284 |
285 | else:
286 | sample_inter = (1 | (self.num_timesteps // 10))
287 | if not self.conditional:
288 | shape = x_in.shape
289 | img = torch.randn(shape, device=device)
290 | ret_img = img
291 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step',
292 | total=self.num_timesteps):
293 | img = self.p_sample(img, i)
294 | if i % sample_inter == 0:
295 | ret_img = torch.cat([ret_img, img], dim=0)
296 | else:
297 |
298 | # inversion
299 | from data_analyse.dcp import Defog
300 |
301 | x_in_numpy = x_in[0].permute(1, 2, 0).cpu().numpy()
302 | x_in_numpy = (x_in_numpy - np.min(x_in_numpy)) / (np.max(x_in_numpy) - np.min(x_in_numpy))
303 |
304 | Mask_img, A = Defog(x_in_numpy, r=81, eps=0.001, w=0.95, maxV1=0.80)
305 | Mask_img = torch.from_numpy(Mask_img).unsqueeze(dim=0).unsqueeze(dim=1).expand_as(x_in).to(x_in.device)
306 |
307 | mean_Mask_img = torch.mean(Mask_img)
308 | Mask_img = Mask_img - mean_Mask_img
309 | print(torch.max(Mask_img), torch.min(Mask_img))
310 |
311 | ret_img = x_in
312 | x = torch.cat([ret_img], dim=1)
313 |
314 | shape = x.shape
315 | img = torch.randn([shape[0], 3, shape[2], shape[3]], device=device)
316 |
317 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step',
318 | total=self.num_timesteps):
319 |
320 | img = self.p_sample(img, i, condition_x=x)
321 |
322 | if i % sample_inter == 0:
323 | ret_img = torch.cat([ret_img, img], dim=0)
324 |
325 | if continous:
326 | return ret_img
327 | else:
328 | return ret_img[-1]
329 |
330 | @torch.no_grad()
331 | def sample(self, batch_size=1, continous=False):
332 | image_size = self.image_size
333 | channels = self.channels
334 | return self.p_sample_loop((batch_size, channels, image_size, image_size), continous)
335 |
336 | @torch.no_grad()
337 | def super_resolution(self, x_in, continous=False):
338 | return self.p_sample_loop(x_in, continous)
339 |
340 | def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None):
341 | noise = default(noise, lambda: torch.randn_like(x_start))
342 |
343 | # random gama
344 | return (
345 | continuous_sqrt_alpha_cumprod * x_start +
346 | (1 - continuous_sqrt_alpha_cumprod ** 2).sqrt() * noise
347 | )
348 |
349 | def p_losses(self, x_in, noise=None):
350 | x_start = x_in['HR']
351 |
352 | x_sr = x_in['SR']
353 |
354 | [b, c, h, w] = x_start.shape
355 | t = np.random.randint(1, self.num_timesteps + 1)
356 | continuous_sqrt_alpha_cumprod = torch.FloatTensor(
357 | np.random.uniform(
358 | self.sqrt_alphas_cumprod_prev[t - 1],
359 | self.sqrt_alphas_cumprod_prev[t],
360 | size=b
361 | )
362 | ).to(x_start.device)
363 | continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(
364 | b, -1)
365 |
366 | noise = default(noise, lambda: torch.randn_like(x_start))
367 | x_noisy = self.q_sample(
368 | x_start=x_start, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise=noise)
369 |
370 | if not self.conditional:
371 | x_recon = self.denoise_fn(x_noisy, continuous_sqrt_alpha_cumprod)
372 | loss = self.loss_func(noise, x_recon)
373 |
374 | else:
375 | x_recon = self.denoise_fn(
376 | torch.cat([x_sr, x_noisy], dim=1), continuous_sqrt_alpha_cumprod)
377 | loss = self.loss_func(noise, x_recon)
378 |
379 | return loss
380 |
381 | def calc_RGB(self, tensor):
382 | b, c = tensor.shape[:2]
383 | RGB_mean = torch.mean(tensor.view(b, c, -1), -1)
384 | return RGB_mean
385 |
386 | def forward(self, x, *args, **kwargs):
387 | return self.p_losses(x, *args, **kwargs)
388 |
--------------------------------------------------------------------------------
/model/dehaze_with_z_v2_modules/unet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from inspect import isfunction
6 | from kornia.filters import gaussian_blur2d
7 |
8 |
9 | def exists(x):
10 | return x is not None
11 |
12 |
13 | def default(val, d):
14 | if exists(val):
15 | return val
16 | return d() if isfunction(d) else d
17 |
18 |
19 | class PositionalEncoding(nn.Module):
20 | def __init__(self, dim):
21 | super().__init__()
22 | self.dim = dim
23 |
24 | def forward(self, noise_level):
25 | count = self.dim // 2
26 | step = torch.arange(count, dtype=noise_level.dtype,
27 | device=noise_level.device) / count
28 | encoding = noise_level.unsqueeze(
29 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
30 | encoding = torch.cat(
31 | [torch.sin(encoding), torch.cos(encoding)], dim=-1)
32 | return encoding
33 |
34 |
35 | class FeatureWiseAffine(nn.Module):
36 |
37 | def __init__(self, in_channels, out_channels, use_affine_level=False):
38 | super(FeatureWiseAffine, self).__init__()
39 | self.use_affine_level = use_affine_level
40 | self.noise_func = nn.Sequential(
41 | nn.Linear(in_channels, out_channels * (1 + self.use_affine_level))
42 | )
43 |
44 | def forward(self, x, noise_embed):
45 | batch = x.shape[0]
46 | if self.use_affine_level:
47 | gamma, beta = self.noise_func(noise_embed).view(
48 | batch, -1, 1, 1).chunk(2, dim=1)
49 | x = (1 + gamma) * x + beta
50 | else:
51 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
52 | return x
53 |
54 |
55 | class Swish(nn.Module):
56 | def forward(self, x):
57 | return x * torch.sigmoid(x)
58 |
59 |
60 | class Upsample(nn.Module):
61 | def __init__(self, dim):
62 | super().__init__()
63 | self.up = nn.Upsample(scale_factor=2, mode="nearest")
64 | self.conv = nn.Conv2d(dim, dim, 3, padding=1)
65 |
66 | def forward(self, x):
67 | return self.conv(self.up(x))
68 |
69 |
70 | class Downsample(nn.Module):
71 | def __init__(self, dim):
72 | super().__init__()
73 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
74 |
75 | def forward(self, x):
76 | return self.conv(x)
77 |
78 |
79 | class Block(nn.Module):
80 | def __init__(self, dim, dim_out, groups=32, dropout=0):
81 | super().__init__()
82 | self.block = nn.Sequential(
83 | nn.GroupNorm(groups, dim),
84 | Swish(),
85 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
86 | nn.Conv2d(dim, dim_out, 3, padding=1)
87 | )
88 |
89 | def forward(self, x):
90 | return self.block(x)
91 |
92 |
93 | class ResnetBlock(nn.Module):
94 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
95 | super().__init__()
96 | self.noise_func = FeatureWiseAffine(
97 | noise_level_emb_dim, dim_out, use_affine_level)
98 |
99 | self.block1 = Block(dim, dim_out, groups=norm_groups)
100 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
101 | self.res_conv = nn.Conv2d(
102 | dim, dim_out, 1) if dim != dim_out else nn.Identity()
103 |
104 | def forward(self, x, time_emb):
105 | b, c, h, w = x.shape
106 | h = self.block1(x)
107 | h = self.noise_func(h, time_emb)
108 | h = self.block2(h)
109 | return h + self.res_conv(x)
110 |
111 |
112 | class SelfAttention(nn.Module):
113 | def __init__(self, in_channel, n_head=1, norm_groups=32):
114 | super().__init__()
115 |
116 | self.n_head = n_head
117 |
118 | self.norm = nn.GroupNorm(norm_groups, in_channel)
119 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
120 | self.out = nn.Conv2d(in_channel, in_channel, 1)
121 |
122 | def forward(self, input):
123 | batch, channel, height, width = input.shape
124 | n_head = self.n_head
125 | head_dim = channel // n_head
126 |
127 | norm = self.norm(input)
128 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
129 | query, key, value = qkv.chunk(3, dim=2) # bhdyx
130 |
131 | attn = torch.einsum(
132 | "bnchw, bncyx -> bnhwyx", query, key
133 | ).contiguous() / math.sqrt(channel)
134 | attn = attn.view(batch, n_head, height, width, -1)
135 | attn = torch.softmax(attn, -1)
136 | attn = attn.view(batch, n_head, height, width, height, width)
137 |
138 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
139 | out = self.out(out.view(batch, channel, height, width))
140 |
141 | return out + input
142 |
143 |
144 | class ResnetBlocWithAttn(nn.Module):
145 | def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
146 | super().__init__()
147 | self.with_attn = with_attn
148 | self.res_block = ResnetBlock(
149 | dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
150 | if with_attn:
151 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
152 |
153 | def forward(self, x, time_emb):
154 | x = self.res_block(x, time_emb)
155 | if (self.with_attn):
156 | x = self.attn(x)
157 | return x
158 |
159 |
160 | class FCB(nn.Module):
161 | def __init__(self, channel, kernel_size=3):
162 | super().__init__()
163 | self.ks = kernel_size
164 | self.sigma_rate = 1
165 |
166 | params = torch.ones((4, 1), requires_grad=True)
167 | self.params = nn.Parameter(params)
168 |
169 | def forward(self, x):
170 | #
171 | x1 = gaussian_blur2d(x, (self.ks, self.ks), (1 * self.sigma_rate, 1 * self.sigma_rate))
172 | R1 = x - x1
173 |
174 | x2 = gaussian_blur2d(x, (self.ks * 2 - 1, self.ks * 2 - 1), (2 * self.sigma_rate, 2 * self.sigma_rate))
175 | x3 = gaussian_blur2d(x, (self.ks * 4 - 1, self.ks * 4 - 1), (4 * self.sigma_rate, 4 * self.sigma_rate))
176 | R2 = x1 - x2
177 | R3 = x2 - x3
178 |
179 | R1 = R1.unsqueeze(dim=-1)
180 | R2 = R2.unsqueeze(dim=-1)
181 | R3 = R3.unsqueeze(dim=-1)
182 | R_cat = torch.cat([R1, R2, R3, x.unsqueeze(dim=-1)], dim=-1)
183 |
184 | sum_ = torch.matmul(R_cat, self.params).squeeze(dim=-1)
185 |
186 | return sum_
187 |
188 |
189 | class UNet(nn.Module):
190 | def __init__(
191 | self,
192 | in_channel=6,
193 | out_channel=3,
194 | inner_channel=32,
195 | norm_groups=32,
196 | channel_mults=(1, 2, 4, 8, 8),
197 | attn_res=[8],
198 | res_blocks=3,
199 | dropout=0,
200 | with_noise_level_emb=True,
201 | image_size=128,
202 | fcb=True
203 | ):
204 | super().__init__()
205 |
206 | self.fcb = fcb
207 |
208 | if with_noise_level_emb:
209 | noise_level_channel = inner_channel
210 | self.noise_level_mlp = nn.Sequential(
211 | PositionalEncoding(inner_channel),
212 | nn.Linear(inner_channel, inner_channel * 4),
213 | Swish(),
214 | nn.Linear(inner_channel * 4, inner_channel)
215 | )
216 | else:
217 | noise_level_channel = None
218 | self.noise_level_mlp = None
219 |
220 | num_mults = len(channel_mults)
221 | pre_channel = inner_channel
222 | feat_channels = [pre_channel]
223 | now_res = image_size
224 | downs = [nn.Conv2d(in_channel, inner_channel,
225 | kernel_size=3, padding=1)]
226 | for ind in range(num_mults):
227 | is_last = (ind == num_mults - 1)
228 | use_attn = (now_res in attn_res)
229 | channel_mult = inner_channel * channel_mults[ind]
230 | for _ in range(0, res_blocks):
231 | downs.append(ResnetBlocWithAttn(
232 | pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
233 | dropout=dropout, with_attn=use_attn))
234 | feat_channels.append(channel_mult)
235 | pre_channel = channel_mult
236 | if not is_last:
237 | downs.append(Downsample(pre_channel))
238 | feat_channels.append(pre_channel)
239 | now_res = now_res // 2
240 | self.downs = nn.ModuleList(downs)
241 |
242 | self.mid = nn.ModuleList([
243 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
244 | norm_groups=norm_groups,
245 | dropout=dropout, with_attn=True),
246 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
247 | norm_groups=norm_groups,
248 | dropout=dropout, with_attn=False)
249 | ])
250 |
251 | ups = []
252 | fbs = []
253 | for ind in reversed(range(num_mults)):
254 | is_last = (ind < 1)
255 | use_attn = (now_res in attn_res)
256 | channel_mult = inner_channel * channel_mults[ind]
257 | for _ in range(0, res_blocks + 1):
258 | ups.append(ResnetBlocWithAttn(
259 | pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel,
260 | norm_groups=norm_groups,
261 | dropout=dropout, with_attn=use_attn))
262 | pre_channel = channel_mult
263 | tmp = FCB(pre_channel) if self.fcb else pre_channel
264 | fbs.append(tmp)
265 | if not is_last:
266 | ups.append(Upsample(pre_channel))
267 | tmp = FCB(pre_channel) if self.fcb else pre_channel
268 | fbs.append(tmp)
269 | now_res = now_res * 2
270 |
271 | self.ups = nn.ModuleList(ups)
272 | self.fbs = nn.ModuleList(fbs)
273 |
274 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)
275 |
276 | def forward(self, x, time):
277 | t = self.noise_level_mlp(time) if exists(
278 | self.noise_level_mlp) else None
279 |
280 | feats = []
281 | for layer in self.downs:
282 | if isinstance(layer, ResnetBlocWithAttn):
283 | x = layer(x, t)
284 | else:
285 | x = layer(x)
286 | feats.append(x)
287 |
288 | for layer in self.mid:
289 | if isinstance(layer, ResnetBlocWithAttn):
290 | x = layer(x, t)
291 | else:
292 | x = layer(x)
293 |
294 | for layer, fb in zip(self.ups, self.fbs):
295 | if isinstance(layer, ResnetBlocWithAttn):
296 | tmp = feats.pop()
297 | if self.fcb:
298 | tmp = fb(tmp)
299 | x = layer(torch.cat((x, tmp), dim=1), t)
300 | else:
301 | x = layer(x)
302 |
303 | tmp = self.final_conv(x)
304 |
305 | return tmp
306 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | import os
7 | import model.networks as networks
8 | from .base_model import BaseModel
9 |
10 | logger = logging.getLogger('base')
11 |
12 |
13 | class EMA():
14 | def __init__(self, model, decay):
15 | self.model = model
16 | self.decay = decay
17 | self.shadow = {}
18 | self.backup = {}
19 |
20 | def register(self):
21 | for name, param in self.model.named_parameters():
22 | if param.requires_grad:
23 | self.shadow[name] = param.data.clone()
24 |
25 | def update(self):
26 | for name, param in self.model.named_parameters():
27 | if param.requires_grad:
28 | assert name in self.shadow
29 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
30 | self.shadow[name] = new_average.clone()
31 |
32 | def apply_shadow(self):
33 | for name, param in self.model.named_parameters():
34 | if param.requires_grad:
35 | assert name in self.shadow
36 | self.backup[name] = param.data
37 | param.data = self.shadow[name]
38 |
39 | def restore(self):
40 | for name, param in self.model.named_parameters():
41 | if param.requires_grad:
42 | assert name in self.backup
43 | param.data = self.backup[name]
44 | self.backup = {}
45 |
46 |
47 | class DDPM(BaseModel):
48 | def __init__(self, opt):
49 | super(DDPM, self).__init__(opt)
50 | # define network and load pretrained models
51 | self.netG = self.set_device(networks.define_G(opt))
52 | self.schedule_phase = None
53 |
54 | # ema
55 | self.use_ema = opt['train']['ema_scheduler']['used'] if "used" in opt['train'][
56 | 'ema_scheduler'].keys() else False
57 | if self.use_ema:
58 | self.decay = opt['train']['ema_scheduler']['ema_decay']
59 | self.ema_start = opt['train']['ema_scheduler']['step_start_ema']
60 | self.shadow = {}
61 | self.backup = {}
62 | self.register()
63 | print("using ema to training ...")
64 |
65 | # set loss and load resume state
66 | self.set_loss()
67 | self.set_new_noise_schedule(
68 | opt['model']['beta_schedule']['train'], schedule_phase='train')
69 | if self.opt['phase'] == 'train':
70 | self.netG.train()
71 | # find the parameters to optimize
72 | if opt['model']['finetune_norm']:
73 | optim_params = []
74 | for k, v in self.netG.named_parameters():
75 | v.requires_grad = False
76 | if k.find('transformer') >= 0:
77 | v.requires_grad = True
78 | v.data.zero_()
79 | optim_params.append(v)
80 | logger.info(
81 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k))
82 | else:
83 | optim_params = list(self.netG.parameters())
84 |
85 | self.optG = torch.optim.Adam(
86 | optim_params, lr=opt['train']["optimizer"]["lr"])
87 | self.log_dict = OrderedDict()
88 | self.load_network()
89 | self.print_network()
90 |
91 | def feed_data(self, data):
92 | self.data = self.set_device(data)
93 |
94 | def optimize_parameters(self, current_step):
95 | self.optG.zero_grad()
96 |
97 | l_pix = self.netG(self.data)
98 | # need to average in multi-gpu
99 | b, c, h, w = self.data['HR'].shape
100 | l_pix = l_pix.sum() / int(b * c * h * w)
101 | l_pix.backward(retain_graph=True)
102 | self.optG.step()
103 |
104 | if self.use_ema and current_step > self.ema_start:
105 | self.update()
106 |
107 | # set log
108 | self.log_dict['l_pix'] = l_pix.item()
109 |
110 | def test(self, continous=False):
111 | if self.use_ema:
112 | print("use ema to test...")
113 | self.apply_shadow()
114 |
115 | self.netG.eval()
116 | with torch.no_grad():
117 | if isinstance(self.netG, nn.DataParallel):
118 | self.SR = self.netG.module.super_resolution(self.data['SR'], continous)
119 |
120 | else:
121 | self.SR = self.netG.super_resolution(self.data['SR'], continous)
122 |
123 | if self.use_ema:
124 | self.restore()
125 |
126 | self.netG.train()
127 |
128 | def sample(self, batch_size=1, continous=False):
129 | self.netG.eval()
130 | with torch.no_grad():
131 | if isinstance(self.netG, nn.DataParallel):
132 | self.SR = self.netG.module.sample(batch_size, continous)
133 | else:
134 | self.SR = self.netG.sample(batch_size, continous)
135 | self.netG.train()
136 |
137 | def set_loss(self):
138 | if isinstance(self.netG, nn.DataParallel):
139 | self.netG.module.set_loss(self.device)
140 | else:
141 | self.netG.set_loss(self.device)
142 |
143 | def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'):
144 | if self.schedule_phase is None or self.schedule_phase != schedule_phase:
145 | self.schedule_phase = schedule_phase
146 | if isinstance(self.netG, nn.DataParallel):
147 | self.netG.module.set_new_noise_schedule(
148 | schedule_opt, self.device)
149 | else:
150 | self.netG.set_new_noise_schedule(schedule_opt, self.device)
151 |
152 | def get_current_log(self):
153 | return self.log_dict
154 |
155 | def get_current_visuals(self, need_LR=True, sample=False):
156 | out_dict = OrderedDict()
157 | if sample:
158 | out_dict['SAM'] = self.SR.detach().float().cpu()
159 | else:
160 | out_dict['SR'] = self.SR.detach().float().cpu()
161 | out_dict['INF'] = self.data['SR'].detach().float().cpu()
162 | out_dict['HR'] = self.data['HR'].detach().float().cpu()
163 | if need_LR and 'LR' in self.data:
164 | out_dict['LR'] = self.data['LR'].detach().float().cpu()
165 | else:
166 | out_dict['LR'] = out_dict['INF']
167 | return out_dict
168 |
169 | def print_network(self):
170 | s, n = self.get_network_description(self.netG)
171 | if isinstance(self.netG, nn.DataParallel):
172 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
173 | self.netG.module.__class__.__name__)
174 | else:
175 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
176 |
177 | logger.info(
178 | 'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
179 | # logger.info(s)
180 |
181 | def save_network(self, epoch, iter_step):
182 | gen_path = os.path.join(
183 | self.opt['path']['checkpoint'], 'I{}_E{}_gen.pth'.format(iter_step, epoch))
184 | opt_path = os.path.join(
185 | self.opt['path']['checkpoint'], 'I{}_E{}_opt.pth'.format(iter_step, epoch))
186 | # gen
187 | network = self.netG
188 | if isinstance(self.netG, nn.DataParallel):
189 | network = network.module
190 | state_dict = network.state_dict()
191 | for key, param in state_dict.items():
192 | state_dict[key] = param.cpu()
193 | torch.save(state_dict, gen_path)
194 | # opt
195 | opt_state = {'epoch': epoch, 'iter': iter_step,
196 | 'scheduler': None, 'optimizer': None}
197 | opt_state['optimizer'] = self.optG.state_dict()
198 | torch.save(opt_state, opt_path)
199 |
200 | logger.info(
201 | 'Saved model in [{:s}] ...'.format(gen_path))
202 |
203 | def load_network(self):
204 | load_path = self.opt['path']['resume_state']
205 | if load_path is not None:
206 | logger.info(
207 | 'Loading pretrained model for G [{:s}] ...'.format(load_path))
208 | gen_path = '{}_gen.pth'.format(load_path)
209 | opt_path = '{}_opt.pth'.format(load_path)
210 | # gen
211 | network = self.netG
212 | if isinstance(self.netG, nn.DataParallel):
213 | network = network.module
214 | network.load_state_dict(torch.load(
215 | gen_path), strict=False)
216 | if self.opt['phase'] == 'train':
217 | try:
218 | # optimizer
219 | opt = torch.load(opt_path)
220 | # self.optG.load_state_dict(opt['optimizer'])
221 | self.begin_step = opt['iter']
222 | self.begin_epoch = opt['epoch']
223 | except:
224 | pass
225 |
226 | def register(self):
227 | for name, param in self.netG.named_parameters():
228 | if param.requires_grad:
229 | self.shadow[name] = param.data.clone()
230 |
231 | def update(self):
232 | for name, param in self.netG.named_parameters():
233 | if param.requires_grad:
234 | assert name in self.shadow
235 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
236 | self.shadow[name] = new_average.clone()
237 |
238 | def apply_shadow(self):
239 | for name, param in self.netG.named_parameters():
240 | if param.requires_grad:
241 | assert name in self.shadow
242 | self.backup[name] = param.data
243 | param.data = self.shadow[name]
244 |
245 | def restore(self):
246 | for name, param in self.netG.named_parameters():
247 | if param.requires_grad:
248 | assert name in self.backup
249 | param.data = self.backup[name]
250 | self.backup = {}
251 |
--------------------------------------------------------------------------------
/model/networks.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import init
6 | from torch.nn import modules
7 |
8 | logger = logging.getLogger('base')
9 |
10 |
11 | ####################
12 | # initialize
13 | ####################
14 |
15 |
16 | def weights_init_normal(m, std=0.02):
17 | classname = m.__class__.__name__
18 | if classname.find('Conv') != -1:
19 | init.normal_(m.weight.data, 0.0, std)
20 | if m.bias is not None:
21 | m.bias.data.zero_()
22 | elif classname.find('Linear') != -1:
23 | init.normal_(m.weight.data, 0.0, std)
24 | if m.bias is not None:
25 | m.bias.data.zero_()
26 | elif classname.find('BatchNorm2d') != -1:
27 | init.normal_(m.weight.data, 1.0, std) # BN also uses norm
28 | init.constant_(m.bias.data, 0.0)
29 |
30 |
31 | def weights_init_kaiming(m, scale=1):
32 | classname = m.__class__.__name__
33 | if classname.find('Conv2d') != -1:
34 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
35 | m.weight.data *= scale
36 | if m.bias is not None:
37 | m.bias.data.zero_()
38 | elif classname.find('Linear') != -1:
39 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
40 | m.weight.data *= scale
41 | if m.bias is not None:
42 | m.bias.data.zero_()
43 | elif classname.find('BatchNorm2d') != -1:
44 | init.constant_(m.weight.data, 1.0)
45 | init.constant_(m.bias.data, 0.0)
46 |
47 |
48 | def weights_init_orthogonal(m):
49 | classname = m.__class__.__name__
50 | if classname.find('Conv') != -1:
51 | try:
52 | init.orthogonal_(m.weight.data, gain=1)
53 | except:
54 | pass
55 | try:
56 | if m.bias is not None:
57 | m.bias.data.zero_()
58 | except:
59 | pass
60 | elif classname.find('Linear') != -1:
61 | init.orthogonal_(m.weight.data, gain=1)
62 | try:
63 | if m.bias is not None:
64 | m.bias.data.zero_()
65 | except:
66 | pass
67 | elif classname.find('BatchNorm2d') != -1:
68 | init.constant_(m.weight.data, 1.0)
69 | init.constant_(m.bias.data, 0.0)
70 |
71 |
72 | def init_weights(net, init_type='kaiming', scale=1, std=0.02):
73 | # scale for 'kaiming', std for 'normal'.
74 | logger.info('Initialization method [{:s}]'.format(init_type))
75 | if init_type == 'normal':
76 | weights_init_normal_ = functools.partial(weights_init_normal, std=std)
77 | net.apply(weights_init_normal_)
78 | elif init_type == 'kaiming':
79 | weights_init_kaiming_ = functools.partial(
80 | weights_init_kaiming, scale=scale)
81 | net.apply(weights_init_kaiming_)
82 | elif init_type == 'orthogonal':
83 | net.apply(weights_init_orthogonal)
84 | else:
85 | raise NotImplementedError(
86 | 'initialization method [{:s}] not implemented'.format(init_type))
87 |
88 |
89 | ####################
90 | # define network
91 | ####################
92 |
93 |
94 | # Generator
95 | def define_G(opt):
96 | model_opt = opt['model']
97 | if model_opt['which_model_G'] == 'ddpm':
98 | from .ddpm_modules import diffusion, unet
99 | elif model_opt['which_model_G'] == 'sr3':
100 | from .sr3_modules import diffusion, unet
101 | elif model_opt['which_model_G'] == 'asm':
102 | from .asm_modules import diffusion, unet
103 | elif model_opt['which_model_G'] == 'MSBDN':
104 | from .MSBDN import diffusion, unet
105 | elif model_opt['which_model_G'] == 'dehazy':
106 | from .dehazy_modules import diffusion
107 | from .dehazy_modules import vspga as unet
108 | elif model_opt['which_model_G'] == 'dehaze_with_z':
109 | from .dehaze_with_z_modules import diffusion, unet
110 | elif model_opt['which_model_G'] == 'dehaze_with_z_gan':
111 | from .dehaze_with_z_gan_modules import diffusion, unet
112 | elif model_opt['which_model_G'] == 'dehaze_with_z_v1':
113 | from .dehaze_with_z_v1_modules import diffusion, unet
114 | elif model_opt['which_model_G'] == 'dehaze_with_z_bagging':
115 | from .dehaze_with_z_bagging_modules import diffusion, unet
116 | elif model_opt['which_model_G'] == 'dehaze_with_z_v1_ssim':
117 | from .dehaze_with_z_v1_ssim_modules import diffusion, unet
118 | elif model_opt['which_model_G'] == 'dehaze_with_z_v1_depth_lap_ssim':
119 | from .dehaze_with_z_v1_depth_lap_ssim_modules import diffusion, unet
120 | elif model_opt['which_model_G'] == 'dehaze_with_z_v2':
121 | from .dehaze_with_z_v2_modules import diffusion, unet
122 | elif model_opt['which_model_G'] == 'dehaze_with_z_v4_CA':
123 | from .dehaze_with_z_v4_CA_modules import diffusion, unet
124 | elif model_opt['which_model_G'] == 'dehaze_filter_hsv':
125 | from .dehaze_filter_hsv_modules import diffusion, unet
126 |
127 | if ('norm_groups' not in model_opt['unet']) or model_opt['unet']['norm_groups'] is None:
128 | model_opt['unet']['norm_groups'] = 32
129 | model = unet.UNet(
130 | in_channel=model_opt['unet']['in_channel'],
131 | out_channel=model_opt['unet']['out_channel'],
132 | norm_groups=model_opt['unet']['norm_groups'],
133 | inner_channel=model_opt['unet']['inner_channel'],
134 | channel_mults=model_opt['unet']['channel_multiplier'],
135 | attn_res=model_opt['unet']['attn_res'],
136 | res_blocks=model_opt['unet']['res_blocks'],
137 | dropout=model_opt['unet']['dropout'],
138 | image_size=model_opt['diffusion']['image_size'],
139 | fcb = model_opt['FCB']
140 |
141 | )
142 | netG = diffusion.GaussianDiffusion(
143 | model,
144 | image_size=model_opt['diffusion']['image_size'],
145 | channels=model_opt['diffusion']['channels'],
146 | loss_type='l1', # L1 or L2
147 | conditional=model_opt['diffusion']['conditional'],
148 | schedule_opt=model_opt['beta_schedule']['train'],
149 | start_step=model_opt['diffusion']['start_step'] if 'start_step' in model_opt['diffusion'].keys() else 1000
150 | )
151 | if opt['phase'] == 'train':
152 | # init_weights(netG, init_type='kaiming', scale=0.1)
153 | init_weights(netG, init_type='orthogonal')
154 | if opt['gpu_ids'] and opt['distributed']:
155 | assert torch.cuda.is_available()
156 | netG = nn.DataParallel(netG)
157 | return netG
158 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | torch>=1.6
2 | torchvision
3 | numpy
4 | pandas
5 | tqdm
6 | lmdb
7 | opencv-python
8 | pillow
9 | tensorboardx
10 | wandb
11 | kornia=0.6.2
12 | pyciede2000
13 | pyiqa==0.1.5
14 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import data as Data
3 | import model as Model
4 | import argparse
5 | import logging
6 | import core.logger as Logger
7 | import core.metrics as Metrics
8 | from core.wandb_logger import WandbLogger
9 | from tensorboardX import SummaryWriter
10 | import os
11 | import numpy as np
12 | import copy
13 | import random
14 |
15 | seed = 6666
16 | print('Random seed: {}'.format(seed))
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | torch.backends.cudnn.deterministic = True
22 |
23 |
24 | class AverageMeter(object):
25 | """Computes and stores the average and current value"""
26 |
27 | def __init__(self):
28 | self.reset()
29 | self.cache = []
30 |
31 | def reset(self):
32 | self.val = 0
33 | self.avg = 0
34 | self.sum = 0
35 | self.count = 0
36 |
37 | def update(self, val, n=0):
38 | self.val = val
39 | self.sum += val * n
40 | self.count += n
41 |
42 | self.cache.append(self.val)
43 | if len(self.cache) >= 20: self.cache = self.cache[1:]
44 | self.avg = np.mean(self.cache)
45 |
46 | def __str__(self):
47 | """String representation for logging
48 | """
49 | # for values that should be recorded exactly e.g. iteration number
50 | if self.count == 0:
51 | return str(self.val)
52 | # for stats
53 | return '%.4f (%.4f)' % (self.val, self.avg)
54 |
55 |
56 | def adjust_learning_rate(change_idx, optimizer):
57 | """Sets the learning rate to the initial LR
58 | decayed by 10 every 30 epochs"""
59 | for param_group in optimizer.param_groups:
60 | lr = param_group['lr']
61 |
62 | lr = lr * (0.7 ** change_idx)
63 |
64 | param_group['lr'] = lr
65 |
66 | logger.info("Current lr: {}".format(optimizer.state_dict()['param_groups'][0]['lr']))
67 |
68 |
69 | if __name__ == "__main__":
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('-c', '--config', type=str, default='config/framework_da.json',
72 | help='JSON file for configuration')
73 | parser.add_argument('-p', '--phase', type=str, choices=['train'],
74 | help='Run either train(training) or val(generation)', default='train')
75 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
76 | parser.add_argument('-debug', '-d', action='store_true')
77 | parser.add_argument('-enable_wandb', action='store_true')
78 | parser.add_argument('-log_wandb_ckpt', action='store_true')
79 | parser.add_argument('-log_eval', action='store_true')
80 |
81 | # parse configs
82 | args = parser.parse_args()
83 | opt = Logger.parse(args)
84 | # Convert to NoneDict, which return None for missing key.
85 | opt = Logger.dict_to_nonedict(opt)
86 |
87 | # logging
88 | torch.backends.cudnn.enabled = True
89 | torch.backends.cudnn.benchmark = True
90 |
91 | Logger.setup_logger(None, opt['path']['log'],
92 | 'train', level=logging.INFO, screen=True)
93 | Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
94 | logger = logging.getLogger('base')
95 | logger.info(Logger.dict2str(opt))
96 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])
97 |
98 | change_sizes = opt["change_sizes"]
99 |
100 | # Initialize WandbLogger
101 | if opt['enable_wandb']:
102 | import wandb
103 |
104 | wandb_logger = WandbLogger(opt)
105 | wandb.define_metric('validation/val_step')
106 | wandb.define_metric('epoch')
107 | wandb.define_metric("validation/*", step_metric="val_step")
108 | val_step = 0
109 | else:
110 | wandb_logger = None
111 |
112 | # dataset
113 | for phase, dataset_opt in opt['datasets'].items():
114 | if phase == 'train' and args.phase != 'val':
115 | train_set = Data.create_dataset(dataset_opt, phase)
116 | train_loader = Data.create_dataloader(
117 | train_set, dataset_opt, phase)
118 | logger.info('Initial Dataset Finished')
119 |
120 | logger.info("change rate:" + "".join(["{}:{} ".format(k, v) for k, v in change_sizes.items()]))
121 |
122 | # model
123 | diffusion = Model.create_model(opt)
124 | logger.info('Initial Model Finished')
125 |
126 | # Train
127 | current_step = diffusion.begin_step
128 | current_epoch = diffusion.begin_epoch
129 | n_iter = opt['train']['n_iter']
130 |
131 | # ave
132 | ave_loss = AverageMeter()
133 |
134 | if opt['path']['resume_state']:
135 | logger.info('Resuming training from epoch: {}, iter: {}.'.format(
136 | current_epoch, current_step))
137 |
138 | diffusion.set_new_noise_schedule(
139 | opt['model']['beta_schedule'][opt['phase']], schedule_phase=opt['phase'])
140 |
141 | if current_step == 0:
142 | change_size_idx = 0
143 | else:
144 | change_size_idx = 0
145 | try:
146 | while current_step >= int(
147 | float(list(change_sizes.keys())[change_size_idx]) * n_iter) and change_size_idx < len(
148 | list(change_sizes.keys())):
149 | change_size_idx += 1
150 | except:
151 | pass
152 | change_size_idx -= 1
153 |
154 | while current_step < n_iter:
155 |
156 | # reset train_loader
157 | if current_step >= int(
158 | float(list(change_sizes.keys())[change_size_idx]) * n_iter) and change_size_idx < len(
159 | list(change_sizes.keys())):
160 | logger.info('reset train_loader')
161 | resize_resolu = change_sizes[list(change_sizes.keys())[change_size_idx]]
162 | train_dataset_opt = copy.deepcopy(opt['datasets']['train'])
163 |
164 | train_dataset_opt["l_resolution"], train_dataset_opt["r_resolution"] = resize_resolu, resize_resolu
165 |
166 | logger.info('reset train_loader: l_resolution:{}, r_resolution:{}, batch_size:{}'.format(
167 | train_dataset_opt["l_resolution"], train_dataset_opt["r_resolution"],
168 | train_dataset_opt["batch_size"]))
169 |
170 | train_set = Data.create_dataset(train_dataset_opt, 'train')
171 | train_loader = Data.create_dataloader(train_set, train_dataset_opt, 'train')
172 |
173 | logger.info('reset train_loader finished .')
174 |
175 | adjust_learning_rate(change_size_idx, diffusion.optG)
176 |
177 | change_size_idx += 1
178 |
179 | current_epoch += 1
180 | for _, train_data in enumerate(train_loader):
181 | current_step += 1
182 | if current_step > n_iter:
183 | break
184 |
185 | diffusion.feed_data(train_data)
186 | diffusion.optimize_parameters(current_step)
187 | # log
188 | if current_step % opt['train']['print_freq'] == 0:
189 | logs = diffusion.get_current_log()
190 | message = ' '.format(
191 | current_epoch, current_step)
192 | for k, v in logs.items():
193 | ave_loss.update(v)
194 | message += '{:s}: {:.4e} ({:.4e})'.format(k, v, ave_loss.avg)
195 | tb_logger.add_scalar(k, v, current_step)
196 | logger.info(message)
197 |
198 | if wandb_logger:
199 | wandb_logger.log_metrics(logs)
200 |
201 | if current_step % opt['train']['save_checkpoint_freq'] == 0:
202 | logger.info('Saving models and training states.')
203 | diffusion.save_network(current_epoch, current_step)
204 |
205 | if wandb_logger and opt['log_wandb_ckpt']:
206 | wandb_logger.log_checkpoint(current_epoch, current_step)
207 |
208 | if current_step >= int(
209 | float(list(change_sizes.keys())[change_size_idx]) * n_iter) and change_size_idx < len(
210 | list(change_sizes.keys())):
211 | break
212 |
213 | if wandb_logger:
214 | wandb_logger.log_metrics({'epoch': current_epoch - 1})
215 |
216 | # save model
217 | logger.info('End of training.')
218 |
--------------------------------------------------------------------------------