├── README.md ├── config └── framework_da.json ├── core ├── __pycache__ │ ├── logger.cpython-38.pyc │ ├── metrics.cpython-38.pyc │ └── wandb_logger.cpython-38.pyc ├── calc_indicator.py ├── logger.py ├── metrics.py └── wandb_logger.py ├── data ├── FDA.py ├── HazeAug.py ├── LRHR_dataset.py ├── __init__.py ├── __pycache__ │ ├── FDA.cpython-38.pyc │ ├── HazeAug.cpython-38.pyc │ ├── LRHR_dataset.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ └── util.cpython-38.pyc └── util.py ├── infer.py ├── misc ├── RTTS.jpg └── framework-v3.jpg ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── base_model.cpython-38.pyc │ ├── model.cpython-38.pyc │ └── networks.cpython-38.pyc ├── base_model.py ├── dehaze_with_z_v2_modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── diffusion.cpython-38.pyc │ │ └── unet.cpython-38.pyc │ ├── diffusion.py │ └── unet.py ├── model.py └── networks.py ├── requirement.txt └── train.py /README.md: -------------------------------------------------------------------------------- 1 | 2 |

Frequency Compensated Diffusion Model for Real-scene Dehazing

3 | 4 | 5 | This is an official implementation of **Frequency Compensated Diffusion Model for Real-scene Dehazing** by **Pytorch**. 6 | 7 | 8 | 9 | 10 | 11 | show 12 | 13 | 14 | 20 | 21 | ## News 22 | - 2025.03 We release a more powerful dehazing diffusion model [ProHaze](https://github.com/TianwenZhou/ProDehaze) based on SD-2.1. 23 | 24 | ## Getting started 25 | ### Installation 26 | * This repo is a modification on the [**SR3 Repo**](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement ). 27 | 28 | * Install third-party libraries. 29 | 30 | ```python 31 | pip install -r requirement.txt 32 | ``` 33 | 34 | ### Data Prepare 35 | 36 | Download train/eval data from the following links: 37 | 38 | Training: [*RESIDE*](https://sites.google.com/view/reside-dehaze-datasets/reside-v0) 39 | 40 | Testing: 41 | [*I-Haze*](https://data.vision.ee.ethz.ch/cvl/ntire18//i-haze/#:~:text=To%20overcome%20this%20issue%20we%20introduce%20I-HAZE%2C%20a,real%20haze%20produced%20by%20a%20professional%20haze%20machine.) / 42 | [*O-Haze*](https://data.vision.ee.ethz.ch/cvl/ntire18/o-haze/) / 43 | [*Dense-Haze*](https://arxiv.org/abs/1904.02904#:~:text=To%20address%20this%20limitation%2C%20we%20introduce%20Dense-Haze%20-,introducing%20real%20haze%2C%20generated%20by%20professional%20haze%20machines.) / 44 | [*Nh-Haze*](https://data.vision.ee.ethz.ch/cvl/ntire20/nh-haze/) / 45 | [*RTTS*](https://sites.google.com/view/reside-dehaze-datasets/reside-standard?authuser=0) 46 | 47 | ```python 48 | mkdir dataset 49 | ``` 50 | 51 | Re-organize the train/val images in the following file structure: 52 | 53 | 54 | ```shell 55 | #Training data file structure 56 | dataset/RESIDE/ 57 | ├── HR # ground-truth clear images. 58 | ├── HR_hazy_src # hazy images. 59 | └── HR_depth # depth images (Generated by MonoDepth (github.com/OniroAI/MonoDepth-PyTorch)). 60 | 61 | #Testing data (e.g. DenseHaze) file structure 62 | dataset/{name}/ 63 | ├── HR # ground-truth images. 64 | └── HR_hazy # hazy images. 65 | ``` 66 | 67 | then make sure the correct data paths ("dataroot") in config/framework_da.json. 68 | 69 | ## Pretrained Model 70 | 71 | We prepared the pretrained model at: 72 | 73 | | Type | Weights | 74 | | ----------------------------------------------------------- | ------------------------------------------------------------ | 75 | | Generator | [OneDrive](https://1drv.ms/u/s!AsqtTP8eWS-penA8AqrU8c_I4jU) | 76 | 77 | ## Evaluation 78 | 79 | Download the test set (e.g O-Haze). Simply put the test images in "dataroot" and set the correct path in config/framework_da.json about "dataroot"; 80 | 81 | Download the pretrained model and set the correct path in config/framework_da.json about "resume_state": 82 | 83 | ```json 84 | "path": { 85 | "log": "logs", 86 | "tb_logger": "tb_logger", 87 | "results": "results", 88 | "checkpoint": "checkpoint", 89 | "resume_state": "./ddpm_fcb_230221_121802" 90 | } 91 | "val": { 92 | "name": "dehaze_val", 93 | "mode": "LRHR", 94 | "dataroot": "dataset/O-HAZE-PROCESS", 95 | ... 96 | } 97 | ``` 98 | 99 | 100 | ```python 101 | # infer 102 | python infer.py -c [config file] 103 | ``` 104 | 105 | The default config file is config/framework_da.json. The outputs images are located at /data/diffusion/results. One can change output path in core/logger.py. 106 | 107 | ### Train 108 | 109 | Prepare train dataset and set the correct paths in config/framework_da.json about "datasets"; 110 | 111 | If training from scratch, make sure "resume_state" is null in config/framework_da.json. 112 | 113 | ```python 114 | # infer 115 | python train.py -c [config file] 116 | ``` 117 | 118 | ## Results 119 | Quantitative comparison on real-world hazy data (RTTS). Bold and underline indicate the best and the second-best, respectively. 120 |

121 | 122 |

123 | 124 | ## Todo 125 | 126 | 127 | - [x] Upload configs and pretrained models 128 | 129 | - [x] Upload evaluation scripts 130 | 131 | - [x] Upload train scripts 132 | -------------------------------------------------------------------------------- /config/framework_da.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "framework_da", 3 | "phase": "train", 4 | // train or val 5 | "gpu_ids": [ 6 | 0 7 | ], 8 | "change_sizes": { 9 | "0.0": 128, 10 | "0.3": 128, 11 | "0.9": 128, 12 | "1.01": 128 13 | }, 14 | "path": { 15 | //set the path 16 | "log": "logs", 17 | "tb_logger": "tb_logger", 18 | "results": "results", 19 | "checkpoint": "checkpoint", 20 | "resume_state": "./ddpm_fcb_230221_121802" 21 | // "resume_state": null 22 | }, 23 | "datasets": { 24 | "train": { 25 | "name": "RESIDE_train_syntheic", 26 | "mode": "HR", 27 | "dataroot": "dataset/RESIDE/HR_hazy_src", 28 | "hr_path": "dataset/RESIDE/HR", 29 | "datatype": "RESIDE_img_syntheic", 30 | "l_resolution": 128, 31 | "r_resolution": 128, 32 | "batch_size": 3, 33 | "num_workers": 12, 34 | "use_shuffle": true, 35 | "HazeAug": true, 36 | "rt_da_ref": [ 37 | "dataset/RESIDE/HR_hazy_src" 38 | ], 39 | "depth_img_path": "dataset/RESIDE/HR_depth/", 40 | "data_len": -1 // -1 represents all data used in train 41 | }, 42 | "val": { 43 | "name": "dehaze_val", 44 | "mode": "LRHR", 45 | // "dataroot": "dataset/I-HAZE-PROCESS", 46 | // "dataroot": "dataset/RTTS-PROCESS", 47 | "dataroot": "dataset/O-HAZE-PROCESS", 48 | // "dataroot": "dataset/DenseHaze", 49 | //"dataroot": "dataset/NhHaze", 50 | "datatype": "haze_img", 51 | 52 | "l_resolution": 512, 53 | "r_resolution": 512, 54 | "data_len": 5000 55 | } 56 | }, 57 | "model": { 58 | "which_model_G": "dehaze_with_z_v2", 59 | "finetune_norm": false, 60 | "FCB": true, 61 | "unet": { 62 | "in_channel": 6, 63 | "out_channel": 3, 64 | "inner_channel": 64, 65 | "norm_groups": 16, 66 | "channel_multiplier": [ 67 | 1, 68 | 2, 69 | 4, 70 | 8, 71 | 16 72 | ], 73 | "attn_res": [ 74 | // 16 75 | ], 76 | "res_blocks": 1, 77 | "dropout": 0.2 78 | }, 79 | "beta_schedule": { 80 | // use munual beta_schedule for acceleration 81 | "train": { 82 | "schedule": "linear", 83 | "n_timestep": 2000, 84 | "linear_start": 1e-6, 85 | "linear_end": 1e-2 86 | }, 87 | "val": { 88 | "schedule": "linear", 89 | "n_timestep": 2000, 90 | "linear_start": 1e-6, 91 | "linear_end": 1e-2 92 | } 93 | }, 94 | "diffusion": { 95 | "image_size": 128, 96 | "channels": 3, 97 | //sample channel 98 | "conditional": true, 99 | // unconditional generation or unconditional generation(super_resolution) 100 | "start_step": 1000 101 | } 102 | }, 103 | "train": { 104 | "n_iter": 2000000, 105 | "save_checkpoint_freq": 1e4, 106 | "print_freq": 50, 107 | "optimizer": { 108 | "type": "adam", 109 | "lr": 1e-4 110 | }, 111 | "ema_scheduler": { 112 | // not used now 113 | "step_start_ema": 5000, 114 | "update_ema_every": 1, 115 | "ema_decay": 0.9999 116 | } 117 | }, 118 | "wandb": { 119 | "project": "dehaze_with_z_v2" 120 | } 121 | } 122 | // Ask AI to edit or generate... 123 | 124 | -------------------------------------------------------------------------------- /core/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/core/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/core/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/wandb_logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/core/__pycache__/wandb_logger.cpython-38.pyc -------------------------------------------------------------------------------- /core/calc_indicator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import cv2 5 | from torchvision.utils import make_grid 6 | 7 | 8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 9 | ''' 10 | Converts a torch Tensor into an image Numpy array 11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 13 | ''' 14 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 15 | tensor = (tensor - min_max[0]) / \ 16 | (min_max[1] - min_max[0]) # to range [0,1] 17 | n_dim = tensor.dim() 18 | if n_dim == 4: 19 | n_img = len(tensor) 20 | img_np = make_grid(tensor, nrow=int( 21 | math.sqrt(n_img)), normalize=False).numpy() 22 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 23 | elif n_dim == 3: 24 | img_np = tensor.numpy() 25 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 26 | elif n_dim == 2: 27 | img_np = tensor.numpy() 28 | else: 29 | raise TypeError( 30 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 31 | if out_type == np.uint8: 32 | img_np = (img_np * 255.0).round() 33 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 34 | return img_np.astype(out_type) 35 | 36 | 37 | def save_img(img, img_path, mode='RGB'): 38 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 39 | # cv2.imwrite(img_path, img) 40 | 41 | 42 | def calculate_psnr(img1, img2): 43 | # img1 and img2 have range [0, 255] 44 | img1 = img1.astype(np.float64) 45 | img2 = img2.astype(np.float64) 46 | mse = np.mean((img1 - img2) ** 2) 47 | if mse == 0: 48 | return float('inf') 49 | return 20 * math.log10(255.0 / math.sqrt(mse)) 50 | 51 | 52 | def ssim(img1, img2): 53 | C1 = (0.01 * 255) ** 2 54 | C2 = (0.03 * 255) ** 2 55 | 56 | img1 = img1.astype(np.float64) 57 | img2 = img2.astype(np.float64) 58 | kernel = cv2.getGaussianKernel(11, 1.5) 59 | window = np.outer(kernel, kernel.transpose()) 60 | 61 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 62 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 63 | mu1_sq = mu1 ** 2 64 | mu2_sq = mu2 ** 2 65 | mu1_mu2 = mu1 * mu2 66 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 67 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 68 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 69 | 70 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 71 | (sigma1_sq + sigma2_sq + C2)) 72 | return ssim_map.mean() 73 | 74 | 75 | def calculate_ssim(img1, img2): 76 | '''calculate SSIM 77 | the same outputs as MATLAB's 78 | img1, img2: [0, 255] 79 | ''' 80 | if not img1.shape == img2.shape: 81 | raise ValueError('Input images must have the same dimensions.') 82 | if img1.ndim == 2: 83 | return ssim(img1, img2) 84 | elif img1.ndim == 3: 85 | if img1.shape[2] == 3: 86 | ssims = [] 87 | for i in range(3): 88 | ssims.append(ssim(img1, img2)) 89 | return np.array(ssims).mean() 90 | elif img1.shape[2] == 1: 91 | return ssim(np.squeeze(img1), np.squeeze(img2)) 92 | else: 93 | raise ValueError('Wrong input image dimensions.') 94 | 95 | 96 | if __name__ == "__main__": 97 | path1 = "/data/ImageDehazing/tmp/DenseHaze/GCA/" 98 | path2 = "/data/ImageDehazing/DenseHaze/HR" 99 | 100 | img1s = sorted(os.listdir(path1)) 101 | img2s = sorted(os.listdir(path2)) 102 | 103 | ave_p, ave_s = 0, 0 104 | for idx, (img1, img2) in enumerate(zip(img1s, img2s)): 105 | im1 = cv2.imread(os.path.join(path1, img1)) 106 | im2 = cv2.imread(os.path.join(path2, img2)) 107 | 108 | im1 = cv2.resize(im1, (512, 512)) 109 | im2 = cv2.resize(im2, (512, 512)) 110 | 111 | s = calculate_ssim(im1, im2) 112 | p = calculate_psnr(im1, im2) 113 | print(img1, img2, "psnr:{}".format(p), " ssim:{}".format(s)) 114 | 115 | ave_p += p 116 | ave_s += s 117 | 118 | print("ave ssim: {}, ave psnr:{}".format(ave_s / idx, ave_p / idx)) 119 | -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | from collections import OrderedDict 5 | import json 6 | from datetime import datetime 7 | 8 | 9 | def mkdirs(paths): 10 | if isinstance(paths, str): 11 | os.makedirs(paths, exist_ok=True) 12 | else: 13 | for path in paths: 14 | os.makedirs(path, exist_ok=True) 15 | 16 | 17 | def get_timestamp(): 18 | return datetime.now().strftime('%y%m%d_%H%M%S') 19 | 20 | 21 | def parse(args): 22 | phase = args.phase 23 | opt_path = args.config 24 | gpu_ids = args.gpu_ids 25 | enable_wandb = args.enable_wandb 26 | # remove comments starting with '//' 27 | json_str = '' 28 | with open(opt_path, 'r') as f: 29 | for line in f: 30 | line = line.split('//')[0] + '\n' 31 | json_str += line 32 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 33 | 34 | # set log directory 35 | if args.debug: 36 | opt['name'] = 'debug_{}'.format(opt['name']) 37 | experiments_root = os.path.join( 38 | '/data/diffusion/results', '{}_{}'.format(opt['name'], get_timestamp())) 39 | opt['path']['experiments_root'] = experiments_root 40 | for key, path in opt['path'].items(): 41 | if 'resume' not in key and 'experiments' not in key: 42 | opt['path'][key] = os.path.join(experiments_root, path) 43 | mkdirs(opt['path'][key]) 44 | 45 | # change dataset length limit 46 | opt['phase'] = phase 47 | 48 | # export CUDA_VISIBLE_DEVICES 49 | if gpu_ids is not None: 50 | opt['gpu_ids'] = [int(id) for id in gpu_ids.split(',')] 51 | gpu_list = gpu_ids 52 | else: 53 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 54 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 55 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 56 | if len(gpu_list) > 1: 57 | opt['distributed'] = True 58 | else: 59 | opt['distributed'] = False 60 | 61 | # debug 62 | if 'debug' in opt['name']: 63 | opt['train']['val_freq'] = 2 64 | opt['train']['print_freq'] = 2 65 | opt['train']['save_checkpoint_freq'] = 3 66 | opt['datasets']['train']['batch_size'] = 2 67 | opt['model']['beta_schedule']['train']['n_timestep'] = 10 68 | opt['model']['beta_schedule']['val']['n_timestep'] = 10 69 | opt['datasets']['train']['data_len'] = 6 70 | opt['datasets']['val']['data_len'] = 3 71 | 72 | # validation in train phase 73 | if phase == 'train': 74 | opt['datasets']['val']['data_len'] = 20 75 | 76 | # W&B Logging 77 | try: 78 | log_wandb_ckpt = args.log_wandb_ckpt 79 | opt['log_wandb_ckpt'] = log_wandb_ckpt 80 | except: 81 | pass 82 | try: 83 | log_eval = args.log_eval 84 | opt['log_eval'] = log_eval 85 | except: 86 | pass 87 | try: 88 | log_infer = args.log_infer 89 | opt['log_infer'] = log_infer 90 | except: 91 | pass 92 | opt['enable_wandb'] = enable_wandb 93 | 94 | return opt 95 | 96 | 97 | class NoneDict(dict): 98 | def __missing__(self, key): 99 | return None 100 | 101 | 102 | # convert to NoneDict, which return None for missing key. 103 | def dict_to_nonedict(opt): 104 | if isinstance(opt, dict): 105 | new_opt = dict() 106 | for key, sub_opt in opt.items(): 107 | new_opt[key] = dict_to_nonedict(sub_opt) 108 | return NoneDict(**new_opt) 109 | elif isinstance(opt, list): 110 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 111 | else: 112 | return opt 113 | 114 | 115 | def dict2str(opt, indent_l=1): 116 | '''dict to string for logger''' 117 | msg = '' 118 | for k, v in opt.items(): 119 | if isinstance(v, dict): 120 | msg += ' ' * (indent_l * 2) + k + ':[\n' 121 | msg += dict2str(v, indent_l + 1) 122 | msg += ' ' * (indent_l * 2) + ']\n' 123 | else: 124 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 125 | return msg 126 | 127 | 128 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): 129 | '''set up logger''' 130 | l = logging.getLogger(logger_name) 131 | formatter = logging.Formatter( 132 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') 133 | log_file = os.path.join(root, '{}.log'.format(phase)) 134 | fh = logging.FileHandler(log_file, mode='w') 135 | fh.setFormatter(formatter) 136 | l.setLevel(level) 137 | l.addHandler(fh) 138 | if screen: 139 | sh = logging.StreamHandler() 140 | sh.setFormatter(formatter) 141 | l.addHandler(sh) 142 | 143 | return fh 144 | -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import cv2 5 | from torchvision.utils import make_grid 6 | 7 | 8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 9 | ''' 10 | Converts a torch Tensor into an image Numpy array 11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 13 | ''' 14 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 15 | tensor = (tensor - min_max[0]) / \ 16 | (min_max[1] - min_max[0]) # to range [0,1] 17 | n_dim = tensor.dim() 18 | if n_dim == 4: 19 | n_img = len(tensor) 20 | img_np = make_grid(tensor, nrow=int( 21 | math.sqrt(n_img)), normalize=False).numpy() 22 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 23 | elif n_dim == 3: 24 | img_np = tensor.numpy() 25 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 26 | elif n_dim == 2: 27 | img_np = tensor.numpy() 28 | else: 29 | raise TypeError( 30 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 31 | if out_type == np.uint8: 32 | img_np = (img_np * 255.0).round() 33 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 34 | return img_np.astype(out_type) 35 | 36 | 37 | def save_img(img, img_path, mode='RGB'): 38 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 39 | # cv2.imwrite(img_path, img) 40 | 41 | 42 | def calculate_psnr(img1, img2): 43 | # img1 and img2 have range [0, 255] 44 | img1 = img1.astype(np.float64) 45 | img2 = img2.astype(np.float64) 46 | mse = np.mean((img1 - img2) ** 2) 47 | if mse == 0: 48 | return float('inf') 49 | return 20 * math.log10(255.0 / math.sqrt(mse)) 50 | 51 | 52 | def ssim(img1, img2): 53 | C1 = (0.01 * 255) ** 2 54 | C2 = (0.03 * 255) ** 2 55 | 56 | img1 = img1.astype(np.float64) 57 | img2 = img2.astype(np.float64) 58 | kernel = cv2.getGaussianKernel(11, 1.5) 59 | window = np.outer(kernel, kernel.transpose()) 60 | 61 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 62 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 63 | mu1_sq = mu1 ** 2 64 | mu2_sq = mu2 ** 2 65 | mu1_mu2 = mu1 * mu2 66 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 67 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 68 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 69 | 70 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 71 | (sigma1_sq + sigma2_sq + C2)) 72 | return ssim_map.mean() 73 | 74 | 75 | def calculate_ssim(img1, img2): 76 | '''calculate SSIM 77 | the same outputs as MATLAB's 78 | img1, img2: [0, 255] 79 | ''' 80 | if not img1.shape == img2.shape: 81 | raise ValueError('Input images must have the same dimensions.') 82 | if img1.ndim == 2: 83 | return ssim(img1, img2) 84 | elif img1.ndim == 3: 85 | if img1.shape[2] == 3: 86 | ssims = [] 87 | for i in range(3): 88 | ssims.append(ssim(img1, img2)) 89 | return np.array(ssims).mean() 90 | elif img1.shape[2] == 1: 91 | return ssim(np.squeeze(img1), np.squeeze(img2)) 92 | else: 93 | raise ValueError('Wrong input image dimensions.') 94 | -------------------------------------------------------------------------------- /core/wandb_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class WandbLogger: 4 | """ 5 | Log using `Weights and Biases`. 6 | """ 7 | def __init__(self, opt): 8 | try: 9 | import wandb 10 | except ImportError: 11 | raise ImportError( 12 | "To use the Weights and Biases Logger please install wandb." 13 | "Run `pip install wandb` to install it." 14 | ) 15 | 16 | self._wandb = wandb 17 | 18 | # Initialize a W&B run 19 | if self._wandb.run is None: 20 | self._wandb.init( 21 | project=opt['wandb']['project'], 22 | config=opt, 23 | dir='./experiments' 24 | ) 25 | 26 | self.config = self._wandb.config 27 | 28 | if self.config.get('log_eval', None): 29 | self.eval_table = self._wandb.Table(columns=['fake_image', 30 | 'sr_image', 31 | 'hr_image', 32 | 'psnr', 33 | 'ssim']) 34 | else: 35 | self.eval_table = None 36 | 37 | if self.config.get('log_infer', None): 38 | self.infer_table = self._wandb.Table(columns=['fake_image', 39 | 'sr_image', 40 | 'hr_image']) 41 | else: 42 | self.infer_table = None 43 | 44 | def log_metrics(self, metrics, commit=True): 45 | """ 46 | Log train/validation metrics onto W&B. 47 | 48 | metrics: dictionary of metrics to be logged 49 | """ 50 | self._wandb.log(metrics, commit=commit) 51 | 52 | def log_image(self, key_name, image_array): 53 | """ 54 | Log image array onto W&B. 55 | 56 | key_name: name of the key 57 | image_array: numpy array of image. 58 | """ 59 | self._wandb.log({key_name: self._wandb.Image(image_array)}) 60 | 61 | def log_images(self, key_name, list_images): 62 | """ 63 | Log list of image array onto W&B 64 | 65 | key_name: name of the key 66 | list_images: list of numpy image arrays 67 | """ 68 | self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]}) 69 | 70 | def log_checkpoint(self, current_epoch, current_step): 71 | """ 72 | Log the model checkpoint as W&B artifacts 73 | 74 | current_epoch: the current epoch 75 | current_step: the current batch step 76 | """ 77 | model_artifact = self._wandb.Artifact( 78 | self._wandb.run.id + "_model", type="model" 79 | ) 80 | 81 | gen_path = os.path.join( 82 | self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch)) 83 | opt_path = os.path.join( 84 | self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch)) 85 | 86 | model_artifact.add_file(gen_path) 87 | model_artifact.add_file(opt_path) 88 | self._wandb.log_artifact(model_artifact, aliases=["latest"]) 89 | 90 | def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None): 91 | """ 92 | Add data row-wise to the initialized table. 93 | """ 94 | if psnr is not None and ssim is not None: 95 | self.eval_table.add_data( 96 | self._wandb.Image(fake_img), 97 | self._wandb.Image(sr_img), 98 | self._wandb.Image(hr_img), 99 | psnr, 100 | ssim 101 | ) 102 | else: 103 | self.infer_table.add_data( 104 | self._wandb.Image(fake_img), 105 | self._wandb.Image(sr_img), 106 | self._wandb.Image(hr_img) 107 | ) 108 | 109 | def log_eval_table(self, commit=False): 110 | """ 111 | Log the table 112 | """ 113 | if self.eval_table: 114 | self._wandb.log({'eval_data': self.eval_table}, commit=commit) 115 | elif self.infer_table: 116 | self._wandb.log({'infer_data': self.infer_table}, commit=commit) 117 | -------------------------------------------------------------------------------- /data/FDA.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 4 | import numpy as np 5 | from PIL import Image 6 | # from FDA_utils import FDA_source_to_target_np 7 | import cv2 8 | import scipy.misc 9 | # from matplotlib import image 10 | import torch 11 | import numpy as np 12 | 13 | 14 | def extract_ampl_phase(fft_im): 15 | # fft_im: size should be bx3xhxwx2 16 | fft_amp = fft_im[:, :, :, :, 0] ** 2 + fft_im[:, :, :, :, 1] ** 2 17 | fft_amp = torch.sqrt(fft_amp) 18 | fft_pha = torch.atan2(fft_im[:, :, :, :, 1], fft_im[:, :, :, :, 0]) 19 | return fft_amp, fft_pha 20 | 21 | 22 | def low_freq_mutate(amp_src, amp_trg, L=0.1): 23 | _, _, h, w = amp_src.size() 24 | b = (np.floor(np.amin((h, w)) * L)).astype(int) # get b 25 | amp_src[:, :, 0:b, 0:b] = amp_trg[:, :, 0:b, 0:b] # top left 26 | amp_src[:, :, 0:b, w - b:w] = amp_trg[:, :, 0:b, w - b:w] # top right 27 | amp_src[:, :, h - b:h, 0:b] = amp_trg[:, :, h - b:h, 0:b] # bottom left 28 | amp_src[:, :, h - b:h, w - b:w] = amp_trg[:, :, h - b:h, w - b:w] # bottom right 29 | return amp_src 30 | 31 | 32 | def low_freq_mutate_np(amp_src, amp_trg, L=0.1): 33 | a_src = np.fft.fftshift(amp_src, axes=(-2, -1)) 34 | a_trg = np.fft.fftshift(amp_trg, axes=(-2, -1)) 35 | 36 | _, h, w = a_src.shape 37 | b = (np.floor(np.amin((h, w)) * L)).astype(int) 38 | c_h = np.floor(h / 2.0).astype(int) 39 | c_w = np.floor(w / 2.0).astype(int) 40 | 41 | h1 = c_h - b 42 | h2 = c_h + b + 1 43 | w1 = c_w - b 44 | w2 = c_w + b + 1 45 | 46 | a_src[:, h1:h2, w1:w2] = a_trg[:, h1:h2, w1:w2] 47 | a_src = np.fft.ifftshift(a_src, axes=(-2, -1)) 48 | return a_src 49 | 50 | 51 | def FDA_source_to_target(src_img, trg_img, L=0.1): 52 | # exchange magnitude 53 | # input: src_img, trg_img 54 | 55 | # get fft of both source and target 56 | fft_src = torch.rfft(src_img.clone(), signal_ndim=2, onesided=False) 57 | fft_trg = torch.rfft(trg_img.clone(), signal_ndim=2, onesided=False) 58 | 59 | # extract amplitude and phase of both ffts 60 | amp_src, pha_src = extract_ampl_phase(fft_src.clone()) 61 | amp_trg, pha_trg = extract_ampl_phase(fft_trg.clone()) 62 | 63 | # replace the low frequency amplitude part of source with that from target 64 | amp_src_ = low_freq_mutate(amp_src.clone(), amp_trg.clone(), L=L) 65 | 66 | # recompose fft of source 67 | fft_src_ = torch.zeros(fft_src.size(), dtype=torch.float) 68 | fft_src_[:, :, :, :, 0] = torch.cos(pha_src.clone()) * amp_src_.clone() 69 | fft_src_[:, :, :, :, 1] = torch.sin(pha_src.clone()) * amp_src_.clone() 70 | 71 | # get the recomposed image: source content, target style 72 | _, _, imgH, imgW = src_img.size() 73 | src_in_trg = torch.irfft(fft_src_, signal_ndim=2, onesided=False, signal_sizes=[imgH, imgW]) 74 | 75 | return src_in_trg 76 | 77 | 78 | def FDA_source_to_target_np(src_img, trg_img, L=0.1): 79 | # exchange magnitude 80 | # input: src_img, trg_img 81 | 82 | src_img_np = src_img # .cpu().numpy() 83 | trg_img_np = trg_img # .cpu().numpy() 84 | 85 | # get fft of both source and target 86 | fft_src_np = np.fft.fft2(src_img_np, axes=(-2, -1)) 87 | fft_trg_np = np.fft.fft2(trg_img_np, axes=(-2, -1)) 88 | 89 | # extract amplitude and phase of both ffts 90 | amp_src, pha_src = np.abs(fft_src_np), np.angle(fft_src_np) 91 | amp_trg, pha_trg = np.abs(fft_trg_np), np.angle(fft_trg_np) 92 | 93 | # mutate the amplitude part of source with target 94 | amp_src_ = low_freq_mutate_np(amp_src, amp_trg, L=L) 95 | 96 | # mutated fft of source 97 | fft_src_ = amp_src_ * np.exp(1j * pha_src) 98 | 99 | # get the mutated image 100 | src_in_trg = np.fft.ifft2(fft_src_, axes=(-2, -1)) 101 | src_in_trg = np.real(src_in_trg) 102 | 103 | return src_in_trg 104 | 105 | 106 | def trans_image_by_ref(in_path, ref_path, value=0.002): 107 | # im_src = Image.open(in_path).convert('RGB') 108 | im_src = in_path 109 | im_trg = Image.open(ref_path).convert('RGB') 110 | src_h, src_w, src_c = np.shape(im_src) 111 | 112 | im_src = im_src.resize((1024, 512), Image.BICUBIC) 113 | im_trg = im_trg.resize((1024, 512), Image.BICUBIC) 114 | 115 | im_src = np.asarray(im_src, np.float32) 116 | im_trg = np.asarray(im_trg, np.float32) 117 | 118 | im_src = im_src.transpose((2, 0, 1)) 119 | im_trg = im_trg.transpose((2, 0, 1)) 120 | 121 | src_in_trg = FDA_source_to_target_np(im_src, im_trg, L=value) 122 | 123 | src_in_trg = src_in_trg.transpose((1, 2, 0)) 124 | 125 | # recover to src size 126 | src_in_trg = cv2.resize(src_in_trg, (src_w, src_h)) 127 | 128 | src_in_trg = (src_in_trg - np.min(src_in_trg)) / (np.max(src_in_trg) - np.min(src_in_trg)) * 255 129 | 130 | # scipy.misc.toimage(src_in_trg, cmin=0.0, cmax=255.0).save('src_in_tar.png') 131 | # image.imsave('src_in_tar.png',src_in_trg) # cmap常用于改变绘制风格,如黑白gray,翠绿色virdidis 132 | 133 | # from PIL import Iamge 134 | img = Image.fromarray(np.uint8(src_in_trg)) # .covert('RGB') 135 | 136 | return img 137 | -------------------------------------------------------------------------------- /data/HazeAug.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import cv2 5 | from data.FDA import trans_image_by_ref 6 | 7 | from PIL import Image, ImageFilter 8 | 9 | depth_argu = False 10 | 11 | 12 | def depth_change(depth): 13 | depth_strategy = np.random.uniform(0, 1) 14 | 15 | if 0.4 <= depth_strategy < 0.7: 16 | strategy = 'gamma' 17 | elif 0.7 <= depth_strategy < 1.0: 18 | strategy = 'normalize' 19 | else: 20 | strategy = 'identity' 21 | 22 | if strategy == "gamma": 23 | factor = np.random.uniform(0.2, 1.8) 24 | 25 | depth = np.array(depth ** factor) 26 | 27 | elif strategy == "normalize": 28 | # normalize float versions 29 | factor_alpha = np.random.uniform(0, 0.4) 30 | factor_beta = np.random.uniform(0, 2) 31 | depth = cv2.normalize(depth, None, alpha=factor_alpha, beta=factor_beta, norm_type=cv2.NORM_MINMAX, 32 | dtype=cv2.CV_32F) 33 | 34 | return depth 35 | 36 | 37 | class MyGaussianBlur(ImageFilter.Filter): 38 | name = "GaussianBlur" 39 | 40 | def __init__(self, radius=1, bounds=None): 41 | self.radius = radius 42 | self.bounds = bounds 43 | 44 | def filter(self, image): 45 | if self.bounds: 46 | clips = image.crop(self.bounds).gaussian_blur(self.radius) 47 | image.paste(clips, self.bounds) 48 | return image 49 | else: 50 | return image.gaussian_blur(self.radius) 51 | 52 | 53 | def rt_haze_enhancement(pil_img, depth_path, ref_path): 54 | # add_haze 55 | A = np.random.rand() * 1.3 + 0.5 56 | beta = 2 * np.random.rand() + 0.8 57 | color_strategy = np.random.rand() 58 | if color_strategy <= 0.5: 59 | strategy = 'colour_cast' 60 | # elif 0.3 < color_strategy <= 0.6: 61 | # strategy = 'luminance' 62 | else: 63 | strategy = 'add_hazy' 64 | 65 | img = cv2.imread(pil_img) 66 | depth = cv2.imread(depth_path, cv2.IMREAD_GRAYSCALE) / 256.0 # + 1e-7 67 | 68 | if depth_argu == False: 69 | depth = depth_change(depth) 70 | 71 | img_f = img / 255.0 # 归一化 72 | 73 | td_bk = np.exp(- np.array(depth) * beta) 74 | td_bk = np.expand_dims(td_bk, axis=-1).repeat(3, axis=-1) 75 | img_bk = np.array(img_f) * td_bk + A * (1 - td_bk) 76 | 77 | img_bk = img_bk / np.max(img_bk) * 255 78 | img_bk = img_bk[:, :, ::-1] 79 | 80 | if strategy == 'colour_cast': 81 | img_bk = Image.fromarray(np.uint8(img_bk)) # .covert('RGB') 82 | img_bk = trans_image_by_ref( 83 | in_path=img_bk, 84 | ref_path=ref_path, 85 | value=np.random.rand() * 0.002 + 0.0001 86 | ) 87 | 88 | if strategy == 'luminance': 89 | img_bk = np.power(img_bk, 0.95) # 对像素值指数变换 90 | img_bk = Image.fromarray(np.uint8(img_bk)) # .covert('RGB') 91 | 92 | else: 93 | img_bk = Image.fromarray(np.uint8(img_bk)) # .covert('RGB') 94 | 95 | img_bk = img_bk.filter(ImageFilter.SMOOTH_MORE) 96 | 97 | return img_bk 98 | -------------------------------------------------------------------------------- /data/LRHR_dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import lmdb 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import random 6 | import data.util as Util 7 | import h5py, os 8 | import numpy as np 9 | import torch 10 | import copy 11 | 12 | from data.HazeAug import rt_haze_enhancement 13 | 14 | 15 | def neibor_16_mul(num, size=32): 16 | a = num // size 17 | b = num % size 18 | if b >= 0.5 * size: 19 | return size * (a + 1) 20 | else: 21 | return size * a 22 | 23 | 24 | class LRHRDataset(Dataset): 25 | def __init__(self, dataroot, datatype, l_resolution=16, r_resolution=128, split='train', data_len=-1, need_LR=False, 26 | other_params=None): 27 | self.datatype = datatype 28 | self.l_res = l_resolution 29 | self.r_res = r_resolution 30 | self.data_len = data_len 31 | self.need_LR = need_LR 32 | self.split = split 33 | 34 | self.down_sample = other_params['down_sample'] if "down_sample" in other_params.keys() else None 35 | self.real_hr_path = other_params['hr_path'] if "hr_path" in other_params.keys() else None 36 | 37 | # rt daRESIDE_img_syntheic 38 | self.rt_da = other_params['HazeAug'] if "HazeAug" in other_params.keys() else None 39 | if self.rt_da: 40 | self.rt_da_ref = other_params['rt_da_ref'] 41 | self.ref_imgs = [] 42 | for dir in self.rt_da_ref: 43 | self.ref_imgs += [os.path.join(dir, i) for i in os.listdir(dir)] 44 | self.depth_path = other_params['depth_img_path'] 45 | 46 | if datatype in ["haze_img"]: 47 | self.sr_path = Util.get_paths_from_images("{}/HR_hazy".format(dataroot)) 48 | 49 | self.hr_path = Util.get_paths_from_images("{}/HR".format(dataroot)) 50 | self.dataset_len = len(self.hr_path) 51 | 52 | self.dis_prefix = other_params['distanse_prefix'] if "distanse_prefix" in other_params.keys() else None 53 | if self.data_len <= 0: 54 | self.data_len = self.dataset_len 55 | else: 56 | self.data_len = min(self.data_len, self.dataset_len) 57 | 58 | elif datatype in ["RESIDE_img_syntheic"]: 59 | 60 | self.sr_path = Util.get_paths_from_images(dataroot) 61 | self.hr_path = self.sr_path 62 | self.dataset_len = len(self.hr_path) 63 | if self.data_len <= 0: 64 | self.data_len = self.dataset_len 65 | else: 66 | self.data_len = min(self.data_len, self.dataset_len) 67 | 68 | else: 69 | raise NotImplementedError( 70 | 'data_type [{:s}] is not recognized.'.format(datatype)) 71 | 72 | def __len__(self): 73 | return self.data_len 74 | 75 | def __getitem__(self, index): 76 | img_HR = None 77 | img_LR = None 78 | 79 | if self.datatype in ["RESIDE_img_syntheic"]: 80 | 81 | if self.rt_da: 82 | 83 | img_SR = rt_haze_enhancement( 84 | self.sr_path[index], 85 | os.path.join(self.depth_path, "{}.png".format(self.sr_path[index].split("/")[-1].split("_")[0])), 86 | ref_path=np.random.choice(self.ref_imgs) 87 | ) 88 | else: 89 | img_SR = Image.open(self.sr_path[index]).convert("RGB") 90 | 91 | img_SR = img_SR.resize((self.r_res, self.r_res)) 92 | 93 | # hr_path 94 | hr_path = "{}/{}.png".format( 95 | self.real_hr_path, 96 | self.sr_path[index].split("/")[-1].split("_")[0] 97 | ) 98 | img_HR = Image.open(hr_path).convert("RGB") 99 | img_HR = img_HR.resize((self.r_res, self.r_res)) 100 | 101 | if self.need_LR: 102 | img_LR = img_SR 103 | 104 | else: 105 | img_HR = Image.open(self.hr_path[index]).convert("RGB") 106 | img_SR = Image.open(self.sr_path[index]).convert("RGB") 107 | if self.need_LR: 108 | img_LR = Image.open(self.sr_path[index]).convert("RGB") 109 | 110 | if self.down_sample is not None: 111 | img_HR = self.resize(img_HR) 112 | img_SR = self.resize(img_SR) 113 | img_LR = self.resize(img_LR) 114 | 115 | if self.dis_prefix != None: 116 | img_depth = self.resize(img_depth) 117 | 118 | else: 119 | img_HR = self.resize_to_resolution(img_HR) 120 | img_SR = self.resize_to_resolution(img_SR) 121 | img_LR = self.resize_to_resolution(img_LR) 122 | 123 | if self.need_LR: 124 | [img_LR, img_SR, img_HR] = Util.transform_augment( 125 | [img_LR, img_SR, img_HR], split=self.split, min_max=(-1, 1)) 126 | 127 | return {'LR': img_LR, 'HR': img_HR, 'SR': img_SR, 'Index': index} 128 | else: 129 | [img_SR, img_HR] = Util.transform_augment( 130 | [img_SR, img_HR], split=self.split, min_max=(-1, 1)) 131 | 132 | return {'HR': img_HR, 'SR': img_SR, 'Index': index} 133 | 134 | def resize(self, input_image): 135 | H, W = np.shape(input_image)[:2] 136 | resize_H, resize_W = neibor_16_mul(int(H / self.down_sample)), neibor_16_mul(int(W / self.down_sample)) 137 | out_image = input_image.resize((resize_W, resize_H)) 138 | return out_image 139 | 140 | def resize_to_resolution(self, input_image): 141 | out_image = input_image.resize((self.r_res, self.r_res)) 142 | return out_image 143 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | '''create dataset and dataloader''' 2 | import logging 3 | from re import split 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, phase): 8 | '''create dataloader ''' 9 | if phase == 'train': 10 | return torch.utils.data.DataLoader( 11 | dataset, 12 | batch_size=dataset_opt['batch_size'], 13 | shuffle=dataset_opt['use_shuffle'], 14 | num_workers=dataset_opt['num_workers'], 15 | pin_memory=True) 16 | elif phase == 'val': 17 | return torch.utils.data.DataLoader( 18 | dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 19 | else: 20 | raise NotImplementedError( 21 | 'Dataloader [{:s}] is not found.'.format(phase)) 22 | 23 | 24 | def create_dataset(dataset_opt, phase): 25 | '''create dataset''' 26 | mode = dataset_opt['mode'] 27 | from data.LRHR_dataset import LRHRDataset as D 28 | dataset = D(dataroot=dataset_opt['dataroot'], 29 | datatype=dataset_opt['datatype'], 30 | l_resolution=dataset_opt['l_resolution'], 31 | r_resolution=dataset_opt['r_resolution'], 32 | split=phase, 33 | data_len=dataset_opt['data_len'], 34 | need_LR=(mode == 'LRHR'), 35 | other_params=dataset_opt 36 | ) 37 | logger = logging.getLogger('base') 38 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 39 | dataset_opt['name'])) 40 | return dataset 41 | -------------------------------------------------------------------------------- /data/__pycache__/FDA.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/data/__pycache__/FDA.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/HazeAug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/data/__pycache__/HazeAug.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/LRHR_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/data/__pycache__/LRHR_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/data/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /data/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import random 5 | import numpy as np 6 | 7 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] 9 | 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | 15 | def get_paths_from_images(path): 16 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 17 | images = [] 18 | for dirpath, _, fnames in sorted(os.walk(path)): 19 | for fname in sorted(fnames): 20 | if is_image_file(fname): 21 | img_path = os.path.join(dirpath, fname) 22 | images.append(img_path) 23 | assert images, '{:s} has no valid image file'.format(path) 24 | return sorted(images) 25 | 26 | 27 | def augment(img_list, hflip=True, rot=True, split='val'): 28 | # horizontal flip OR rotate 29 | hflip = hflip and (split == 'train' and random.random() < 0.5) 30 | vflip = rot and (split == 'train' and random.random() < 0.5) 31 | rot90 = rot and (split == 'train' and random.random() < 0.5) 32 | 33 | def _augment(img): 34 | if hflip: 35 | img = img[:, ::-1, :] 36 | if vflip: 37 | img = img[::-1, :, :] 38 | if rot90: 39 | img = img.transpose(1, 0, 2) 40 | return img 41 | 42 | return [_augment(img) for img in img_list] 43 | 44 | 45 | def transform2numpy(img): 46 | img = np.array(img) 47 | img = img.astype(np.float32) / 255. 48 | if img.ndim == 2: 49 | img = np.expand_dims(img, axis=2) 50 | # some images have 4 channels 51 | if img.shape[2] > 3: 52 | img = img[:, :, :3] 53 | return img 54 | 55 | 56 | def transform2tensor(img, min_max=(0, 1)): 57 | # HWC to CHW 58 | img = torch.from_numpy(np.ascontiguousarray( 59 | np.transpose(img, (2, 0, 1)))).float() 60 | # to range min_max 61 | img = img * (min_max[1] - min_max[0]) + min_max[0] 62 | return img 63 | 64 | 65 | totensor = torchvision.transforms.ToTensor() 66 | hflip = torchvision.transforms.RandomHorizontalFlip() 67 | 68 | 69 | def transform_augment(img_list, split='val', min_max=(0, 1)): 70 | imgs = [totensor(img) for img in img_list] 71 | if split == 'train': 72 | imgs = torch.stack(imgs, 0) 73 | imgs = hflip(imgs) 74 | imgs = torch.unbind(imgs, dim=0) 75 | ret_img = [img * (min_max[1] - min_max[0]) + min_max[0] for img in imgs] 76 | return ret_img 77 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import data as Data 3 | import model as Model 4 | import argparse 5 | import logging 6 | import core.logger as Logger 7 | import core.metrics as Metrics 8 | from core.wandb_logger import WandbLogger 9 | from tensorboardX import SummaryWriter 10 | import os 11 | import numpy as np 12 | # from brisque import BRISQUE 13 | import cv2 14 | import random 15 | 16 | seed = 6666 17 | print('Random seed: {}'.format(seed)) 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | 24 | def calc_mean_rgb(img): 25 | H, W, C = np.shape(img) 26 | img = np.reshape(img, (H * W, C)) 27 | return np.mean(img, axis=0) 28 | 29 | 30 | def fix_img(img, img_ref): 31 | sr_R, sr_G, sr_B = calc_mean_rgb(img) 32 | hr_R, hr_G, hr_B = calc_mean_rgb(img_ref) 33 | 34 | R, G, B = sr_R - hr_R, sr_G - hr_G, sr_B - hr_B 35 | R = np.array(img[:, :, 0]) - R 36 | G = np.array(img[:, :, 1]) - G 37 | B = np.array(img[:, :, 2]) - B 38 | 39 | R = np.expand_dims(R, axis=-1) 40 | G = np.expand_dims(G, axis=-1) 41 | B = np.expand_dims(B, axis=-1) 42 | 43 | return np.array(np.concatenate((R, G, B), axis=-1), dtype=np.uint8) 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('-c', '--config', type=str, default='config/framework_da.json', 49 | help='JSON file for configuration') 50 | parser.add_argument('-p', '--phase', type=str, choices=['val'], help='val(generation)', default='val') 51 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) 52 | parser.add_argument('-debug', '-d', action='store_true') 53 | parser.add_argument('-enable_wandb', action='store_true') 54 | parser.add_argument('-log_infer', action='store_true') 55 | parser.add_argument('-color_fix', default=False) 56 | 57 | # parse configs 58 | args = parser.parse_args() 59 | print(args) 60 | opt = Logger.parse(args) 61 | # Convert to NoneDict, which return None for missing key. 62 | opt = Logger.dict_to_nonedict(opt) 63 | 64 | # logging 65 | torch.backends.cudnn.enabled = True 66 | torch.backends.cudnn.benchmark = True 67 | 68 | Logger.setup_logger(None, opt['path']['log'], 69 | 'train', level=logging.INFO, screen=True) 70 | Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) 71 | logger = logging.getLogger('base') 72 | logger.info(Logger.dict2str(opt)) 73 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger']) 74 | 75 | # Initialize WandbLogger 76 | if opt['enable_wandb']: 77 | wandb_logger = WandbLogger(opt) 78 | else: 79 | wandb_logger = None 80 | 81 | # dataset 82 | for phase, dataset_opt in opt['datasets'].items(): 83 | if phase == 'val': 84 | val_set = Data.create_dataset(dataset_opt, phase) 85 | val_loader = Data.create_dataloader( 86 | val_set, dataset_opt, phase) 87 | logger.info('Initial Dataset Finished') 88 | 89 | # model 90 | diffusion = Model.create_model(opt) 91 | logger.info('Initial Model Finished') 92 | 93 | diffusion.set_new_noise_schedule( 94 | opt['model']['beta_schedule']['val'], schedule_phase='val') 95 | 96 | logger.info('Begin Model Inference.') 97 | current_step = 0 98 | current_epoch = 0 99 | idx = 0 100 | avg_psnr = 0.0 101 | avg_ssim = 0.0 102 | 103 | result_path = '{}'.format(opt['path']['results']) 104 | os.makedirs(result_path, exist_ok=True) 105 | for _, val_data in enumerate(val_loader): 106 | 107 | idx += 1 108 | 109 | diffusion.feed_data(val_data) 110 | diffusion.test(continous=True) 111 | visuals = diffusion.get_current_visuals(need_LR=False) 112 | 113 | visuals['SR'] = torch.cat([visuals['SR'], visuals['HR']], dim=0) 114 | 115 | hr_img = Metrics.tensor2img(visuals['HR']) # uint8 116 | fake_img = Metrics.tensor2img(visuals['INF']) # uint8 117 | 118 | sr_img_mode = 'grid' 119 | if sr_img_mode == 'single': 120 | # single img series 121 | sr_img = visuals['SR'] # uint8 122 | sample_num = sr_img.shape[0] 123 | for iter in range(0, sample_num): 124 | Metrics.save_img( 125 | Metrics.tensor2img(sr_img[iter]), '{}/{}_{}_sr_{}.png'.format(result_path, current_step, idx, iter)) 126 | else: 127 | # grid img 128 | sr_img = Metrics.tensor2img(visuals['SR']) # uint8 129 | 130 | h, w, c = np.shape(hr_img) 131 | 132 | # try: 133 | # sr_img[-h-2:-2, -w-2:-2, :] = hr_img 134 | # except: 135 | # pass 136 | 137 | Metrics.save_img( 138 | sr_img, '{}/{}_{}_sr_process.png'.format(result_path, current_step, idx)) 139 | Metrics.save_img( 140 | Metrics.tensor2img(visuals['SR'][-2]), '{}/{}_{}_sr.png'.format(result_path, current_step, idx)) 141 | 142 | Metrics.save_img( 143 | hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx)) 144 | Metrics.save_img( 145 | fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx)) 146 | 147 | sr_img = Metrics.tensor2img(visuals['SR'][-2]) 148 | if args.color_fix: 149 | # print(sr_img) 150 | # print(fake_img) 151 | # print(sr_img.shape) 152 | # print(fake_img.shape) 153 | 154 | sr_img = fix_img(sr_img, fake_img) 155 | # cv2.imwrite('{}/{}_{}_sr.png'.format(result_path, current_step, idx), sr_img) 156 | 157 | psnr = Metrics.calculate_psnr(sr_img, hr_img) 158 | ssim = Metrics.calculate_ssim(sr_img, hr_img) 159 | # brisque = BRISQUE('{}/{}_{}_sr.png'.format(result_path, current_step, idx)).score() 160 | brisque = 0 161 | 162 | avg_psnr += psnr 163 | avg_ssim += ssim 164 | print(f"psnr: {psnr}, ssim:{ssim}, save to {'{}/{}_{}_sr_process.png'.format(result_path, current_step, idx)}") 165 | 166 | if wandb_logger and opt['log_infer']: 167 | wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img) 168 | 169 | avg_psnr = avg_psnr / idx 170 | avg_ssim = avg_ssim / idx 171 | 172 | print(f"avg_psnr: {avg_psnr}, avg_ssim:{avg_ssim}") 173 | 174 | if wandb_logger and opt['log_infer']: 175 | wandb_logger.log_eval_table(commit=True) 176 | -------------------------------------------------------------------------------- /misc/RTTS.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/misc/RTTS.jpg -------------------------------------------------------------------------------- /misc/framework-v3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/misc/framework-v3.jpg -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | logger = logging.getLogger('base') 4 | 5 | 6 | def create_model(opt): 7 | from .model import DDPM as M 8 | m = M(opt) 9 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 10 | # m = nn.DataParallel(m) 11 | 12 | return m 13 | -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/base_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/__pycache__/base_model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/__pycache__/networks.cpython-38.pyc -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseModel(): 7 | def __init__(self, opt): 8 | self.opt = opt 9 | self.device = torch.device( 10 | 'cuda' if opt['gpu_ids'] is not None else 'cpu') 11 | self.begin_step = 0 12 | self.begin_epoch = 0 13 | 14 | def feed_data(self, data): 15 | pass 16 | 17 | def optimize_parameters(self, current_step): 18 | pass 19 | 20 | def get_current_visuals(self): 21 | pass 22 | 23 | def get_current_losses(self): 24 | pass 25 | 26 | def print_network(self): 27 | pass 28 | 29 | def set_device(self, x): 30 | if isinstance(x, dict): 31 | for key, item in x.items(): 32 | if item is not None: 33 | x[key] = item.to(self.device) 34 | elif isinstance(x, list): 35 | for item in x: 36 | if item is not None: 37 | item = item.to(self.device) 38 | else: 39 | x = x.to(self.device) 40 | return x 41 | 42 | def get_network_description(self, network): 43 | '''Get the string and total parameters of the network''' 44 | if isinstance(network, nn.DataParallel): 45 | network = network.module 46 | s = str(network) 47 | n = sum(map(lambda x: x.numel(), network.parameters())) 48 | return s, n 49 | -------------------------------------------------------------------------------- /model/dehaze_with_z_v2_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/dehaze_with_z_v2_modules/__init__.py -------------------------------------------------------------------------------- /model/dehaze_with_z_v2_modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/dehaze_with_z_v2_modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/dehaze_with_z_v2_modules/__pycache__/diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/dehaze_with_z_v2_modules/__pycache__/diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /model/dehaze_with_z_v2_modules/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/W-Jilly/frequency-compensated-diffusion-model-pytorch/7a23e3aeb4b0cd66ad43604ce99329795686b7f4/model/dehaze_with_z_v2_modules/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /model/dehaze_with_z_v2_modules/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import device, nn, einsum 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | from functools import partial 7 | import numpy as np 8 | from tqdm import tqdm 9 | from torchvision.transforms import Resize 10 | 11 | import copy 12 | 13 | 14 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): 15 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 16 | warmup_time = int(n_timestep * warmup_frac) 17 | betas[:warmup_time] = np.linspace( 18 | linear_start, linear_end, warmup_time, dtype=np.float64) 19 | return betas 20 | 21 | 22 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 23 | if schedule == 'quad': 24 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, 25 | n_timestep, dtype=np.float64) ** 2 26 | elif schedule == 'linear': 27 | betas = np.linspace(linear_start, linear_end, 28 | n_timestep, dtype=np.float64) 29 | elif schedule == 'warmup10': 30 | betas = _warmup_beta(linear_start, linear_end, 31 | n_timestep, 0.1) 32 | elif schedule == 'warmup50': 33 | betas = _warmup_beta(linear_start, linear_end, 34 | n_timestep, 0.5) 35 | elif schedule == 'const': 36 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 37 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 38 | betas = 1. / np.linspace(n_timestep, 39 | 1, n_timestep, dtype=np.float64) 40 | elif schedule == "cosine": 41 | timesteps = ( 42 | torch.arange(n_timestep + 1, dtype=torch.float64) / 43 | n_timestep + cosine_s 44 | ) 45 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 46 | alphas = torch.cos(alphas).pow(2) 47 | alphas = alphas / alphas[0] 48 | betas = 1 - alphas[1:] / alphas[:-1] 49 | betas = betas.clamp(max=0.999) 50 | else: 51 | raise NotImplementedError(schedule) 52 | return betas 53 | 54 | 55 | # gaussian diffusion trainer class 56 | 57 | def exists(x): 58 | return x is not None 59 | 60 | 61 | def default(val, d): 62 | if exists(val): 63 | return val 64 | return d() if isfunction(d) else d 65 | 66 | 67 | class GaussianDiffusion(nn.Module): 68 | def __init__( 69 | self, 70 | denoise_fn, 71 | image_size, 72 | channels=3, 73 | loss_type='l1', 74 | conditional=True, 75 | schedule_opt=None, 76 | start_step=1000 77 | ): 78 | super().__init__() 79 | self.channels = channels 80 | self.image_size = image_size 81 | self.denoise_fn = denoise_fn 82 | self.loss_type = loss_type 83 | self.conditional = conditional 84 | if schedule_opt is not None: 85 | pass 86 | # self.set_new_noise_schedule(schedule_opt) 87 | 88 | def set_loss(self, device): 89 | if self.loss_type == 'l1': 90 | self.loss_func = nn.L1Loss(reduction='sum').to(device) 91 | elif self.loss_type == 'l2': 92 | self.loss_func = nn.MSELoss(reduction='sum').to(device) 93 | else: 94 | raise NotImplementedError() 95 | self.optim_loss = nn.MSELoss(reduction='sum').to(device) 96 | 97 | def set_new_noise_schedule(self, schedule_opt, device): 98 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 99 | 100 | betas = make_beta_schedule( 101 | schedule=schedule_opt['schedule'], 102 | n_timestep=schedule_opt['n_timestep'], 103 | linear_start=schedule_opt['linear_start'], 104 | linear_end=schedule_opt['linear_end']) 105 | betas = betas.detach().cpu().numpy() if isinstance( 106 | betas, torch.Tensor) else betas 107 | alphas = 1. - betas 108 | alphas_cumprod = np.cumprod(alphas, axis=0) 109 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 110 | self.sqrt_alphas_cumprod_prev = np.sqrt( 111 | np.append(1., alphas_cumprod)) 112 | 113 | timesteps, = betas.shape 114 | self.num_timesteps = int(timesteps) 115 | self.register_buffer('betas', to_torch(betas)) 116 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 117 | self.register_buffer('alphas_cumprod_prev', 118 | to_torch(alphas_cumprod_prev)) 119 | 120 | # calculations for diffusion q(x_t | x_{t-1}) and others 121 | self.register_buffer('sqrt_alphas_cumprod', 122 | to_torch(np.sqrt(alphas_cumprod))) 123 | self.register_buffer('sqrt_one_minus_alphas_cumprod', 124 | to_torch(np.sqrt(1. - alphas_cumprod))) 125 | self.register_buffer('log_one_minus_alphas_cumprod', 126 | to_torch(np.log(1. - alphas_cumprod))) 127 | self.register_buffer('sqrt_recip_alphas_cumprod', 128 | to_torch(np.sqrt(1. / alphas_cumprod))) 129 | self.register_buffer('sqrt_recipm1_alphas_cumprod', 130 | to_torch(np.sqrt(1. / alphas_cumprod - 1))) 131 | 132 | # calculations for posterior q(x_{t-1} | x_t, x_0) 133 | posterior_variance = betas * \ 134 | (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 135 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 136 | self.register_buffer('posterior_variance', 137 | to_torch(posterior_variance)) 138 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 139 | self.register_buffer('posterior_log_variance_clipped', to_torch( 140 | np.log(np.maximum(posterior_variance, 1e-20)))) 141 | self.register_buffer('posterior_mean_coef1', to_torch( 142 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 143 | self.register_buffer('posterior_mean_coef2', to_torch( 144 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 145 | 146 | def predict_start_from_noise(self, x_t, t, noise): 147 | return self.sqrt_recip_alphas_cumprod[t] * x_t - \ 148 | self.sqrt_recipm1_alphas_cumprod[t] * noise 149 | 150 | def q_posterior(self, x_start, x_t, t): 151 | posterior_mean = self.posterior_mean_coef1[t] * \ 152 | x_start + self.posterior_mean_coef2[t] * x_t 153 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t] 154 | return posterior_mean, posterior_log_variance_clipped 155 | 156 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None): 157 | batch_size = x.shape[0] 158 | noise_level = torch.FloatTensor( 159 | [self.sqrt_alphas_cumprod_prev[t + 1]]).repeat(batch_size, 1).to(x.device) 160 | 161 | if condition_x is not None: 162 | x_recon = self.predict_start_from_noise( 163 | x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level)) 164 | else: 165 | x_recon = self.predict_start_from_noise( 166 | x, t=t, noise=self.denoise_fn(x, noise_level)) 167 | 168 | if clip_denoised: 169 | x_recon.clamp_(-1., 1.) 170 | 171 | model_mean, posterior_log_variance = self.q_posterior( 172 | x_start=x_recon, x_t=x, t=t) 173 | return model_mean, posterior_log_variance 174 | 175 | @torch.no_grad() 176 | def p_sample(self, x, t, clip_denoised=True, condition_x=None): 177 | model_mean, model_log_variance = self.p_mean_variance( 178 | x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x) 179 | noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x) 180 | return model_mean + noise * (0.5 * model_log_variance).exp() 181 | 182 | # calc ddim alpha 183 | def compute_alpha(self, beta, t): 184 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 185 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 186 | return a 187 | 188 | def slerp(self, z1, z2, alpha): 189 | theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2))) 190 | return ( 191 | torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1 192 | + torch.sin(alpha * theta) / torch.sin(theta) * z2 193 | ) 194 | 195 | def neibor_16_mul(self, num, size=16): 196 | a = num // size 197 | b = num % size 198 | if b >= 0.5 * size: 199 | return size * (a + 1) 200 | else: 201 | return size * a 202 | 203 | @torch.no_grad() 204 | def p_sample_loop(self, x_in, continous=False): 205 | device = self.betas.device 206 | 207 | condition_ddim = True 208 | if condition_ddim: 209 | timesteps = 20 210 | ddim_eta = 1 211 | alpha = 0.5 212 | 213 | sample_inter = (1 | (timesteps // 10)) 214 | 215 | x = copy.deepcopy(x_in) 216 | batch_size, C, H, W = x.shape 217 | 218 | ret_img = x_in 219 | 220 | skip = self.num_timesteps // timesteps 221 | seq = range(0, self.num_timesteps, skip) 222 | seq_next = [-1] + list(seq[:-1]) 223 | 224 | # 初始化噪声 225 | shape = x.shape 226 | z1 = torch.randn([shape[0], 3, shape[2], shape[3]], device=device) 227 | z2 = torch.randn([shape[0], 3, shape[2], shape[3]], device=device) 228 | x = self.slerp(z1, z2, alpha) 229 | 230 | # reshape strategy 231 | reshape = False 232 | reshape_stage = 3 233 | h_gap, w_gap = H // reshape_stage, W // reshape_stage 234 | hs = [self.neibor_16_mul(h) for h in range(h_gap, H, h_gap)] + [H] 235 | ws = [self.neibor_16_mul(w) for w in range(w_gap, W, w_gap)] + [W] 236 | 237 | len_seq = len(seq) 238 | for idx, (i, j) in tqdm(enumerate(zip(reversed(seq), reversed(seq_next))), desc='sampling loop time step', 239 | total=len_seq): 240 | t = (torch.ones(batch_size) * i).to(x.device) 241 | next_t = (torch.ones(batch_size) * j).to(x.device) 242 | 243 | at = self.compute_alpha(self.betas, t.long()) 244 | at_next = self.compute_alpha(self.betas, next_t.long()) 245 | 246 | noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[i + 1]]).repeat(batch_size, 1).to( 247 | x.device) 248 | 249 | if reshape: 250 | cur_idx = int(idx / int(len_seq / reshape_stage)) 251 | cur_idx = cur_idx if cur_idx < reshape_stage else reshape_stage - 1 252 | 253 | h, w = hs[cur_idx], ws[cur_idx] 254 | im_resize = Resize([h, w]) 255 | 256 | x_in_tmp = im_resize(x_in) 257 | x = im_resize(x) 258 | 259 | et = self.denoise_fn(torch.cat([x_in_tmp, x], dim=1), noise_level) 260 | else: 261 | 262 | et = self.denoise_fn(torch.cat([x_in, x], dim=1), noise_level) 263 | 264 | x0_t = (x - et * (1 - at).sqrt()) / at.sqrt() 265 | 266 | c1 = ( 267 | ddim_eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 268 | ) 269 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 270 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 271 | 272 | x = xt_next 273 | 274 | if i % sample_inter == 0 or (i == len(seq) - 1): 275 | 276 | if x.shape[-1] != W: 277 | im_resize = Resize([H, W]) 278 | x_ = im_resize(x) 279 | else: 280 | x_ = x 281 | 282 | ret_img = torch.cat([ret_img, x_], dim=0) 283 | 284 | 285 | else: 286 | sample_inter = (1 | (self.num_timesteps // 10)) 287 | if not self.conditional: 288 | shape = x_in.shape 289 | img = torch.randn(shape, device=device) 290 | ret_img = img 291 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', 292 | total=self.num_timesteps): 293 | img = self.p_sample(img, i) 294 | if i % sample_inter == 0: 295 | ret_img = torch.cat([ret_img, img], dim=0) 296 | else: 297 | 298 | # inversion 299 | from data_analyse.dcp import Defog 300 | 301 | x_in_numpy = x_in[0].permute(1, 2, 0).cpu().numpy() 302 | x_in_numpy = (x_in_numpy - np.min(x_in_numpy)) / (np.max(x_in_numpy) - np.min(x_in_numpy)) 303 | 304 | Mask_img, A = Defog(x_in_numpy, r=81, eps=0.001, w=0.95, maxV1=0.80) 305 | Mask_img = torch.from_numpy(Mask_img).unsqueeze(dim=0).unsqueeze(dim=1).expand_as(x_in).to(x_in.device) 306 | 307 | mean_Mask_img = torch.mean(Mask_img) 308 | Mask_img = Mask_img - mean_Mask_img 309 | print(torch.max(Mask_img), torch.min(Mask_img)) 310 | 311 | ret_img = x_in 312 | x = torch.cat([ret_img], dim=1) 313 | 314 | shape = x.shape 315 | img = torch.randn([shape[0], 3, shape[2], shape[3]], device=device) 316 | 317 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', 318 | total=self.num_timesteps): 319 | 320 | img = self.p_sample(img, i, condition_x=x) 321 | 322 | if i % sample_inter == 0: 323 | ret_img = torch.cat([ret_img, img], dim=0) 324 | 325 | if continous: 326 | return ret_img 327 | else: 328 | return ret_img[-1] 329 | 330 | @torch.no_grad() 331 | def sample(self, batch_size=1, continous=False): 332 | image_size = self.image_size 333 | channels = self.channels 334 | return self.p_sample_loop((batch_size, channels, image_size, image_size), continous) 335 | 336 | @torch.no_grad() 337 | def super_resolution(self, x_in, continous=False): 338 | return self.p_sample_loop(x_in, continous) 339 | 340 | def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None): 341 | noise = default(noise, lambda: torch.randn_like(x_start)) 342 | 343 | # random gama 344 | return ( 345 | continuous_sqrt_alpha_cumprod * x_start + 346 | (1 - continuous_sqrt_alpha_cumprod ** 2).sqrt() * noise 347 | ) 348 | 349 | def p_losses(self, x_in, noise=None): 350 | x_start = x_in['HR'] 351 | 352 | x_sr = x_in['SR'] 353 | 354 | [b, c, h, w] = x_start.shape 355 | t = np.random.randint(1, self.num_timesteps + 1) 356 | continuous_sqrt_alpha_cumprod = torch.FloatTensor( 357 | np.random.uniform( 358 | self.sqrt_alphas_cumprod_prev[t - 1], 359 | self.sqrt_alphas_cumprod_prev[t], 360 | size=b 361 | ) 362 | ).to(x_start.device) 363 | continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view( 364 | b, -1) 365 | 366 | noise = default(noise, lambda: torch.randn_like(x_start)) 367 | x_noisy = self.q_sample( 368 | x_start=x_start, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise=noise) 369 | 370 | if not self.conditional: 371 | x_recon = self.denoise_fn(x_noisy, continuous_sqrt_alpha_cumprod) 372 | loss = self.loss_func(noise, x_recon) 373 | 374 | else: 375 | x_recon = self.denoise_fn( 376 | torch.cat([x_sr, x_noisy], dim=1), continuous_sqrt_alpha_cumprod) 377 | loss = self.loss_func(noise, x_recon) 378 | 379 | return loss 380 | 381 | def calc_RGB(self, tensor): 382 | b, c = tensor.shape[:2] 383 | RGB_mean = torch.mean(tensor.view(b, c, -1), -1) 384 | return RGB_mean 385 | 386 | def forward(self, x, *args, **kwargs): 387 | return self.p_losses(x, *args, **kwargs) 388 | -------------------------------------------------------------------------------- /model/dehaze_with_z_v2_modules/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | from kornia.filters import gaussian_blur2d 7 | 8 | 9 | def exists(x): 10 | return x is not None 11 | 12 | 13 | def default(val, d): 14 | if exists(val): 15 | return val 16 | return d() if isfunction(d) else d 17 | 18 | 19 | class PositionalEncoding(nn.Module): 20 | def __init__(self, dim): 21 | super().__init__() 22 | self.dim = dim 23 | 24 | def forward(self, noise_level): 25 | count = self.dim // 2 26 | step = torch.arange(count, dtype=noise_level.dtype, 27 | device=noise_level.device) / count 28 | encoding = noise_level.unsqueeze( 29 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 30 | encoding = torch.cat( 31 | [torch.sin(encoding), torch.cos(encoding)], dim=-1) 32 | return encoding 33 | 34 | 35 | class FeatureWiseAffine(nn.Module): 36 | 37 | def __init__(self, in_channels, out_channels, use_affine_level=False): 38 | super(FeatureWiseAffine, self).__init__() 39 | self.use_affine_level = use_affine_level 40 | self.noise_func = nn.Sequential( 41 | nn.Linear(in_channels, out_channels * (1 + self.use_affine_level)) 42 | ) 43 | 44 | def forward(self, x, noise_embed): 45 | batch = x.shape[0] 46 | if self.use_affine_level: 47 | gamma, beta = self.noise_func(noise_embed).view( 48 | batch, -1, 1, 1).chunk(2, dim=1) 49 | x = (1 + gamma) * x + beta 50 | else: 51 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1) 52 | return x 53 | 54 | 55 | class Swish(nn.Module): 56 | def forward(self, x): 57 | return x * torch.sigmoid(x) 58 | 59 | 60 | class Upsample(nn.Module): 61 | def __init__(self, dim): 62 | super().__init__() 63 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 64 | self.conv = nn.Conv2d(dim, dim, 3, padding=1) 65 | 66 | def forward(self, x): 67 | return self.conv(self.up(x)) 68 | 69 | 70 | class Downsample(nn.Module): 71 | def __init__(self, dim): 72 | super().__init__() 73 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1) 74 | 75 | def forward(self, x): 76 | return self.conv(x) 77 | 78 | 79 | class Block(nn.Module): 80 | def __init__(self, dim, dim_out, groups=32, dropout=0): 81 | super().__init__() 82 | self.block = nn.Sequential( 83 | nn.GroupNorm(groups, dim), 84 | Swish(), 85 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 86 | nn.Conv2d(dim, dim_out, 3, padding=1) 87 | ) 88 | 89 | def forward(self, x): 90 | return self.block(x) 91 | 92 | 93 | class ResnetBlock(nn.Module): 94 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32): 95 | super().__init__() 96 | self.noise_func = FeatureWiseAffine( 97 | noise_level_emb_dim, dim_out, use_affine_level) 98 | 99 | self.block1 = Block(dim, dim_out, groups=norm_groups) 100 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 101 | self.res_conv = nn.Conv2d( 102 | dim, dim_out, 1) if dim != dim_out else nn.Identity() 103 | 104 | def forward(self, x, time_emb): 105 | b, c, h, w = x.shape 106 | h = self.block1(x) 107 | h = self.noise_func(h, time_emb) 108 | h = self.block2(h) 109 | return h + self.res_conv(x) 110 | 111 | 112 | class SelfAttention(nn.Module): 113 | def __init__(self, in_channel, n_head=1, norm_groups=32): 114 | super().__init__() 115 | 116 | self.n_head = n_head 117 | 118 | self.norm = nn.GroupNorm(norm_groups, in_channel) 119 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False) 120 | self.out = nn.Conv2d(in_channel, in_channel, 1) 121 | 122 | def forward(self, input): 123 | batch, channel, height, width = input.shape 124 | n_head = self.n_head 125 | head_dim = channel // n_head 126 | 127 | norm = self.norm(input) 128 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width) 129 | query, key, value = qkv.chunk(3, dim=2) # bhdyx 130 | 131 | attn = torch.einsum( 132 | "bnchw, bncyx -> bnhwyx", query, key 133 | ).contiguous() / math.sqrt(channel) 134 | attn = attn.view(batch, n_head, height, width, -1) 135 | attn = torch.softmax(attn, -1) 136 | attn = attn.view(batch, n_head, height, width, height, width) 137 | 138 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() 139 | out = self.out(out.view(batch, channel, height, width)) 140 | 141 | return out + input 142 | 143 | 144 | class ResnetBlocWithAttn(nn.Module): 145 | def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False): 146 | super().__init__() 147 | self.with_attn = with_attn 148 | self.res_block = ResnetBlock( 149 | dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout) 150 | if with_attn: 151 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) 152 | 153 | def forward(self, x, time_emb): 154 | x = self.res_block(x, time_emb) 155 | if (self.with_attn): 156 | x = self.attn(x) 157 | return x 158 | 159 | 160 | class FCB(nn.Module): 161 | def __init__(self, channel, kernel_size=3): 162 | super().__init__() 163 | self.ks = kernel_size 164 | self.sigma_rate = 1 165 | 166 | params = torch.ones((4, 1), requires_grad=True) 167 | self.params = nn.Parameter(params) 168 | 169 | def forward(self, x): 170 | # 171 | x1 = gaussian_blur2d(x, (self.ks, self.ks), (1 * self.sigma_rate, 1 * self.sigma_rate)) 172 | R1 = x - x1 173 | 174 | x2 = gaussian_blur2d(x, (self.ks * 2 - 1, self.ks * 2 - 1), (2 * self.sigma_rate, 2 * self.sigma_rate)) 175 | x3 = gaussian_blur2d(x, (self.ks * 4 - 1, self.ks * 4 - 1), (4 * self.sigma_rate, 4 * self.sigma_rate)) 176 | R2 = x1 - x2 177 | R3 = x2 - x3 178 | 179 | R1 = R1.unsqueeze(dim=-1) 180 | R2 = R2.unsqueeze(dim=-1) 181 | R3 = R3.unsqueeze(dim=-1) 182 | R_cat = torch.cat([R1, R2, R3, x.unsqueeze(dim=-1)], dim=-1) 183 | 184 | sum_ = torch.matmul(R_cat, self.params).squeeze(dim=-1) 185 | 186 | return sum_ 187 | 188 | 189 | class UNet(nn.Module): 190 | def __init__( 191 | self, 192 | in_channel=6, 193 | out_channel=3, 194 | inner_channel=32, 195 | norm_groups=32, 196 | channel_mults=(1, 2, 4, 8, 8), 197 | attn_res=[8], 198 | res_blocks=3, 199 | dropout=0, 200 | with_noise_level_emb=True, 201 | image_size=128, 202 | fcb=True 203 | ): 204 | super().__init__() 205 | 206 | self.fcb = fcb 207 | 208 | if with_noise_level_emb: 209 | noise_level_channel = inner_channel 210 | self.noise_level_mlp = nn.Sequential( 211 | PositionalEncoding(inner_channel), 212 | nn.Linear(inner_channel, inner_channel * 4), 213 | Swish(), 214 | nn.Linear(inner_channel * 4, inner_channel) 215 | ) 216 | else: 217 | noise_level_channel = None 218 | self.noise_level_mlp = None 219 | 220 | num_mults = len(channel_mults) 221 | pre_channel = inner_channel 222 | feat_channels = [pre_channel] 223 | now_res = image_size 224 | downs = [nn.Conv2d(in_channel, inner_channel, 225 | kernel_size=3, padding=1)] 226 | for ind in range(num_mults): 227 | is_last = (ind == num_mults - 1) 228 | use_attn = (now_res in attn_res) 229 | channel_mult = inner_channel * channel_mults[ind] 230 | for _ in range(0, res_blocks): 231 | downs.append(ResnetBlocWithAttn( 232 | pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 233 | dropout=dropout, with_attn=use_attn)) 234 | feat_channels.append(channel_mult) 235 | pre_channel = channel_mult 236 | if not is_last: 237 | downs.append(Downsample(pre_channel)) 238 | feat_channels.append(pre_channel) 239 | now_res = now_res // 2 240 | self.downs = nn.ModuleList(downs) 241 | 242 | self.mid = nn.ModuleList([ 243 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, 244 | norm_groups=norm_groups, 245 | dropout=dropout, with_attn=True), 246 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, 247 | norm_groups=norm_groups, 248 | dropout=dropout, with_attn=False) 249 | ]) 250 | 251 | ups = [] 252 | fbs = [] 253 | for ind in reversed(range(num_mults)): 254 | is_last = (ind < 1) 255 | use_attn = (now_res in attn_res) 256 | channel_mult = inner_channel * channel_mults[ind] 257 | for _ in range(0, res_blocks + 1): 258 | ups.append(ResnetBlocWithAttn( 259 | pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, 260 | norm_groups=norm_groups, 261 | dropout=dropout, with_attn=use_attn)) 262 | pre_channel = channel_mult 263 | tmp = FCB(pre_channel) if self.fcb else pre_channel 264 | fbs.append(tmp) 265 | if not is_last: 266 | ups.append(Upsample(pre_channel)) 267 | tmp = FCB(pre_channel) if self.fcb else pre_channel 268 | fbs.append(tmp) 269 | now_res = now_res * 2 270 | 271 | self.ups = nn.ModuleList(ups) 272 | self.fbs = nn.ModuleList(fbs) 273 | 274 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups) 275 | 276 | def forward(self, x, time): 277 | t = self.noise_level_mlp(time) if exists( 278 | self.noise_level_mlp) else None 279 | 280 | feats = [] 281 | for layer in self.downs: 282 | if isinstance(layer, ResnetBlocWithAttn): 283 | x = layer(x, t) 284 | else: 285 | x = layer(x) 286 | feats.append(x) 287 | 288 | for layer in self.mid: 289 | if isinstance(layer, ResnetBlocWithAttn): 290 | x = layer(x, t) 291 | else: 292 | x = layer(x) 293 | 294 | for layer, fb in zip(self.ups, self.fbs): 295 | if isinstance(layer, ResnetBlocWithAttn): 296 | tmp = feats.pop() 297 | if self.fcb: 298 | tmp = fb(tmp) 299 | x = layer(torch.cat((x, tmp), dim=1), t) 300 | else: 301 | x = layer(x) 302 | 303 | tmp = self.final_conv(x) 304 | 305 | return tmp 306 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import os 7 | import model.networks as networks 8 | from .base_model import BaseModel 9 | 10 | logger = logging.getLogger('base') 11 | 12 | 13 | class EMA(): 14 | def __init__(self, model, decay): 15 | self.model = model 16 | self.decay = decay 17 | self.shadow = {} 18 | self.backup = {} 19 | 20 | def register(self): 21 | for name, param in self.model.named_parameters(): 22 | if param.requires_grad: 23 | self.shadow[name] = param.data.clone() 24 | 25 | def update(self): 26 | for name, param in self.model.named_parameters(): 27 | if param.requires_grad: 28 | assert name in self.shadow 29 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] 30 | self.shadow[name] = new_average.clone() 31 | 32 | def apply_shadow(self): 33 | for name, param in self.model.named_parameters(): 34 | if param.requires_grad: 35 | assert name in self.shadow 36 | self.backup[name] = param.data 37 | param.data = self.shadow[name] 38 | 39 | def restore(self): 40 | for name, param in self.model.named_parameters(): 41 | if param.requires_grad: 42 | assert name in self.backup 43 | param.data = self.backup[name] 44 | self.backup = {} 45 | 46 | 47 | class DDPM(BaseModel): 48 | def __init__(self, opt): 49 | super(DDPM, self).__init__(opt) 50 | # define network and load pretrained models 51 | self.netG = self.set_device(networks.define_G(opt)) 52 | self.schedule_phase = None 53 | 54 | # ema 55 | self.use_ema = opt['train']['ema_scheduler']['used'] if "used" in opt['train'][ 56 | 'ema_scheduler'].keys() else False 57 | if self.use_ema: 58 | self.decay = opt['train']['ema_scheduler']['ema_decay'] 59 | self.ema_start = opt['train']['ema_scheduler']['step_start_ema'] 60 | self.shadow = {} 61 | self.backup = {} 62 | self.register() 63 | print("using ema to training ...") 64 | 65 | # set loss and load resume state 66 | self.set_loss() 67 | self.set_new_noise_schedule( 68 | opt['model']['beta_schedule']['train'], schedule_phase='train') 69 | if self.opt['phase'] == 'train': 70 | self.netG.train() 71 | # find the parameters to optimize 72 | if opt['model']['finetune_norm']: 73 | optim_params = [] 74 | for k, v in self.netG.named_parameters(): 75 | v.requires_grad = False 76 | if k.find('transformer') >= 0: 77 | v.requires_grad = True 78 | v.data.zero_() 79 | optim_params.append(v) 80 | logger.info( 81 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k)) 82 | else: 83 | optim_params = list(self.netG.parameters()) 84 | 85 | self.optG = torch.optim.Adam( 86 | optim_params, lr=opt['train']["optimizer"]["lr"]) 87 | self.log_dict = OrderedDict() 88 | self.load_network() 89 | self.print_network() 90 | 91 | def feed_data(self, data): 92 | self.data = self.set_device(data) 93 | 94 | def optimize_parameters(self, current_step): 95 | self.optG.zero_grad() 96 | 97 | l_pix = self.netG(self.data) 98 | # need to average in multi-gpu 99 | b, c, h, w = self.data['HR'].shape 100 | l_pix = l_pix.sum() / int(b * c * h * w) 101 | l_pix.backward(retain_graph=True) 102 | self.optG.step() 103 | 104 | if self.use_ema and current_step > self.ema_start: 105 | self.update() 106 | 107 | # set log 108 | self.log_dict['l_pix'] = l_pix.item() 109 | 110 | def test(self, continous=False): 111 | if self.use_ema: 112 | print("use ema to test...") 113 | self.apply_shadow() 114 | 115 | self.netG.eval() 116 | with torch.no_grad(): 117 | if isinstance(self.netG, nn.DataParallel): 118 | self.SR = self.netG.module.super_resolution(self.data['SR'], continous) 119 | 120 | else: 121 | self.SR = self.netG.super_resolution(self.data['SR'], continous) 122 | 123 | if self.use_ema: 124 | self.restore() 125 | 126 | self.netG.train() 127 | 128 | def sample(self, batch_size=1, continous=False): 129 | self.netG.eval() 130 | with torch.no_grad(): 131 | if isinstance(self.netG, nn.DataParallel): 132 | self.SR = self.netG.module.sample(batch_size, continous) 133 | else: 134 | self.SR = self.netG.sample(batch_size, continous) 135 | self.netG.train() 136 | 137 | def set_loss(self): 138 | if isinstance(self.netG, nn.DataParallel): 139 | self.netG.module.set_loss(self.device) 140 | else: 141 | self.netG.set_loss(self.device) 142 | 143 | def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'): 144 | if self.schedule_phase is None or self.schedule_phase != schedule_phase: 145 | self.schedule_phase = schedule_phase 146 | if isinstance(self.netG, nn.DataParallel): 147 | self.netG.module.set_new_noise_schedule( 148 | schedule_opt, self.device) 149 | else: 150 | self.netG.set_new_noise_schedule(schedule_opt, self.device) 151 | 152 | def get_current_log(self): 153 | return self.log_dict 154 | 155 | def get_current_visuals(self, need_LR=True, sample=False): 156 | out_dict = OrderedDict() 157 | if sample: 158 | out_dict['SAM'] = self.SR.detach().float().cpu() 159 | else: 160 | out_dict['SR'] = self.SR.detach().float().cpu() 161 | out_dict['INF'] = self.data['SR'].detach().float().cpu() 162 | out_dict['HR'] = self.data['HR'].detach().float().cpu() 163 | if need_LR and 'LR' in self.data: 164 | out_dict['LR'] = self.data['LR'].detach().float().cpu() 165 | else: 166 | out_dict['LR'] = out_dict['INF'] 167 | return out_dict 168 | 169 | def print_network(self): 170 | s, n = self.get_network_description(self.netG) 171 | if isinstance(self.netG, nn.DataParallel): 172 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 173 | self.netG.module.__class__.__name__) 174 | else: 175 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 176 | 177 | logger.info( 178 | 'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 179 | # logger.info(s) 180 | 181 | def save_network(self, epoch, iter_step): 182 | gen_path = os.path.join( 183 | self.opt['path']['checkpoint'], 'I{}_E{}_gen.pth'.format(iter_step, epoch)) 184 | opt_path = os.path.join( 185 | self.opt['path']['checkpoint'], 'I{}_E{}_opt.pth'.format(iter_step, epoch)) 186 | # gen 187 | network = self.netG 188 | if isinstance(self.netG, nn.DataParallel): 189 | network = network.module 190 | state_dict = network.state_dict() 191 | for key, param in state_dict.items(): 192 | state_dict[key] = param.cpu() 193 | torch.save(state_dict, gen_path) 194 | # opt 195 | opt_state = {'epoch': epoch, 'iter': iter_step, 196 | 'scheduler': None, 'optimizer': None} 197 | opt_state['optimizer'] = self.optG.state_dict() 198 | torch.save(opt_state, opt_path) 199 | 200 | logger.info( 201 | 'Saved model in [{:s}] ...'.format(gen_path)) 202 | 203 | def load_network(self): 204 | load_path = self.opt['path']['resume_state'] 205 | if load_path is not None: 206 | logger.info( 207 | 'Loading pretrained model for G [{:s}] ...'.format(load_path)) 208 | gen_path = '{}_gen.pth'.format(load_path) 209 | opt_path = '{}_opt.pth'.format(load_path) 210 | # gen 211 | network = self.netG 212 | if isinstance(self.netG, nn.DataParallel): 213 | network = network.module 214 | network.load_state_dict(torch.load( 215 | gen_path), strict=False) 216 | if self.opt['phase'] == 'train': 217 | try: 218 | # optimizer 219 | opt = torch.load(opt_path) 220 | # self.optG.load_state_dict(opt['optimizer']) 221 | self.begin_step = opt['iter'] 222 | self.begin_epoch = opt['epoch'] 223 | except: 224 | pass 225 | 226 | def register(self): 227 | for name, param in self.netG.named_parameters(): 228 | if param.requires_grad: 229 | self.shadow[name] = param.data.clone() 230 | 231 | def update(self): 232 | for name, param in self.netG.named_parameters(): 233 | if param.requires_grad: 234 | assert name in self.shadow 235 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] 236 | self.shadow[name] = new_average.clone() 237 | 238 | def apply_shadow(self): 239 | for name, param in self.netG.named_parameters(): 240 | if param.requires_grad: 241 | assert name in self.shadow 242 | self.backup[name] = param.data 243 | param.data = self.shadow[name] 244 | 245 | def restore(self): 246 | for name, param in self.netG.named_parameters(): 247 | if param.requires_grad: 248 | assert name in self.backup 249 | param.data = self.backup[name] 250 | self.backup = {} 251 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.nn import modules 7 | 8 | logger = logging.getLogger('base') 9 | 10 | 11 | #################### 12 | # initialize 13 | #################### 14 | 15 | 16 | def weights_init_normal(m, std=0.02): 17 | classname = m.__class__.__name__ 18 | if classname.find('Conv') != -1: 19 | init.normal_(m.weight.data, 0.0, std) 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif classname.find('Linear') != -1: 23 | init.normal_(m.weight.data, 0.0, std) 24 | if m.bias is not None: 25 | m.bias.data.zero_() 26 | elif classname.find('BatchNorm2d') != -1: 27 | init.normal_(m.weight.data, 1.0, std) # BN also uses norm 28 | init.constant_(m.bias.data, 0.0) 29 | 30 | 31 | def weights_init_kaiming(m, scale=1): 32 | classname = m.__class__.__name__ 33 | if classname.find('Conv2d') != -1: 34 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 35 | m.weight.data *= scale 36 | if m.bias is not None: 37 | m.bias.data.zero_() 38 | elif classname.find('Linear') != -1: 39 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 40 | m.weight.data *= scale 41 | if m.bias is not None: 42 | m.bias.data.zero_() 43 | elif classname.find('BatchNorm2d') != -1: 44 | init.constant_(m.weight.data, 1.0) 45 | init.constant_(m.bias.data, 0.0) 46 | 47 | 48 | def weights_init_orthogonal(m): 49 | classname = m.__class__.__name__ 50 | if classname.find('Conv') != -1: 51 | try: 52 | init.orthogonal_(m.weight.data, gain=1) 53 | except: 54 | pass 55 | try: 56 | if m.bias is not None: 57 | m.bias.data.zero_() 58 | except: 59 | pass 60 | elif classname.find('Linear') != -1: 61 | init.orthogonal_(m.weight.data, gain=1) 62 | try: 63 | if m.bias is not None: 64 | m.bias.data.zero_() 65 | except: 66 | pass 67 | elif classname.find('BatchNorm2d') != -1: 68 | init.constant_(m.weight.data, 1.0) 69 | init.constant_(m.bias.data, 0.0) 70 | 71 | 72 | def init_weights(net, init_type='kaiming', scale=1, std=0.02): 73 | # scale for 'kaiming', std for 'normal'. 74 | logger.info('Initialization method [{:s}]'.format(init_type)) 75 | if init_type == 'normal': 76 | weights_init_normal_ = functools.partial(weights_init_normal, std=std) 77 | net.apply(weights_init_normal_) 78 | elif init_type == 'kaiming': 79 | weights_init_kaiming_ = functools.partial( 80 | weights_init_kaiming, scale=scale) 81 | net.apply(weights_init_kaiming_) 82 | elif init_type == 'orthogonal': 83 | net.apply(weights_init_orthogonal) 84 | else: 85 | raise NotImplementedError( 86 | 'initialization method [{:s}] not implemented'.format(init_type)) 87 | 88 | 89 | #################### 90 | # define network 91 | #################### 92 | 93 | 94 | # Generator 95 | def define_G(opt): 96 | model_opt = opt['model'] 97 | if model_opt['which_model_G'] == 'ddpm': 98 | from .ddpm_modules import diffusion, unet 99 | elif model_opt['which_model_G'] == 'sr3': 100 | from .sr3_modules import diffusion, unet 101 | elif model_opt['which_model_G'] == 'asm': 102 | from .asm_modules import diffusion, unet 103 | elif model_opt['which_model_G'] == 'MSBDN': 104 | from .MSBDN import diffusion, unet 105 | elif model_opt['which_model_G'] == 'dehazy': 106 | from .dehazy_modules import diffusion 107 | from .dehazy_modules import vspga as unet 108 | elif model_opt['which_model_G'] == 'dehaze_with_z': 109 | from .dehaze_with_z_modules import diffusion, unet 110 | elif model_opt['which_model_G'] == 'dehaze_with_z_gan': 111 | from .dehaze_with_z_gan_modules import diffusion, unet 112 | elif model_opt['which_model_G'] == 'dehaze_with_z_v1': 113 | from .dehaze_with_z_v1_modules import diffusion, unet 114 | elif model_opt['which_model_G'] == 'dehaze_with_z_bagging': 115 | from .dehaze_with_z_bagging_modules import diffusion, unet 116 | elif model_opt['which_model_G'] == 'dehaze_with_z_v1_ssim': 117 | from .dehaze_with_z_v1_ssim_modules import diffusion, unet 118 | elif model_opt['which_model_G'] == 'dehaze_with_z_v1_depth_lap_ssim': 119 | from .dehaze_with_z_v1_depth_lap_ssim_modules import diffusion, unet 120 | elif model_opt['which_model_G'] == 'dehaze_with_z_v2': 121 | from .dehaze_with_z_v2_modules import diffusion, unet 122 | elif model_opt['which_model_G'] == 'dehaze_with_z_v4_CA': 123 | from .dehaze_with_z_v4_CA_modules import diffusion, unet 124 | elif model_opt['which_model_G'] == 'dehaze_filter_hsv': 125 | from .dehaze_filter_hsv_modules import diffusion, unet 126 | 127 | if ('norm_groups' not in model_opt['unet']) or model_opt['unet']['norm_groups'] is None: 128 | model_opt['unet']['norm_groups'] = 32 129 | model = unet.UNet( 130 | in_channel=model_opt['unet']['in_channel'], 131 | out_channel=model_opt['unet']['out_channel'], 132 | norm_groups=model_opt['unet']['norm_groups'], 133 | inner_channel=model_opt['unet']['inner_channel'], 134 | channel_mults=model_opt['unet']['channel_multiplier'], 135 | attn_res=model_opt['unet']['attn_res'], 136 | res_blocks=model_opt['unet']['res_blocks'], 137 | dropout=model_opt['unet']['dropout'], 138 | image_size=model_opt['diffusion']['image_size'], 139 | fcb = model_opt['FCB'] 140 | 141 | ) 142 | netG = diffusion.GaussianDiffusion( 143 | model, 144 | image_size=model_opt['diffusion']['image_size'], 145 | channels=model_opt['diffusion']['channels'], 146 | loss_type='l1', # L1 or L2 147 | conditional=model_opt['diffusion']['conditional'], 148 | schedule_opt=model_opt['beta_schedule']['train'], 149 | start_step=model_opt['diffusion']['start_step'] if 'start_step' in model_opt['diffusion'].keys() else 1000 150 | ) 151 | if opt['phase'] == 'train': 152 | # init_weights(netG, init_type='kaiming', scale=0.1) 153 | init_weights(netG, init_type='orthogonal') 154 | if opt['gpu_ids'] and opt['distributed']: 155 | assert torch.cuda.is_available() 156 | netG = nn.DataParallel(netG) 157 | return netG 158 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6 2 | torchvision 3 | numpy 4 | pandas 5 | tqdm 6 | lmdb 7 | opencv-python 8 | pillow 9 | tensorboardx 10 | wandb 11 | kornia=0.6.2 12 | pyciede2000 13 | pyiqa==0.1.5 14 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import data as Data 3 | import model as Model 4 | import argparse 5 | import logging 6 | import core.logger as Logger 7 | import core.metrics as Metrics 8 | from core.wandb_logger import WandbLogger 9 | from tensorboardX import SummaryWriter 10 | import os 11 | import numpy as np 12 | import copy 13 | import random 14 | 15 | seed = 6666 16 | print('Random seed: {}'.format(seed)) 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | class AverageMeter(object): 25 | """Computes and stores the average and current value""" 26 | 27 | def __init__(self): 28 | self.reset() 29 | self.cache = [] 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | 37 | def update(self, val, n=0): 38 | self.val = val 39 | self.sum += val * n 40 | self.count += n 41 | 42 | self.cache.append(self.val) 43 | if len(self.cache) >= 20: self.cache = self.cache[1:] 44 | self.avg = np.mean(self.cache) 45 | 46 | def __str__(self): 47 | """String representation for logging 48 | """ 49 | # for values that should be recorded exactly e.g. iteration number 50 | if self.count == 0: 51 | return str(self.val) 52 | # for stats 53 | return '%.4f (%.4f)' % (self.val, self.avg) 54 | 55 | 56 | def adjust_learning_rate(change_idx, optimizer): 57 | """Sets the learning rate to the initial LR 58 | decayed by 10 every 30 epochs""" 59 | for param_group in optimizer.param_groups: 60 | lr = param_group['lr'] 61 | 62 | lr = lr * (0.7 ** change_idx) 63 | 64 | param_group['lr'] = lr 65 | 66 | logger.info("Current lr: {}".format(optimizer.state_dict()['param_groups'][0]['lr'])) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('-c', '--config', type=str, default='config/framework_da.json', 72 | help='JSON file for configuration') 73 | parser.add_argument('-p', '--phase', type=str, choices=['train'], 74 | help='Run either train(training) or val(generation)', default='train') 75 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) 76 | parser.add_argument('-debug', '-d', action='store_true') 77 | parser.add_argument('-enable_wandb', action='store_true') 78 | parser.add_argument('-log_wandb_ckpt', action='store_true') 79 | parser.add_argument('-log_eval', action='store_true') 80 | 81 | # parse configs 82 | args = parser.parse_args() 83 | opt = Logger.parse(args) 84 | # Convert to NoneDict, which return None for missing key. 85 | opt = Logger.dict_to_nonedict(opt) 86 | 87 | # logging 88 | torch.backends.cudnn.enabled = True 89 | torch.backends.cudnn.benchmark = True 90 | 91 | Logger.setup_logger(None, opt['path']['log'], 92 | 'train', level=logging.INFO, screen=True) 93 | Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) 94 | logger = logging.getLogger('base') 95 | logger.info(Logger.dict2str(opt)) 96 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger']) 97 | 98 | change_sizes = opt["change_sizes"] 99 | 100 | # Initialize WandbLogger 101 | if opt['enable_wandb']: 102 | import wandb 103 | 104 | wandb_logger = WandbLogger(opt) 105 | wandb.define_metric('validation/val_step') 106 | wandb.define_metric('epoch') 107 | wandb.define_metric("validation/*", step_metric="val_step") 108 | val_step = 0 109 | else: 110 | wandb_logger = None 111 | 112 | # dataset 113 | for phase, dataset_opt in opt['datasets'].items(): 114 | if phase == 'train' and args.phase != 'val': 115 | train_set = Data.create_dataset(dataset_opt, phase) 116 | train_loader = Data.create_dataloader( 117 | train_set, dataset_opt, phase) 118 | logger.info('Initial Dataset Finished') 119 | 120 | logger.info("change rate:" + "".join(["{}:{} ".format(k, v) for k, v in change_sizes.items()])) 121 | 122 | # model 123 | diffusion = Model.create_model(opt) 124 | logger.info('Initial Model Finished') 125 | 126 | # Train 127 | current_step = diffusion.begin_step 128 | current_epoch = diffusion.begin_epoch 129 | n_iter = opt['train']['n_iter'] 130 | 131 | # ave 132 | ave_loss = AverageMeter() 133 | 134 | if opt['path']['resume_state']: 135 | logger.info('Resuming training from epoch: {}, iter: {}.'.format( 136 | current_epoch, current_step)) 137 | 138 | diffusion.set_new_noise_schedule( 139 | opt['model']['beta_schedule'][opt['phase']], schedule_phase=opt['phase']) 140 | 141 | if current_step == 0: 142 | change_size_idx = 0 143 | else: 144 | change_size_idx = 0 145 | try: 146 | while current_step >= int( 147 | float(list(change_sizes.keys())[change_size_idx]) * n_iter) and change_size_idx < len( 148 | list(change_sizes.keys())): 149 | change_size_idx += 1 150 | except: 151 | pass 152 | change_size_idx -= 1 153 | 154 | while current_step < n_iter: 155 | 156 | # reset train_loader 157 | if current_step >= int( 158 | float(list(change_sizes.keys())[change_size_idx]) * n_iter) and change_size_idx < len( 159 | list(change_sizes.keys())): 160 | logger.info('reset train_loader') 161 | resize_resolu = change_sizes[list(change_sizes.keys())[change_size_idx]] 162 | train_dataset_opt = copy.deepcopy(opt['datasets']['train']) 163 | 164 | train_dataset_opt["l_resolution"], train_dataset_opt["r_resolution"] = resize_resolu, resize_resolu 165 | 166 | logger.info('reset train_loader: l_resolution:{}, r_resolution:{}, batch_size:{}'.format( 167 | train_dataset_opt["l_resolution"], train_dataset_opt["r_resolution"], 168 | train_dataset_opt["batch_size"])) 169 | 170 | train_set = Data.create_dataset(train_dataset_opt, 'train') 171 | train_loader = Data.create_dataloader(train_set, train_dataset_opt, 'train') 172 | 173 | logger.info('reset train_loader finished .') 174 | 175 | adjust_learning_rate(change_size_idx, diffusion.optG) 176 | 177 | change_size_idx += 1 178 | 179 | current_epoch += 1 180 | for _, train_data in enumerate(train_loader): 181 | current_step += 1 182 | if current_step > n_iter: 183 | break 184 | 185 | diffusion.feed_data(train_data) 186 | diffusion.optimize_parameters(current_step) 187 | # log 188 | if current_step % opt['train']['print_freq'] == 0: 189 | logs = diffusion.get_current_log() 190 | message = ' '.format( 191 | current_epoch, current_step) 192 | for k, v in logs.items(): 193 | ave_loss.update(v) 194 | message += '{:s}: {:.4e} ({:.4e})'.format(k, v, ave_loss.avg) 195 | tb_logger.add_scalar(k, v, current_step) 196 | logger.info(message) 197 | 198 | if wandb_logger: 199 | wandb_logger.log_metrics(logs) 200 | 201 | if current_step % opt['train']['save_checkpoint_freq'] == 0: 202 | logger.info('Saving models and training states.') 203 | diffusion.save_network(current_epoch, current_step) 204 | 205 | if wandb_logger and opt['log_wandb_ckpt']: 206 | wandb_logger.log_checkpoint(current_epoch, current_step) 207 | 208 | if current_step >= int( 209 | float(list(change_sizes.keys())[change_size_idx]) * n_iter) and change_size_idx < len( 210 | list(change_sizes.keys())): 211 | break 212 | 213 | if wandb_logger: 214 | wandb_logger.log_metrics({'epoch': current_epoch - 1}) 215 | 216 | # save model 217 | logger.info('End of training.') 218 | --------------------------------------------------------------------------------