├── LICENSE ├── README.md ├── figs ├── color_psnr.png ├── deblur_1.png ├── demosaic_1.png ├── demosaic_2.png ├── denoiser_arch.png ├── grayscale_psnr.png ├── sisr_1.png ├── test_03_noisy_0750.png └── test_03_usrnet_2355.png ├── kernels ├── Levin09.mat ├── kernels_12.mat └── kernels_bicubicx234.mat ├── main_download_pretrained_models.py ├── main_dpir_deblocking_color.py ├── main_dpir_deblocking_grayscale.py ├── main_dpir_deblur.py ├── main_dpir_demosaick.py ├── main_dpir_denoising.py ├── main_dpir_sisr.py ├── main_dpir_sisr_real_applications.py ├── model_zoo └── README.md ├── models ├── basicblock.py ├── network_dncnn.py └── network_unet.py ├── testsets ├── set12 │ ├── 01.png │ ├── 02.png │ ├── 03.png │ ├── 04.png │ ├── 05.png │ ├── 06.png │ ├── 07.png │ ├── 08.png │ ├── 09.png │ ├── 10.png │ ├── 11.png │ └── 12.png ├── set3c │ ├── butterfly.png │ ├── leaves.png │ └── starfish.png └── set5 │ ├── baby_GT.bmp │ ├── bird_GT.bmp │ ├── butterfly_GT.bmp │ ├── head_GT.bmp │ └── woman_GT.bmp └── utils ├── test.bmp ├── utils_bnorm.py ├── utils_csmri.py ├── utils_deblur.py ├── utils_image.py ├── utils_inpaint.py ├── utils_logger.py ├── utils_model.py ├── utils_mosaic.py ├── utils_pnp.py ├── utils_sisr.py ├── utils_sisr_beforepytorchversion8.py └── utils_test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kai Zhang (cskaizhang@gmail.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Plug-and-Play Image Restoration 2 | 3 | ![visitors](https://visitor-badge.glitch.me/badge?page_id=cszn/DPIR) 4 | 5 | [**Kai Zhang**](https://cszn.github.io/), Yawei Li, Wangmeng Zuo, Lei Zhang, Luc Van Gool, Radu Timofte 6 | 7 | _[Computer Vision Lab](https://vision.ee.ethz.ch/the-institute.html), ETH Zurich, Switzerland_ 8 | 9 | [[paper arxiv](https://arxiv.org/pdf/2008.13751.pdf)] [[paper tpami](https://ieeexplore.ieee.org/abstract/document/9454311)] 10 | 11 | 12 | 13 | 14 | Denoising results on BSD68 and Urban100 datasets 15 | ---------- 16 | | Dataset | Noise Level | FFDNet-PSNR(RGB) | FFDNet-PSNR(Y) | **DRUNet-PSNR(RGB)** | **DRUNet-PSNR(Y)** | 17 | |:---------:|:---------:|:---------:|:---------:|:---------:|:---------:| 18 | | CBSD68 | 30 | 30.32 | 32.05 | 30.81 | 32.44 | 19 | | CBSD68 | 50 | 27.97 | 29.65 | 28.51 | 30.09 | 20 | | Urban100| 30 | 30.53 | 32.72 | 31.83 | 33.93 | 21 | | Urban100| 50 | 28.05 | 30.09 | 29.61 | 31.57 | 22 | 23 | ```PSNR(Y) means the PSNR is calculated on the Y channel of YCbCr space.``` 24 | 25 | 26 | Abstract 27 | ---------- 28 | Recent works on plug-and-play image restoration have shown that a denoiser can implicitly serve as the image prior for 29 | model-based methods to solve many inverse problems. Such a property induces considerable advantages for plug-and-play image 30 | restoration (e.g., integrating the flexibility of model-based method and effectiveness of learning-based methods) when the denoiser is 31 | discriminatively learned via deep convolutional neural network (CNN) with large modeling capacity. However, while deeper and larger 32 | CNN models are rapidly gaining popularity, existing plug-and-play image restoration hinders its performance due to the lack of suitable 33 | denoiser prior. In order to push the limits of plug-and-play image restoration, we set up a benchmark deep denoiser prior by training a 34 | highly flexible and effective CNN denoiser. We then plug the deep denoiser prior as a modular part into a half quadratic splitting based 35 | iterative algorithm to solve various image restoration problems. We, meanwhile, provide a thorough analysis of parameter setting, 36 | intermediate results and empirical convergence to better understand the working mechanism. Experimental results on three 37 | representative image restoration tasks, including deblurring, super-resolution and demosaicing, demonstrate that the proposed 38 | plug-and-play image restoration with deep denoiser prior not only significantly outperforms other state-of-the-art model-based methods 39 | but also achieves competitive or even superior performance against state-of-the-art learning-based methods. 40 | 41 | 42 | The DRUNet Denoiser (state-of-the-art Gaussian denoising performance!) 43 | ---------- 44 | * Network architecture 45 | 46 | 47 | 48 | * Grayscale image denoising 49 | 50 | 51 | 52 | * Color image denoising 53 | 54 | 55 | 56 | | | | 57 | |:---:|:---:| 58 | |(a) Noisy image with noise level 200|(b) Result by the proposed DRUNet denoiser| 59 | 60 | **Even trained on noise level range of [0, 50], DRUNet can still perform well on an extremely large unseen noise level of 200.** 61 | 62 | Image Deblurring 63 | ---------- 64 | * Visual comparison 65 | 66 | 67 | 68 | 69 | Single Image Super-Resolution 70 | ---------- 71 | * Visual comparison 72 | 73 | 74 | 75 | 76 | Color Image Demosaicing 77 | ---------- 78 | * PSNR 79 | 80 | 81 | 82 | * Visual comparison 83 | 84 | 85 | 86 | 87 | 88 | Citation 89 | ---------- 90 | ```BibTex 91 | @article{zhang2021plug, 92 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior}, 93 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, 94 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 95 | volume={44}, 96 | number={10}, 97 | pages={6360-6376}, 98 | year={2021} 99 | } 100 | @inproceedings{zhang2017learning, 101 | title={Learning Deep CNN Denoiser Prior for Image Restoration}, 102 | author={Zhang, Kai and Zuo, Wangmeng and Gu, Shuhang and Zhang, Lei}, 103 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 104 | pages={3929--3938}, 105 | year={2017}, 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /figs/color_psnr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/color_psnr.png -------------------------------------------------------------------------------- /figs/deblur_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/deblur_1.png -------------------------------------------------------------------------------- /figs/demosaic_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/demosaic_1.png -------------------------------------------------------------------------------- /figs/demosaic_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/demosaic_2.png -------------------------------------------------------------------------------- /figs/denoiser_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/denoiser_arch.png -------------------------------------------------------------------------------- /figs/grayscale_psnr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/grayscale_psnr.png -------------------------------------------------------------------------------- /figs/sisr_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/sisr_1.png -------------------------------------------------------------------------------- /figs/test_03_noisy_0750.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/test_03_noisy_0750.png -------------------------------------------------------------------------------- /figs/test_03_usrnet_2355.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/test_03_usrnet_2355.png -------------------------------------------------------------------------------- /kernels/Levin09.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/kernels/Levin09.mat -------------------------------------------------------------------------------- /kernels/kernels_12.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/kernels/kernels_12.mat -------------------------------------------------------------------------------- /kernels/kernels_bicubicx234.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/kernels/kernels_bicubicx234.mat -------------------------------------------------------------------------------- /main_download_pretrained_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import requests 4 | import re 5 | 6 | 7 | """ 8 | How to use: 9 | 10 | download models: 11 | python main_download_pretrained_models.py --models "DPIR IRCNN" --model_dir "model_zoo" 12 | 13 | """ 14 | 15 | 16 | def download_pretrained_model(model_dir='model_zoo', model_name='dncnn3.pth'): 17 | if os.path.exists(os.path.join(model_dir, model_name)): 18 | print(f'already exists, skip downloading [{model_name}]') 19 | else: 20 | os.makedirs(model_dir, exist_ok=True) 21 | if 'SwinIR' in model_name: 22 | url = 'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/{}'.format(model_name) 23 | else: 24 | url = 'https://github.com/cszn/KAIR/releases/download/v1.0/{}'.format(model_name) 25 | r = requests.get(url, allow_redirects=True) 26 | print(f'downloading [{model_dir}/{model_name}] ...') 27 | open(os.path.join(model_dir, model_name), 'wb').write(r.content) 28 | print('done!') 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--models', 34 | type=lambda s: re.split(' |, ', s), 35 | default = "dncnn3.pth", 36 | help='comma or space delimited list of characters, e.g., "DnCNN", "DnCNN BSRGAN.pth", "dncnn_15.pth dncnn_50.pth"') 37 | parser.add_argument('--model_dir', type=str, default='model_zoo', help='path of model_zoo') 38 | args = parser.parse_args() 39 | 40 | print(f'trying to download {args.models}') 41 | 42 | method_model_zoo = {'DnCNN': ['dncnn_15.pth', 'dncnn_25.pth', 'dncnn_50.pth', 'dncnn3.pth', 'dncnn_color_blind.pth', 'dncnn_gray_blind.pth'], 43 | 'SRMD': ['srmdnf_x2.pth', 'srmdnf_x3.pth', 'srmdnf_x4.pth', 'srmd_x2.pth', 'srmd_x3.pth', 'srmd_x4.pth'], 44 | 'DPSR': ['dpsr_x2.pth', 'dpsr_x3.pth', 'dpsr_x4.pth', 'dpsr_x4_gan.pth'], 45 | 'FFDNet': ['ffdnet_color.pth', 'ffdnet_gray.pth', 'ffdnet_color_clip.pth', 'ffdnet_gray_clip.pth'], 46 | 'USRNet': ['usrgan.pth', 'usrgan_tiny.pth', 'usrnet.pth', 'usrnet_tiny.pth'], 47 | 'DPIR': ['drunet_gray.pth', 'drunet_color.pth', 'drunet_deblocking_color.pth', 'drunet_deblocking_grayscale.pth'], 48 | 'BSRGAN': ['BSRGAN.pth', 'BSRNet.pth', 'BSRGANx2.pth'], 49 | 'IRCNN': ['ircnn_color.pth', 'ircnn_gray.pth'], 50 | 'SwinIR': ['001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x3.pth', 51 | '001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth', 52 | '001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x3.pth', 53 | '001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x8.pth', 54 | '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth', '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x3.pth', 55 | '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x4.pth', '003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth', 56 | '003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_PSNR.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth', 57 | '004_grayDN_DFWB_s128w8_SwinIR-M_noise25.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise50.pth', 58 | '005_colorDN_DFWB_s128w8_SwinIR-M_noise15.pth', '005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth', 59 | '005_colorDN_DFWB_s128w8_SwinIR-M_noise50.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg10.pth', 60 | '006_CAR_DFWB_s126w7_SwinIR-M_jpeg20.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg30.pth', 61 | '006_CAR_DFWB_s126w7_SwinIR-M_jpeg40.pth'], 62 | 'others': ['msrresnet_x4_psnr.pth', 'msrresnet_x4_gan.pth', 'imdn_x4.pth', 'RRDB.pth', 'ESRGAN.pth', 63 | 'FSSR_DPED.pth', 'FSSR_JPEG.pth', 'RealSR_DPED.pth', 'RealSR_JPEG.pth'] 64 | } 65 | 66 | method_zoo = list(method_model_zoo.keys()) 67 | model_zoo = [] 68 | for b in list(method_model_zoo.values()): 69 | model_zoo += b 70 | 71 | if 'all' in args.models: 72 | for method in method_zoo: 73 | for model_name in method_model_zoo[method]: 74 | download_pretrained_model(args.model_dir, model_name) 75 | else: 76 | for method_model in args.models: 77 | if method_model in method_zoo: # method, need for loop 78 | for model_name in method_model_zoo[method_model]: 79 | if 'SwinIR' in model_name: 80 | download_pretrained_model(os.path.join(args.model_dir, 'swinir'), model_name) 81 | else: 82 | download_pretrained_model(args.model_dir, model_name) 83 | elif method_model in model_zoo: # model, do not need for loop 84 | if 'SwinIR' in method_model: 85 | download_pretrained_model(os.path.join(args.model_dir, 'swinir'), method_model) 86 | else: 87 | download_pretrained_model(args.model_dir, method_model) 88 | else: 89 | print(f'Do not find {method_model} from the pre-trained model zoo!') 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /main_dpir_deblocking_color.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import logging 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | import torch 7 | from utils import utils_logger 8 | from utils import utils_image as util 9 | import cv2 10 | 11 | 12 | 13 | 14 | def main(): 15 | 16 | # ---------------------------------------- 17 | # Preparation 18 | # ---------------------------------------- 19 | model_name = 'drunet_color' 20 | quality_factors = [10, 20, 30, 40] 21 | testset_name = 'LIVE1' # test set, 'LIVE1' 22 | need_degradation = True # default: True 23 | 24 | task_current = 'db' # 'dn' for deblocking 25 | n_channels = 3 # fixed 26 | model_pool = 'model_zoo' # fixed 27 | testsets = 'testsets' # fixed 28 | results = 'results' # fixed 29 | noise_level_img = 0 # fixed: 0, noise level for LR image 30 | result_name = testset_name + '_' + model_name + '_' + task_current 31 | border = 0 # shave boader to calculate PSNR and SSIM 32 | 33 | # ---------------------------------------- 34 | # L_path, E_path, H_path 35 | # ---------------------------------------- 36 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images 37 | E_path = os.path.join(results, result_name) # E_path, for Estimated images 38 | util.mkdir(E_path) 39 | 40 | logger_name = result_name 41 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) 42 | logger = logging.getLogger(logger_name) 43 | 44 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 45 | 46 | # ---------------------------------------- 47 | # load model 48 | # ---------------------------------------- 49 | model_path = os.path.join('model_zoo', 'drunet_deblocking_color.pth') 50 | from models.network_unet import UNetRes as net 51 | model = net(in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', bias=False) # define network 52 | 53 | model.load_state_dict(torch.load(model_path), strict=True) 54 | model.eval() 55 | for k, v in model.named_parameters(): 56 | v.requires_grad = False 57 | 58 | model = model.to(device) 59 | logger.info('Model path: {:s}'.format(model_path)) 60 | number_parameters = sum(map(lambda x: x.numel(), model.parameters())) 61 | logger.info('Params number: {}'.format(number_parameters)) 62 | L_paths = util.get_image_paths(L_path) 63 | 64 | for quality_factor in quality_factors: 65 | 66 | test_results = OrderedDict() 67 | test_results['psnr'] = [] 68 | test_results['ssim'] = [] 69 | 70 | logger.info('model_name:{}, quality factor:{}'.format(model_name, quality_factor)) 71 | logger.info(L_path) 72 | 73 | for idx, img in enumerate(L_paths): 74 | 75 | # ------------------------------------ 76 | # (1) img_L 77 | # ------------------------------------ 78 | img_name, ext = os.path.splitext(os.path.basename(img)) 79 | logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) 80 | img_L = util.imread_uint(img, n_channels=n_channels) 81 | img_H = img_L.copy() 82 | 83 | # ------------------------------------ 84 | # Do the JPEG compression 85 | # ------------------------------------ 86 | if need_degradation: 87 | img_L = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR) 88 | result, encimg = cv2.imencode('.jpg', img_L, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) 89 | img_L = cv2.imdecode(encimg, 1) 90 | img_L = cv2.cvtColor(img_L, cv2.COLOR_BGR2RGB) 91 | 92 | img_L = util.uint2tensor4(img_L) 93 | 94 | noise_level = (100-quality_factor)/100.0 95 | noise_level = torch.FloatTensor([noise_level]) 96 | noise_level_map = torch.ones((1,1, img_L.shape[2], img_L.shape[3])).mul_(noise_level).float() 97 | img_L = torch.cat((img_L, noise_level_map), 1) 98 | 99 | img_L = img_L.to(device) 100 | 101 | # ------------------------------------ 102 | # (2) img_E 103 | # ------------------------------------ 104 | img_E = model(img_L) 105 | img_E = util.tensor2uint(img_E) 106 | 107 | if need_degradation: 108 | 109 | img_H = img_H.squeeze() 110 | # -------------------------------- 111 | # PSNR and SSIM 112 | # -------------------------------- 113 | psnr = util.calculate_psnr(img_E, img_H, border=border) 114 | ssim = util.calculate_ssim(img_E, img_H, border=border) 115 | test_results['psnr'].append(psnr) 116 | test_results['ssim'].append(ssim) 117 | logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) 118 | 119 | # ------------------------------------ 120 | # save results 121 | # ------------------------------------ 122 | util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'_'+str(quality_factor)+'.png')) 123 | 124 | if need_degradation: 125 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 126 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) 127 | logger.info('Average PSNR/SSIM(RGB) - {} - qf{} --PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, quality_factor, ave_psnr, ave_ssim)) 128 | 129 | 130 | if __name__ == '__main__': 131 | 132 | main() 133 | -------------------------------------------------------------------------------- /main_dpir_deblocking_grayscale.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import logging 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | import torch 7 | 8 | from utils import utils_logger 9 | from utils import utils_image as util 10 | import cv2 11 | 12 | ''' 13 | Spyder (Python 3.7) 14 | PyTorch 1.8.1 15 | Windows 10 or Linux 16 | 17 | If you have any question, please feel free to contact with me. 18 | Kai Zhang (e-mail: cskaizhang@gmail.com) 19 | (github: https://github.com/cszn/DPIR) 20 | (github: https://github.com/cszn/KAIR) 21 | by Kai Zhang (06/June/2021) 22 | 23 | 24 | How to run to get the results in Table 3: 25 | Step 1: download 'classic5' and 'LIVE1' testing dataset from https://github.com/cszn/DnCNN/tree/master/testsets 26 | Step 2: download 'drunet_deblocking_grayscale.pth' model and 'dncnn3.pth' model, and put it into 'model_zoo' 27 | 'drunet_deblocking_grayscale.pth': https://drive.google.com/file/d/1ySemeOINvVfraFi_SZxZ93UuV4hMzk8g/view?usp=sharing 28 | 'dncnn3.pth': https://drive.google.com/file/d/1wwTFLFbS3AWowuNbe1XsEd_VCa2kof5I/view?usp=sharing 29 | ''' 30 | 31 | 32 | def main(): 33 | 34 | # ---------------------------------------- 35 | # Preparation 36 | # ---------------------------------------- 37 | model_name = 'drunet' 38 | quality_factors = [10, 20, 30, 40] 39 | testset_name = 'classic5' # test set, 'classic5' | 'LIVE1' 40 | need_degradation = True # default: True 41 | 42 | task_current = 'db' # 'db' for JPEG image deblocking 43 | 44 | model_pool = 'model_zoo' # fixed 45 | testsets = 'testsets' # fixed 46 | results = 'results' # fixed 47 | result_name = testset_name + '_' + model_name + '_' + task_current 48 | border = 0 # shave boader to calculate PSNR and SSIM 49 | 50 | # ---------------------------------------- 51 | # L_path, E_path, H_path 52 | # ---------------------------------------- 53 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images 54 | E_path = os.path.join(results, result_name) # E_path, for Estimated images 55 | util.mkdir(E_path) 56 | 57 | logger_name = result_name 58 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) 59 | logger = logging.getLogger(logger_name) 60 | 61 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 62 | 63 | # ---------------------------------------- 64 | # load model 65 | # ---------------------------------------- 66 | if model_name == 'dncnn3': 67 | model_path = os.path.join(model_pool, model_name+'.pth') 68 | from models.network_dncnn import DnCNN as net 69 | model = net(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='R') 70 | model_path = os.path.join('model_zoo', 'dncnn3.pth') 71 | else: 72 | model_name = 'drunet' 73 | model_path = os.path.join('model_zoo', 'drunet_deblocking_grayscale.pth') 74 | from models.network_unet import UNetRes as net 75 | model = net(in_nc=2, out_nc=1, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', bias=False) 76 | 77 | model.load_state_dict(torch.load(model_path), strict=True) 78 | model.eval() 79 | for k, v in model.named_parameters(): 80 | v.requires_grad = False 81 | 82 | model = model.to(device) 83 | logger.info('Model path: {:s}'.format(model_path)) 84 | number_parameters = sum(map(lambda x: x.numel(), model.parameters())) 85 | logger.info('Params number: {}'.format(number_parameters)) 86 | L_paths = util.get_image_paths(L_path) 87 | 88 | for quality_factor in quality_factors: 89 | 90 | test_results = OrderedDict() 91 | test_results['psnr'] = [] 92 | test_results['ssim'] = [] 93 | test_results['psnr_y'] = [] 94 | test_results['ssim_y'] = [] 95 | 96 | logger.info('model_name:{}, quality factor:{}'.format(model_name, quality_factor)) 97 | 98 | for idx, img in enumerate(L_paths): 99 | 100 | # ------------------------------------ 101 | # (1) img_L 102 | # ------------------------------------ 103 | img_name, ext = os.path.splitext(os.path.basename(img)) 104 | logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) 105 | 106 | img_L = cv2.imread(img, cv2.IMREAD_UNCHANGED) # BGR or G 107 | grayscale = True if img_L.ndim == 2 else False 108 | if not grayscale: 109 | img_L = cv2.cvtColor(img_L, cv2.COLOR_BGR2RGB) # RGB 110 | img_L_ycbcr = util.rgb2ycbcr(img_L, only_y=False) 111 | img_L = img_L_ycbcr[..., 0] # we operate on Y channel for color images 112 | 113 | img_H = img_L.copy() 114 | 115 | # ------------------------------------ 116 | # Do the JPEG compression 117 | # ------------------------------------ 118 | if need_degradation: 119 | result, encimg = cv2.imencode('.jpg', img_L, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) 120 | img_L = cv2.imdecode(encimg, 0) 121 | 122 | img_L = util.uint2tensor4(img_L[..., np.newaxis]) 123 | 124 | if model_name == 'drunet': 125 | noise_level = (100-quality_factor)/100.0 126 | noise_level = torch.FloatTensor([noise_level]) 127 | noise_level_map = torch.ones((1,1, img_L.shape[2], img_L.shape[3])).mul_(noise_level).float() 128 | img_L = torch.cat((img_L, noise_level_map), 1) 129 | 130 | img_L = img_L.to(device) 131 | 132 | # ------------------------------------ 133 | # (2) img_E 134 | # ------------------------------------ 135 | img_E = model(img_L) 136 | img_E = util.tensor2uint(img_E) 137 | 138 | if need_degradation: 139 | 140 | # -------------------------------- 141 | # PSNR and SSIM 142 | # -------------------------------- 143 | 144 | psnr = util.calculate_psnr(img_E, img_H, border=border) 145 | ssim = util.calculate_ssim(img_E, img_H, border=border) 146 | test_results['psnr'].append(psnr) 147 | test_results['ssim'].append(ssim) 148 | logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) 149 | 150 | util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'_'+str(quality_factor)+'.png')) 151 | if not grayscale: 152 | img_L_ycbcr[..., 0] = img_E 153 | img_E_rgb = util.ycbcr2rgb(img_L_ycbcr) 154 | util.imsave(img_E_rgb, os.path.join(E_path, img_name+'_'+model_name+'_'+str(quality_factor)+'_rgb.png')) 155 | 156 | if need_degradation: 157 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 158 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) 159 | logger.info('Average PSNR/SSIM(RGB) - {} - qf{} --PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, quality_factor, ave_psnr, ave_ssim)) 160 | 161 | if __name__ == '__main__': 162 | 163 | main() 164 | -------------------------------------------------------------------------------- /main_dpir_deblur.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import cv2 3 | import logging 4 | 5 | import numpy as np 6 | from datetime import datetime 7 | from collections import OrderedDict 8 | import hdf5storage 9 | from scipy import ndimage 10 | 11 | import torch 12 | 13 | from utils import utils_deblur 14 | from utils import utils_logger 15 | from utils import utils_model 16 | from utils import utils_pnp as pnp 17 | from utils import utils_sisr as sr 18 | from utils import utils_image as util 19 | 20 | 21 | """ 22 | Spyder (Python 3.7) 23 | PyTorch 1.6.0 24 | Windows 10 or Linux 25 | Kai Zhang (cskaizhang@gmail.com) 26 | github: https://github.com/cszn/DPIR 27 | https://github.com/cszn/IRCNN 28 | https://github.com/cszn/KAIR 29 | @article{zhang2020plug, 30 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior}, 31 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, 32 | journal={arXiv preprint}, 33 | year={2020} 34 | } 35 | % If you have any question, please feel free to contact with me. 36 | % Kai Zhang (e-mail: cskaizhang@gmail.com; homepage: https://cszn.github.io/) 37 | by Kai Zhang (01/August/2020) 38 | 39 | # -------------------------------------------- 40 | |--model_zoo # model_zoo 41 | |--drunet_gray # model_name, for color images 42 | |--drunet_color 43 | |--testset # testsets 44 | |--results # results 45 | # -------------------------------------------- 46 | """ 47 | 48 | def main(): 49 | 50 | # ---------------------------------------- 51 | # Preparation 52 | # ---------------------------------------- 53 | 54 | noise_level_img = 7.65/255.0 # default: 0, noise level for LR image 55 | noise_level_model = noise_level_img # noise level of model, default 0 56 | model_name = 'drunet_gray' # 'drunet_gray' | 'drunet_color' | 'ircnn_gray' | 'ircnn_color' 57 | testset_name = 'Set3C' # test set, 'set5' | 'srbsd68' 58 | x8 = True # default: False, x8 to boost performance 59 | iter_num = 8 # number of iterations 60 | modelSigma1 = 49 61 | modelSigma2 = noise_level_model*255. 62 | 63 | show_img = False # default: False 64 | save_L = True # save LR image 65 | save_E = True # save estimated image 66 | save_LEH = False # save zoomed LR, E and H images 67 | border = 0 68 | 69 | # -------------------------------- 70 | # load kernel 71 | # -------------------------------- 72 | 73 | kernels = hdf5storage.loadmat(os.path.join('kernels', 'Levin09.mat'))['kernels'] 74 | 75 | sf = 1 76 | task_current = 'deblur' # 'deblur' for deblurring 77 | n_channels = 3 if 'color' in model_name else 1 # fixed 78 | model_zoo = 'model_zoo' # fixed 79 | testsets = 'testsets' # fixed 80 | results = 'results' # fixed 81 | result_name = testset_name + '_' + task_current + '_' + model_name 82 | model_path = os.path.join(model_zoo, model_name+'.pth') 83 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 84 | torch.cuda.empty_cache() 85 | 86 | # ---------------------------------------- 87 | # L_path, E_path, H_path 88 | # ---------------------------------------- 89 | 90 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images 91 | E_path = os.path.join(results, result_name) # E_path, for Estimated images 92 | util.mkdir(E_path) 93 | 94 | logger_name = result_name 95 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) 96 | logger = logging.getLogger(logger_name) 97 | 98 | # ---------------------------------------- 99 | # load model 100 | # ---------------------------------------- 101 | 102 | if 'drunet' in model_name: 103 | from models.network_unet import UNetRes as net 104 | model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose") 105 | model.load_state_dict(torch.load(model_path), strict=True) 106 | model.eval() 107 | for _, v in model.named_parameters(): 108 | v.requires_grad = False 109 | model = model.to(device) 110 | elif 'ircnn' in model_name: 111 | from models.network_dncnn import IRCNN as net 112 | model = net(in_nc=n_channels, out_nc=n_channels, nc=64) 113 | model25 = torch.load(model_path) 114 | former_idx = 0 115 | 116 | logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model)) 117 | logger.info('Model path: {:s}'.format(model_path)) 118 | logger.info(L_path) 119 | L_paths = util.get_image_paths(L_path) 120 | 121 | test_results_ave = OrderedDict() 122 | test_results_ave['psnr'] = [] # record average PSNR for each kernel 123 | 124 | for k_index in range(kernels.shape[1]): 125 | 126 | logger.info('-------k:{:>2d} ---------'.format(k_index)) 127 | test_results = OrderedDict() 128 | test_results['psnr'] = [] 129 | k = kernels[0, k_index].astype(np.float64) 130 | util.imshow(k) if show_img else None 131 | 132 | for idx, img in enumerate(L_paths): 133 | 134 | # -------------------------------- 135 | # (1) get img_L 136 | # -------------------------------- 137 | 138 | img_name, ext = os.path.splitext(os.path.basename(img)) 139 | img_H = util.imread_uint(img, n_channels=n_channels) 140 | img_H = util.modcrop(img_H, 8) # modcrop 141 | 142 | img_L = ndimage.filters.convolve(img_H, np.expand_dims(k, axis=2), mode='wrap') 143 | util.imshow(img_L) if show_img else None 144 | img_L = util.uint2single(img_L) 145 | 146 | np.random.seed(seed=0) # for reproducibility 147 | img_L += np.random.normal(0, noise_level_img, img_L.shape) # add AWGN 148 | 149 | # -------------------------------- 150 | # (2) get rhos and sigmas 151 | # -------------------------------- 152 | 153 | rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_model), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1.0) 154 | rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device) 155 | 156 | # -------------------------------- 157 | # (3) initialize x, and pre-calculation 158 | # -------------------------------- 159 | 160 | x = util.single2tensor4(img_L).to(device) 161 | 162 | img_L_tensor, k_tensor = util.single2tensor4(img_L), util.single2tensor4(np.expand_dims(k, 2)) 163 | [k_tensor, img_L_tensor] = util.todevice([k_tensor, img_L_tensor], device) 164 | FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf) 165 | 166 | # -------------------------------- 167 | # (4) main iterations 168 | # -------------------------------- 169 | 170 | for i in range(iter_num): 171 | 172 | # -------------------------------- 173 | # step 1, FFT 174 | # -------------------------------- 175 | 176 | tau = rhos[i].float().repeat(1, 1, 1, 1) 177 | x = sr.data_solution(x, FB, FBC, F2B, FBFy, tau, sf) 178 | 179 | if 'ircnn' in model_name: 180 | current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1) 181 | 182 | if current_idx != former_idx: 183 | model.load_state_dict(model25[str(current_idx)], strict=True) 184 | model.eval() 185 | for _, v in model.named_parameters(): 186 | v.requires_grad = False 187 | model = model.to(device) 188 | former_idx = current_idx 189 | 190 | # -------------------------------- 191 | # step 2, denoiser 192 | # -------------------------------- 193 | 194 | if x8: 195 | x = util.augment_img_tensor4(x, i % 8) 196 | 197 | if 'drunet' in model_name: 198 | x = torch.cat((x, sigmas[i].float().repeat(1, 1, x.shape[2], x.shape[3])), dim=1) 199 | x = utils_model.test_mode(model, x, mode=2, refield=32, min_size=256, modulo=16) 200 | elif 'ircnn' in model_name: 201 | x = model(x) 202 | 203 | if x8: 204 | if i % 8 == 3 or i % 8 == 5: 205 | x = util.augment_img_tensor4(x, 8 - i % 8) 206 | else: 207 | x = util.augment_img_tensor4(x, i % 8) 208 | 209 | # -------------------------------- 210 | # (3) img_E 211 | # -------------------------------- 212 | 213 | img_E = util.tensor2uint(x) 214 | if n_channels == 1: 215 | img_H = img_H.squeeze() 216 | 217 | if save_E: 218 | util.imsave(img_E, os.path.join(E_path, img_name+'_k'+str(k_index)+'_'+model_name+'.png')) 219 | 220 | # -------------------------------- 221 | # (4) img_LEH 222 | # -------------------------------- 223 | 224 | if save_LEH: 225 | img_L = util.single2uint(img_L) 226 | k_v = k/np.max(k)*1.0 227 | k_v = util.single2uint(np.tile(k_v[..., np.newaxis], [1, 1, 3])) 228 | k_v = cv2.resize(k_v, (3*k_v.shape[1], 3*k_v.shape[0]), interpolation=cv2.INTER_NEAREST) 229 | img_I = cv2.resize(img_L, (sf*img_L.shape[1], sf*img_L.shape[0]), interpolation=cv2.INTER_NEAREST) 230 | img_I[:k_v.shape[0], -k_v.shape[1]:, :] = k_v 231 | img_I[:img_L.shape[0], :img_L.shape[1], :] = img_L 232 | util.imshow(np.concatenate([img_I, img_E, img_H], axis=1), title='LR / Recovered / Ground-truth') if show_img else None 233 | util.imsave(np.concatenate([img_I, img_E, img_H], axis=1), os.path.join(E_path, img_name+'_k'+str(k_index)+'_LEH.png')) 234 | 235 | if save_L: 236 | util.imsave(util.single2uint(img_L), os.path.join(E_path, img_name+'_k'+str(k_index)+'_LR.png')) 237 | 238 | psnr = util.calculate_psnr(img_E, img_H, border=border) # change with your own border 239 | test_results['psnr'].append(psnr) 240 | logger.info('{:->4d}--> {:>10s} --k:{:>2d} PSNR: {:.2f}dB'.format(idx+1, img_name+ext, k_index, psnr)) 241 | 242 | 243 | # -------------------------------- 244 | # Average PSNR 245 | # -------------------------------- 246 | 247 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 248 | logger.info('------> Average PSNR of ({}), kernel: ({}) sigma: ({:.2f}): {:.2f} dB'.format(testset_name, k_index, noise_level_model, ave_psnr)) 249 | test_results_ave['psnr'].append(ave_psnr) 250 | 251 | if __name__ == '__main__': 252 | 253 | main() 254 | -------------------------------------------------------------------------------- /main_dpir_demosaick.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import cv2 3 | import logging 4 | 5 | import numpy as np 6 | from collections import OrderedDict 7 | 8 | import torch 9 | 10 | from utils import utils_model 11 | from utils import utils_mosaic 12 | from utils import utils_logger 13 | from utils import utils_pnp as pnp 14 | from utils import utils_image as util 15 | 16 | 17 | """ 18 | Spyder (Python 3.7) 19 | PyTorch 1.6.0 20 | Windows 10 or Linux 21 | Kai Zhang (cskaizhang@gmail.com) 22 | github: https://github.com/cszn/DPIR 23 | https://github.com/cszn/IRCNN 24 | https://github.com/cszn/KAIR 25 | @article{zhang2020plug, 26 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior}, 27 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, 28 | journal={arXiv preprint}, 29 | year={2020} 30 | } 31 | % If you have any question, please feel free to contact with me. 32 | % Kai Zhang (e-mail: cskaizhang@gmail.com; homepage: https://cszn.github.io/) 33 | by Kai Zhang (01/August/2020) 34 | 35 | # -------------------------------------------- 36 | |--model_zoo # model_zoo 37 | |--drunet_gray # model_name, for color images 38 | |--drunet_color 39 | |--testset # testsets 40 | |--results # results 41 | # -------------------------------------------- 42 | 43 | How to run: 44 | step 1: download [drunet_color.pth, ircnn_color.pth] from https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D 45 | step 2: set your own testset 'testset_name' and parameter setting such as 'noise_level_model', 'iter_num'. 46 | step 3: 'python main_dpir_demosaick.py' 47 | 48 | """ 49 | 50 | def main(): 51 | 52 | # ---------------------------------------- 53 | # Preparation 54 | # ---------------------------------------- 55 | 56 | noise_level_img = 0/255.0 # set AWGN noise level for LR image, default: 0 57 | noise_level_model = noise_level_img # set noise level of model, default: 0 58 | model_name = 'ircnn_color' # set denoiser, 'drunet_color' | 'ircnn_color' 59 | testset_name = 'Set18' # set testing set, 'set18' | 'set24' 60 | x8 = True # set PGSE to boost performance, default: True 61 | iter_num = 40 # set number of iterations, default: 40 for demosaicing 62 | modelSigma1 = 49 # set sigma_1, default: 49 63 | modelSigma2 = max(0.6, noise_level_model*255.) # set sigma_2, default 64 | matlab_init = True 65 | 66 | show_img = False # default: False 67 | save_L = True # save LR image 68 | save_E = True # save estimated image 69 | save_LEH = False # save zoomed LR, E and H images 70 | border = 10 # default 10 for demosaicing 71 | 72 | task_current = 'dm' # 'dm' for demosaicing 73 | n_channels = 3 # fixed 74 | model_zoo = 'model_zoo' # fixed 75 | testsets = 'testsets' # fixed 76 | results = 'results' # fixed 77 | result_name = testset_name + '_' + task_current + '_' + model_name 78 | model_path = os.path.join(model_zoo, model_name+'.pth') 79 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 80 | torch.cuda.empty_cache() 81 | 82 | # ---------------------------------------- 83 | # L_path, E_path, H_path 84 | # ---------------------------------------- 85 | 86 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images 87 | E_path = os.path.join(results, result_name) # E_path, for Estimated images 88 | util.mkdir(E_path) 89 | 90 | logger_name = result_name 91 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) 92 | logger = logging.getLogger(logger_name) 93 | 94 | # ---------------------------------------- 95 | # load model 96 | # ---------------------------------------- 97 | 98 | if 'drunet' in model_name: 99 | from models.network_unet import UNetRes as net 100 | model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose") 101 | model.load_state_dict(torch.load(model_path), strict=True) 102 | model.eval() 103 | for _, v in model.named_parameters(): 104 | v.requires_grad = False 105 | model = model.to(device) 106 | elif 'ircnn' in model_name: 107 | from models.network_dncnn import IRCNN as net 108 | model = net(in_nc=n_channels, out_nc=n_channels, nc=64) 109 | model25 = torch.load(model_path) 110 | former_idx = 0 111 | 112 | logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model)) 113 | logger.info('Model path: {:s}'.format(model_path)) 114 | logger.info(L_path) 115 | L_paths = util.get_image_paths(L_path) 116 | 117 | test_results = OrderedDict() 118 | test_results['psnr'] = [] 119 | 120 | for idx, img in enumerate(L_paths): 121 | 122 | # -------------------------------- 123 | # (1) get img_H and img_L 124 | # -------------------------------- 125 | 126 | idx += 1 127 | img_name, ext = os.path.splitext(os.path.basename(img)) 128 | img_H = util.imread_uint(img, n_channels=n_channels) 129 | CFA, CFA4, mosaic, mask = utils_mosaic.mosaic_CFA_Bayer(img_H) 130 | 131 | # -------------------------------- 132 | # (2) initialize x 133 | # -------------------------------- 134 | 135 | if matlab_init: # matlab demosaicing for initialization 136 | CFA4 = util.uint2tensor4(CFA4).to(device) 137 | x = utils_mosaic.dm_matlab(CFA4) 138 | else: 139 | x = cv2.cvtColor(CFA, cv2.COLOR_BAYER_BG2RGB_EA) 140 | x = util.uint2tensor4(x).to(device) 141 | 142 | img_L = util.tensor2uint(x) 143 | y = util.uint2tensor4(mosaic).to(device) 144 | 145 | util.imshow(img_L) if show_img else None 146 | mask = util.single2tensor4(mask.astype(np.float32)).to(device) 147 | 148 | # -------------------------------- 149 | # (3) get rhos and sigmas 150 | # -------------------------------- 151 | 152 | rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_img), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1.0) 153 | rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device) 154 | 155 | # -------------------------------- 156 | # (4) main iterations 157 | # -------------------------------- 158 | 159 | for i in range(iter_num): 160 | 161 | # -------------------------------- 162 | # step 1, closed-form solution 163 | # -------------------------------- 164 | 165 | x = (y+rhos[i].float()*x).div(mask+rhos[i]) 166 | 167 | # -------------------------------- 168 | # step 2, denoiser 169 | # -------------------------------- 170 | 171 | if 'ircnn' in model_name: 172 | current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1) 173 | if current_idx != former_idx: 174 | model.load_state_dict(model25[str(current_idx)], strict=True) 175 | model.eval() 176 | for _, v in model.named_parameters(): 177 | v.requires_grad = False 178 | model = model.to(device) 179 | former_idx = current_idx 180 | 181 | x = torch.clamp(x, 0, 1) 182 | if x8: 183 | x = util.augment_img_tensor4(x, i % 8) 184 | 185 | if 'drunet' in model_name: 186 | x = torch.cat((x, sigmas[i].float().repeat(1, 1, x.shape[2], x.shape[3])), dim=1) 187 | x = utils_model.test_mode(model, x, mode=2, refield=32, min_size=256, modulo=16) 188 | # x = model(x) 189 | elif 'ircnn' in model_name: 190 | x = model(x) 191 | 192 | if x8: 193 | if i % 8 == 3 or i % 8 == 5: 194 | x = util.augment_img_tensor4(x, 8 - i % 8) 195 | else: 196 | x = util.augment_img_tensor4(x, i % 8) 197 | 198 | x[mask.to(torch.bool)] = y[mask.to(torch.bool)] 199 | 200 | # -------------------------------- 201 | # (4) img_E 202 | # -------------------------------- 203 | 204 | img_E = util.tensor2uint(x) 205 | psnr = util.calculate_psnr(img_E, img_H, border=border) 206 | test_results['psnr'].append(psnr) 207 | logger.info('{:->4d}--> {:>10s} -- PSNR: {:.2f}dB'.format(idx, img_name+ext, psnr)) 208 | 209 | if save_E: 210 | util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'.png')) 211 | 212 | if save_L: 213 | util.imsave(img_L, os.path.join(E_path, img_name+'_L.png')) 214 | 215 | if save_LEH: 216 | util.imsave(np.concatenate([img_L, img_E, img_H], axis=1), os.path.join(E_path, img_name+model_name+'_LEH.png')) 217 | 218 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 219 | logger.info('------> Average PSNR(RGB) of ({}) is : {:.2f} dB'.format(testset_name, ave_psnr)) 220 | 221 | 222 | if __name__ == '__main__': 223 | 224 | main() 225 | -------------------------------------------------------------------------------- /main_dpir_denoising.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import logging 3 | 4 | import numpy as np 5 | from collections import OrderedDict 6 | 7 | import torch 8 | 9 | from utils import utils_logger 10 | from utils import utils_model 11 | from utils import utils_image as util 12 | 13 | 14 | """ 15 | Spyder (Python 3.7) 16 | PyTorch 1.6.0 17 | Windows 10 or Linux 18 | Kai Zhang (cskaizhang@gmail.com) 19 | github: https://github.com/cszn/DPIR 20 | https://github.com/cszn/IRCNN 21 | https://github.com/cszn/KAIR 22 | @article{zhang2020plug, 23 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior}, 24 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, 25 | journal={arXiv preprint}, 26 | year={2020} 27 | } 28 | % If you have any question, please feel free to contact with me. 29 | % Kai Zhang (e-mail: cskaizhang@gmail.com; homepage: https://cszn.github.io/) 30 | by Kai Zhang (01/August/2020) 31 | 32 | # -------------------------------------------- 33 | |--model_zoo # model_zoo 34 | |--drunet_gray # model_name, for color images 35 | |--drunet_color 36 | |--testset # testsets 37 | |--set12 # testset_name 38 | |--bsd68 39 | |--cbsd68 40 | |--results # results 41 | |--set12_dn_drunet_gray # result_name = testset_name + '_' + 'dn' + model_name 42 | |--set12_dn_drunet_color 43 | # -------------------------------------------- 44 | """ 45 | 46 | 47 | def main(): 48 | 49 | # ---------------------------------------- 50 | # Preparation 51 | # ---------------------------------------- 52 | 53 | noise_level_img = 15 # set AWGN noise level for noisy image 54 | noise_level_model = noise_level_img # set noise level for model 55 | model_name = 'drunet_gray' # set denoiser model, 'drunet_gray' | 'drunet_color' 56 | testset_name = 'bsd68' # set test set, 'bsd68' | 'cbsd68' | 'set12' 57 | x8 = False # default: False, x8 to boost performance 58 | show_img = False # default: False 59 | border = 0 # shave boader to calculate PSNR and SSIM 60 | 61 | if 'color' in model_name: 62 | n_channels = 3 # 3 for color image 63 | else: 64 | n_channels = 1 # 1 for grayscale image 65 | 66 | model_pool = 'model_zoo' # fixed 67 | testsets = 'testsets' # fixed 68 | results = 'results' # fixed 69 | task_current = 'dn' # 'dn' for denoising 70 | result_name = testset_name + '_' + task_current + '_' + model_name 71 | 72 | model_path = os.path.join(model_pool, model_name+'.pth') 73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 74 | torch.cuda.empty_cache() 75 | 76 | # ---------------------------------------- 77 | # L_path, E_path, H_path 78 | # ---------------------------------------- 79 | 80 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images 81 | E_path = os.path.join(results, result_name) # E_path, for Estimated images 82 | util.mkdir(E_path) 83 | 84 | logger_name = result_name 85 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) 86 | logger = logging.getLogger(logger_name) 87 | 88 | # ---------------------------------------- 89 | # load model 90 | # ---------------------------------------- 91 | 92 | from models.network_unet import UNetRes as net 93 | model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose") 94 | model.load_state_dict(torch.load(model_path), strict=True) 95 | model.eval() 96 | for k, v in model.named_parameters(): 97 | v.requires_grad = False 98 | model = model.to(device) 99 | logger.info('Model path: {:s}'.format(model_path)) 100 | number_parameters = sum(map(lambda x: x.numel(), model.parameters())) 101 | logger.info('Params number: {}'.format(number_parameters)) 102 | 103 | test_results = OrderedDict() 104 | test_results['psnr'] = [] 105 | test_results['ssim'] = [] 106 | 107 | logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(model_name, noise_level_img, noise_level_model)) 108 | logger.info(L_path) 109 | L_paths = util.get_image_paths(L_path) 110 | 111 | for idx, img in enumerate(L_paths): 112 | 113 | # ------------------------------------ 114 | # (1) img_L 115 | # ------------------------------------ 116 | 117 | img_name, ext = os.path.splitext(os.path.basename(img)) 118 | # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) 119 | img_H = util.imread_uint(img, n_channels=n_channels) 120 | img_L = util.uint2single(img_H) 121 | 122 | # Add noise without clipping 123 | np.random.seed(seed=0) # for reproducibility 124 | img_L += np.random.normal(0, noise_level_img/255., img_L.shape) 125 | 126 | util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format(noise_level_img)) if show_img else None 127 | 128 | img_L = util.single2tensor4(img_L) 129 | img_L = torch.cat((img_L, torch.FloatTensor([noise_level_model/255.]).repeat(1, 1, img_L.shape[2], img_L.shape[3])), dim=1) 130 | img_L = img_L.to(device) 131 | 132 | # ------------------------------------ 133 | # (2) img_E 134 | # ------------------------------------ 135 | 136 | if not x8 and img_L.size(2)//8==0 and img_L.size(3)//8==0: 137 | img_E = model(img_L) 138 | elif not x8 and (img_L.size(2)//8!=0 or img_L.size(3)//8!=0): 139 | img_E = utils_model.test_mode(model, img_L, refield=64, mode=5) 140 | elif x8: 141 | img_E = utils_model.test_mode(model, img_L, mode=3) 142 | 143 | img_E = util.tensor2uint(img_E) 144 | 145 | # -------------------------------- 146 | # PSNR and SSIM 147 | # -------------------------------- 148 | 149 | if n_channels == 1: 150 | img_H = img_H.squeeze() 151 | psnr = util.calculate_psnr(img_E, img_H, border=border) 152 | ssim = util.calculate_ssim(img_E, img_H, border=border) 153 | test_results['psnr'].append(psnr) 154 | test_results['ssim'].append(ssim) 155 | logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) 156 | 157 | # ------------------------------------ 158 | # save results 159 | # ------------------------------------ 160 | 161 | util.imsave(img_E, os.path.join(E_path, img_name+ext)) 162 | 163 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 164 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) 165 | logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim)) 166 | 167 | 168 | if __name__ == '__main__': 169 | 170 | main() 171 | -------------------------------------------------------------------------------- /main_dpir_sisr.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import glob 3 | import cv2 4 | import logging 5 | import time 6 | 7 | import numpy as np 8 | from datetime import datetime 9 | from collections import OrderedDict 10 | import hdf5storage 11 | 12 | import torch 13 | 14 | from utils import utils_deblur 15 | from utils import utils_logger 16 | from utils import utils_model 17 | from utils import utils_pnp as pnp 18 | from utils import utils_sisr as sr 19 | from utils import utils_image as util 20 | 21 | 22 | """ 23 | Spyder (Python 3.7) 24 | PyTorch 1.6.0 25 | Windows 10 or Linux 26 | Kai Zhang (cskaizhang@gmail.com) 27 | github: https://github.com/cszn/DPIR 28 | https://github.com/cszn/IRCNN 29 | https://github.com/cszn/KAIR 30 | @article{zhang2020plug, 31 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior}, 32 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, 33 | journal={arXiv preprint}, 34 | year={2020} 35 | } 36 | % If you have any question, please feel free to contact with me. 37 | % Kai Zhang (e-mail: cskaizhang@gmail.com; homepage: https://cszn.github.io/) 38 | by Kai Zhang (01/August/2020) 39 | 40 | # -------------------------------------------- 41 | |--model_zoo # model_zoo 42 | |--drunet_color # model_name, for color images 43 | |--drunet_gray 44 | |--testset # testsets 45 | |--results # results 46 | # -------------------------------------------- 47 | 48 | 49 | How to run: 50 | step 1: download [drunet_gray.pth, drunet_color.pth, ircnn_gray.pth, ircnn_color.pth] from https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D 51 | step 2: set your own testset 'testset_name' and parameter setting such as 'noise_level_img', 'iter_num'. 52 | step 3: 'python main_dpir_sisr.py' 53 | 54 | """ 55 | 56 | def main(): 57 | 58 | # ---------------------------------------- 59 | # Preparation 60 | # ---------------------------------------- 61 | 62 | noise_level_img = 0/255.0 # set AWGN noise level for LR image, default: 0, 63 | noise_level_model = noise_level_img # setnoise level of model, default 0 64 | model_name = 'drunet_color' # set denoiser, | 'drunet_color' | 'ircnn_gray' | 'drunet_gray' | 'ircnn_color' 65 | testset_name = 'srbsd68' # set test set, 'set5' | 'srbsd68' 66 | x8 = True # default: False, x8 to boost performance 67 | test_sf = [2] # set scale factor, default: [2, 3, 4], [2], [3], [4] 68 | iter_num = 24 # set number of iterations, default: 24 for SISR 69 | modelSigma1 = 49 # set sigma_1, default: 49 70 | classical_degradation = True # set classical degradation or bicubic degradation 71 | 72 | show_img = False # default: False 73 | save_L = True # save LR image 74 | save_E = True # save estimated image 75 | save_LEH = False # save zoomed LR, E and H images 76 | 77 | task_current = 'sr' # 'sr' for super-resolution 78 | n_channels = 1 if 'gray' in model_name else 3 # fixed 79 | model_zoo = 'model_zoo' # fixed 80 | testsets = 'testsets' # fixed 81 | results = 'results' # fixed 82 | result_name = testset_name + '_' + task_current + '_' + model_name 83 | model_path = os.path.join(model_zoo, model_name+'.pth') 84 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 85 | torch.cuda.empty_cache() 86 | 87 | # ---------------------------------------- 88 | # L_path, E_path, H_path 89 | # ---------------------------------------- 90 | 91 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images 92 | E_path = os.path.join(results, result_name) # E_path, for Estimated images 93 | util.mkdir(E_path) 94 | 95 | logger_name = result_name 96 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) 97 | logger = logging.getLogger(logger_name) 98 | 99 | # ---------------------------------------- 100 | # load model 101 | # ---------------------------------------- 102 | 103 | if 'drunet' in model_name: 104 | from models.network_unet import UNetRes as net 105 | model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose") 106 | model.load_state_dict(torch.load(model_path), strict=True) 107 | model.eval() 108 | for _, v in model.named_parameters(): 109 | v.requires_grad = False 110 | model = model.to(device) 111 | elif 'ircnn' in model_name: 112 | from models.network_dncnn import IRCNN as net 113 | model = net(in_nc=n_channels, out_nc=n_channels, nc=64) 114 | model25 = torch.load(model_path) 115 | former_idx = 0 116 | 117 | logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model)) 118 | logger.info('Model path: {:s}'.format(model_path)) 119 | logger.info(L_path) 120 | L_paths = util.get_image_paths(L_path) 121 | 122 | # -------------------------------- 123 | # load kernel 124 | # -------------------------------- 125 | 126 | # kernels = hdf5storage.loadmat(os.path.join('kernels', 'Levin09.mat'))['kernels'] 127 | if classical_degradation: 128 | kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernels_12.mat'))['kernels'] 129 | else: 130 | kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernel_bicubicx234.mat'))['kernels'] 131 | 132 | test_results_ave = OrderedDict() 133 | test_results_ave['psnr_sf_k'] = [] 134 | test_results_ave['psnr_y_sf_k'] = [] 135 | 136 | for sf in test_sf: 137 | border = sf 138 | modelSigma2 = max(sf, noise_level_model*255.) 139 | k_num = 8 if classical_degradation else 1 140 | 141 | for k_index in range(k_num): 142 | logger.info('--------- sf:{:>1d} --k:{:>2d} ---------'.format(sf, k_index)) 143 | test_results = OrderedDict() 144 | test_results['psnr'] = [] 145 | test_results['psnr_y'] = [] 146 | 147 | if not classical_degradation: # for bicubic degradation 148 | k_index = sf-2 149 | k = kernels[0, k_index].astype(np.float64) 150 | 151 | util.surf(k) if show_img else None 152 | 153 | for idx, img in enumerate(L_paths): 154 | 155 | # -------------------------------- 156 | # (1) get img_L 157 | # -------------------------------- 158 | 159 | img_name, ext = os.path.splitext(os.path.basename(img)) 160 | img_H = util.imread_uint(img, n_channels=n_channels) 161 | img_H = util.modcrop(img_H, sf) # modcrop 162 | 163 | if classical_degradation: 164 | img_L = sr.classical_degradation(img_H, k, sf) 165 | util.imshow(img_L) if show_img else None 166 | img_L = util.uint2single(img_L) 167 | else: 168 | img_L = util.imresize_np(util.uint2single(img_H), 1/sf) 169 | 170 | np.random.seed(seed=0) # for reproducibility 171 | img_L += np.random.normal(0, noise_level_img, img_L.shape) # add AWGN 172 | 173 | # -------------------------------- 174 | # (2) get rhos and sigmas 175 | # -------------------------------- 176 | 177 | rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_model), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1) 178 | rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device) 179 | 180 | # -------------------------------- 181 | # (3) initialize x, and pre-calculation 182 | # -------------------------------- 183 | 184 | x = cv2.resize(img_L, (img_L.shape[1]*sf, img_L.shape[0]*sf), interpolation=cv2.INTER_CUBIC) 185 | if np.ndim(x)==2: 186 | x = x[..., None] 187 | 188 | if classical_degradation: 189 | x = sr.shift_pixel(x, sf) 190 | x = util.single2tensor4(x).to(device) 191 | 192 | img_L_tensor, k_tensor = util.single2tensor4(img_L), util.single2tensor4(np.expand_dims(k, 2)) 193 | [k_tensor, img_L_tensor] = util.todevice([k_tensor, img_L_tensor], device) 194 | FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf) 195 | 196 | # -------------------------------- 197 | # (4) main iterations 198 | # -------------------------------- 199 | 200 | for i in range(iter_num): 201 | 202 | # -------------------------------- 203 | # step 1, FFT 204 | # -------------------------------- 205 | 206 | tau = rhos[i].float().repeat(1, 1, 1, 1) 207 | x = sr.data_solution(x.float(), FB, FBC, F2B, FBFy, tau, sf) 208 | 209 | if 'ircnn' in model_name: 210 | current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1) 211 | 212 | if current_idx != former_idx: 213 | model.load_state_dict(model25[str(current_idx)], strict=True) 214 | model.eval() 215 | for _, v in model.named_parameters(): 216 | v.requires_grad = False 217 | model = model.to(device) 218 | former_idx = current_idx 219 | 220 | # -------------------------------- 221 | # step 2, denoiser 222 | # -------------------------------- 223 | 224 | if x8: 225 | x = util.augment_img_tensor4(x, i % 8) 226 | 227 | if 'drunet' in model_name: 228 | x = torch.cat((x, sigmas[i].float().repeat(1, 1, x.shape[2], x.shape[3])), dim=1) 229 | x = utils_model.test_mode(model, x, mode=2, refield=32, min_size=256, modulo=16) 230 | elif 'ircnn' in model_name: 231 | x = model(x) 232 | 233 | if x8: 234 | if i % 8 == 3 or i % 8 == 5: 235 | x = util.augment_img_tensor4(x, 8 - i % 8) 236 | else: 237 | x = util.augment_img_tensor4(x, i % 8) 238 | 239 | # -------------------------------- 240 | # (3) img_E 241 | # -------------------------------- 242 | 243 | img_E = util.tensor2uint(x) 244 | 245 | if save_E: 246 | util.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_k'+str(k_index)+'_'+model_name+'.png')) 247 | 248 | if n_channels == 1: 249 | img_H = img_H.squeeze() 250 | 251 | # -------------------------------- 252 | # (4) img_LEH 253 | # -------------------------------- 254 | 255 | img_L = util.single2uint(img_L).squeeze() 256 | 257 | if save_LEH: 258 | k_v = k/np.max(k)*1.0 259 | if n_channels==1: 260 | k_v = util.single2uint(k_v) 261 | else: 262 | k_v = util.single2uint(np.tile(k_v[..., np.newaxis], [1, 1, n_channels])) 263 | k_v = cv2.resize(k_v, (3*k_v.shape[1], 3*k_v.shape[0]), interpolation=cv2.INTER_NEAREST) 264 | img_I = cv2.resize(img_L, (sf*img_L.shape[1], sf*img_L.shape[0]), interpolation=cv2.INTER_NEAREST) 265 | img_I[:k_v.shape[0], -k_v.shape[1]:, ...] = k_v 266 | img_I[:img_L.shape[0], :img_L.shape[1], ...] = img_L 267 | util.imshow(np.concatenate([img_I, img_E, img_H], axis=1), title='LR / Recovered / Ground-truth') if show_img else None 268 | util.imsave(np.concatenate([img_I, img_E, img_H], axis=1), os.path.join(E_path, img_name+'_x'+str(sf)+'_k'+str(k_index)+'_LEH.png')) 269 | 270 | if save_L: 271 | util.imsave(img_L, os.path.join(E_path, img_name+'_x'+str(sf)+'_k'+str(k_index)+'_LR.png')) 272 | 273 | psnr = util.calculate_psnr(img_E, img_H, border=border) 274 | test_results['psnr'].append(psnr) 275 | logger.info('{:->4d}--> {:>10s} -- sf:{:>1d} --k:{:>2d} PSNR: {:.2f}dB'.format(idx+1, img_name+ext, sf, k_index, psnr)) 276 | 277 | if n_channels == 3: 278 | img_E_y = util.rgb2ycbcr(img_E, only_y=True) 279 | img_H_y = util.rgb2ycbcr(img_H, only_y=True) 280 | psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border) 281 | test_results['psnr_y'].append(psnr_y) 282 | 283 | # -------------------------------- 284 | # Average PSNR for all kernels 285 | # -------------------------------- 286 | 287 | ave_psnr_k = sum(test_results['psnr']) / len(test_results['psnr']) 288 | logger.info('------> Average PSNR(RGB) of ({}) scale factor: ({}), kernel: ({}) sigma: ({:.2f}): {:.2f} dB'.format(testset_name, sf, k_index, noise_level_model, ave_psnr_k)) 289 | test_results_ave['psnr_sf_k'].append(ave_psnr_k) 290 | 291 | if n_channels == 3: # RGB image 292 | ave_psnr_y_k = sum(test_results['psnr_y']) / len(test_results['psnr_y']) 293 | logger.info('------> Average PSNR(Y) of ({}) scale factor: ({}), kernel: ({}) sigma: ({:.2f}): {:.2f} dB'.format(testset_name, sf, k_index, noise_level_model, ave_psnr_y_k)) 294 | test_results_ave['psnr_y_sf_k'].append(ave_psnr_y_k) 295 | 296 | # --------------------------------------- 297 | # Average PSNR for all sf and kernels 298 | # --------------------------------------- 299 | 300 | ave_psnr_sf_k = sum(test_results_ave['psnr_sf_k']) / len(test_results_ave['psnr_sf_k']) 301 | logger.info('------> Average PSNR of ({}) {:.2f} dB'.format(testset_name, ave_psnr_sf_k)) 302 | if n_channels == 3: 303 | ave_psnr_y_sf_k = sum(test_results_ave['psnr_y_sf_k']) / len(test_results_ave['psnr_y_sf_k']) 304 | logger.info('------> Average PSNR of ({}) {:.2f} dB'.format(testset_name, ave_psnr_y_sf_k)) 305 | 306 | if __name__ == '__main__': 307 | 308 | main() 309 | -------------------------------------------------------------------------------- /main_dpir_sisr_real_applications.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import glob 3 | import cv2 4 | import logging 5 | import time 6 | 7 | import numpy as np 8 | from datetime import datetime 9 | from collections import OrderedDict 10 | import hdf5storage 11 | 12 | import torch 13 | 14 | from utils import utils_deblur 15 | from utils import utils_logger 16 | from utils import utils_model 17 | from utils import utils_pnp as pnp 18 | from utils import utils_sisr as sr 19 | from utils import utils_image as util 20 | 21 | 22 | """ 23 | Spyder (Python 3.7) 24 | PyTorch 1.6.0 25 | Windows 10 or Linux 26 | Kai Zhang (cskaizhang@gmail.com) 27 | github: https://github.com/cszn/DPIR 28 | https://github.com/cszn/IRCNN 29 | https://github.com/cszn/KAIR 30 | @article{zhang2020plug, 31 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior}, 32 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, 33 | journal={arXiv preprint}, 34 | year={2020} 35 | } 36 | % If you have any question, please feel free to contact with me. 37 | % Kai Zhang (e-mail: cskaizhang@gmail.com; homepage: https://cszn.github.io/) 38 | by Kai Zhang (01/August/2020) 39 | 40 | # -------------------------------------------- 41 | |--model_zoo # model_zoo 42 | |--drunet_gray # model_name, for color images 43 | |--drunet_color 44 | |--testset # testsets 45 | |--results # results 46 | # -------------------------------------------- 47 | """ 48 | 49 | def main(): 50 | 51 | """ 52 | # ---------------------------------------------------------------------------------- 53 | # In real applications, you should set proper 54 | # - "noise_level_img": from [3, 25], set 3 for clean image, try 15 for very noisy LR images 55 | # - "k" (or "kernel_width"): blur kernel is very important!!! kernel_width from [0.6, 3.0] 56 | # to get the best performance. 57 | # ---------------------------------------------------------------------------------- 58 | """ 59 | ############################################################################## 60 | 61 | testset_name = 'Set3C' # set test set, 'set5' | 'srbsd68' 62 | noise_level_img = 3 # set noise level of image, from [3, 25], set 3 for clean image 63 | model_name = 'drunet_color' # 'ircnn_color' # set denoiser, | 'drunet_color' | 'ircnn_gray' | 'drunet_gray' | 'ircnn_color' 64 | sf = 2 # set scale factor, 1, 2, 3, 4 65 | iter_num = 24 # set number of iterations, default: 24 for SISR 66 | 67 | # -------------------------------- 68 | # set blur kernel 69 | # -------------------------------- 70 | kernel_width_default_x1234 = [0.6, 0.9, 1.7, 2.2] # Gaussian kernel widths for x1, x2, x3, x4 71 | noise_level_model = noise_level_img/255. # noise level of model 72 | kernel_width = kernel_width_default_x1234[sf-1] 73 | 74 | """ 75 | # set your own kernel width !!!!!!!!!! 76 | """ 77 | # kernel_width = 1.0 78 | 79 | 80 | k = utils_deblur.fspecial('gaussian', 25, kernel_width) 81 | k = sr.shift_pixel(k, sf) # shift the kernel 82 | k /= np.sum(k) 83 | 84 | ############################################################################## 85 | 86 | 87 | show_img = False 88 | util.surf(k) if show_img else None 89 | x8 = True # default: False, x8 to boost performance 90 | modelSigma1 = 49 # set sigma_1, default: 49 91 | modelSigma2 = max(sf, noise_level_model*255.) 92 | classical_degradation = True # set classical degradation or bicubic degradation 93 | 94 | task_current = 'sr' # 'sr' for super-resolution 95 | n_channels = 1 if 'gray' in model_name else 3 # fixed 96 | model_zoo = 'model_zoo' # fixed 97 | testsets = 'testsets' # fixed 98 | results = 'results' # fixed 99 | result_name = testset_name + '_realapplications_' + task_current + '_' + model_name 100 | model_path = os.path.join(model_zoo, model_name+'.pth') 101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 102 | torch.cuda.empty_cache() 103 | 104 | # ---------------------------------------- 105 | # L_path, E_path, H_path 106 | # ---------------------------------------- 107 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images 108 | E_path = os.path.join(results, result_name) # E_path, for Estimated images 109 | util.mkdir(E_path) 110 | 111 | logger_name = result_name 112 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) 113 | logger = logging.getLogger(logger_name) 114 | 115 | # ---------------------------------------- 116 | # load model 117 | # ---------------------------------------- 118 | if 'drunet' in model_name: 119 | from models.network_unet import UNetRes as net 120 | model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose") 121 | model.load_state_dict(torch.load(model_path), strict=True) 122 | model.eval() 123 | for _, v in model.named_parameters(): 124 | v.requires_grad = False 125 | model = model.to(device) 126 | elif 'ircnn' in model_name: 127 | from models.network_dncnn import IRCNN as net 128 | model = net(in_nc=n_channels, out_nc=n_channels, nc=64) 129 | model25 = torch.load(model_path) 130 | former_idx = 0 131 | 132 | logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model)) 133 | logger.info('Model path: {:s}'.format(model_path)) 134 | logger.info(L_path) 135 | L_paths = util.get_image_paths(L_path) 136 | 137 | for idx, img in enumerate(L_paths): 138 | 139 | # -------------------------------- 140 | # (1) get img_L 141 | # -------------------------------- 142 | logger.info('Model path: {:s} Image: {:s}'.format(model_path, img)) 143 | img_name, ext = os.path.splitext(os.path.basename(img)) 144 | img_L = util.imread_uint(img, n_channels=n_channels) 145 | img_L = util.uint2single(img_L) 146 | img_L = util.modcrop(img_L, 8) # modcrop 147 | 148 | # -------------------------------- 149 | # (2) get rhos and sigmas 150 | # -------------------------------- 151 | rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_model), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1) 152 | rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device) 153 | 154 | # -------------------------------- 155 | # (3) initialize x, and pre-calculation 156 | # -------------------------------- 157 | x = cv2.resize(img_L, (img_L.shape[1]*sf, img_L.shape[0]*sf), interpolation=cv2.INTER_CUBIC) 158 | 159 | if np.ndim(x)==2: 160 | x = x[..., None] 161 | 162 | if classical_degradation: 163 | x = sr.shift_pixel(x, sf) 164 | x = util.single2tensor4(x).to(device) 165 | 166 | img_L_tensor, k_tensor = util.single2tensor4(img_L), util.single2tensor4(np.expand_dims(k, 2)) 167 | [k_tensor, img_L_tensor] = util.todevice([k_tensor, img_L_tensor], device) 168 | FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf) 169 | 170 | # -------------------------------- 171 | # (4) main iterations 172 | # -------------------------------- 173 | for i in range(iter_num): 174 | 175 | print('Iter: {} / {}'.format(i, iter_num)) 176 | 177 | # -------------------------------- 178 | # step 1, FFT 179 | # -------------------------------- 180 | tau = rhos[i].float().repeat(1, 1, 1, 1) 181 | x = sr.data_solution(x, FB, FBC, F2B, FBFy, tau, sf) 182 | 183 | if 'ircnn' in model_name: 184 | current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1) 185 | 186 | if current_idx != former_idx: 187 | model.load_state_dict(model25[str(current_idx)], strict=True) 188 | model.eval() 189 | for _, v in model.named_parameters(): 190 | v.requires_grad = False 191 | model = model.to(device) 192 | former_idx = current_idx 193 | 194 | # -------------------------------- 195 | # step 2, denoiser 196 | # -------------------------------- 197 | if x8: 198 | x = util.augment_img_tensor4(x, i % 8) 199 | 200 | if 'drunet' in model_name: 201 | x = torch.cat((x, sigmas[i].repeat(1, 1, x.shape[2], x.shape[3])), dim=1) 202 | x = utils_model.test_mode(model, x, mode=2, refield=64, min_size=256, modulo=16) 203 | elif 'ircnn' in model_name: 204 | x = model(x) 205 | 206 | if x8: 207 | if i % 8 == 3 or i % 8 == 5: 208 | x = util.augment_img_tensor4(x, 8 - i % 8) 209 | else: 210 | x = util.augment_img_tensor4(x, i % 8) 211 | 212 | # -------------------------------- 213 | # (3) img_E 214 | # -------------------------------- 215 | img_E = util.tensor2uint(x) 216 | util.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'.png')) 217 | 218 | if __name__ == '__main__': 219 | 220 | main() 221 | -------------------------------------------------------------------------------- /model_zoo/README.md: -------------------------------------------------------------------------------- 1 | 2 | * Google drive download link: [https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing](https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing) 3 | 4 | * 腾讯微云下载链接: [https://share.weiyun.com/5qO32s3](https://share.weiyun.com/5qO32s3) 5 | 6 | 7 | ----------------- 8 | 9 | 10 | |Model|Download link|Download link| 11 | |---|:--:|:--:| 12 | |drunet_gray.pth| [Google drive download link](https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing) | [腾讯微云下载链接](https://share.weiyun.com/5qO32s3) | 13 | |drunet_color.pth| [Google drive download link](https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing) | [腾讯微云下载链接](https://share.weiyun.com/5qO32s3) | 14 | |ircnn_gray.pth| [Google drive download link](https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing) | [腾讯微云下载链接](https://share.weiyun.com/5qO32s3) | 15 | |ircnn_color.pth| [Google drive download link](https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing) | [腾讯微云下载链接](https://share.weiyun.com/5qO32s3) | 16 | -------------------------------------------------------------------------------- /models/basicblock.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | ''' 8 | # -------------------------------------------- 9 | # Advanced nn.Sequential 10 | # https://github.com/xinntao/BasicSR 11 | # -------------------------------------------- 12 | ''' 13 | 14 | 15 | def sequential(*args): 16 | """Advanced nn.Sequential. 17 | 18 | Args: 19 | nn.Sequential, nn.Module 20 | 21 | Returns: 22 | nn.Sequential 23 | """ 24 | if len(args) == 1: 25 | if isinstance(args[0], OrderedDict): 26 | raise NotImplementedError('sequential does not support OrderedDict input.') 27 | return args[0] # No sequential is needed. 28 | modules = [] 29 | for module in args: 30 | if isinstance(module, nn.Sequential): 31 | for submodule in module.children(): 32 | modules.append(submodule) 33 | elif isinstance(module, nn.Module): 34 | modules.append(module) 35 | return nn.Sequential(*modules) 36 | 37 | 38 | ''' 39 | # -------------------------------------------- 40 | # Useful blocks 41 | # https://github.com/xinntao/BasicSR 42 | # -------------------------------- 43 | # conv + normaliation + relu (conv) 44 | # (PixelUnShuffle) 45 | # (ConditionalBatchNorm2d) 46 | # concat (ConcatBlock) 47 | # sum (ShortcutBlock) 48 | # resblock (ResBlock) 49 | # Channel Attention (CA) Layer (CALayer) 50 | # Residual Channel Attention Block (RCABlock) 51 | # Residual Channel Attention Group (RCAGroup) 52 | # Residual Dense Block (ResidualDenseBlock_5C) 53 | # Residual in Residual Dense Block (RRDB) 54 | # -------------------------------------------- 55 | ''' 56 | 57 | 58 | # -------------------------------------------- 59 | # return nn.Sequantial of (Conv + BN + ReLU) 60 | # -------------------------------------------- 61 | def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2): 62 | L = [] 63 | for t in mode: 64 | if t == 'C': 65 | L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) 66 | elif t == 'T': 67 | L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) 68 | elif t == 'B': 69 | L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) 70 | elif t == 'I': 71 | L.append(nn.InstanceNorm2d(out_channels, affine=True)) 72 | elif t == 'R': 73 | L.append(nn.ReLU(inplace=True)) 74 | elif t == 'r': 75 | L.append(nn.ReLU(inplace=False)) 76 | elif t == 'L': 77 | L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) 78 | elif t == 'l': 79 | L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) 80 | elif t == '2': 81 | L.append(nn.PixelShuffle(upscale_factor=2)) 82 | elif t == '3': 83 | L.append(nn.PixelShuffle(upscale_factor=3)) 84 | elif t == '4': 85 | L.append(nn.PixelShuffle(upscale_factor=4)) 86 | elif t == 'U': 87 | L.append(nn.Upsample(scale_factor=2, mode='nearest')) 88 | elif t == 'u': 89 | L.append(nn.Upsample(scale_factor=3, mode='nearest')) 90 | elif t == 'v': 91 | L.append(nn.Upsample(scale_factor=4, mode='nearest')) 92 | elif t == 'M': 93 | L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) 94 | elif t == 'A': 95 | L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) 96 | else: 97 | raise NotImplementedError('Undefined type: '.format(t)) 98 | return sequential(*L) 99 | 100 | 101 | # -------------------------------------------- 102 | # inverse of pixel_shuffle 103 | # -------------------------------------------- 104 | def pixel_unshuffle(input, upscale_factor): 105 | r"""Rearranges elements in a Tensor of shape :math:`(C, rH, rW)` to a 106 | tensor of shape :math:`(*, r^2C, H, W)`. 107 | 108 | Authors: 109 | Zhaoyi Yan, https://github.com/Zhaoyi-Yan 110 | Kai Zhang, https://github.com/cszn/FFDNet 111 | 112 | Date: 113 | 01/Jan/2019 114 | """ 115 | batch_size, channels, in_height, in_width = input.size() 116 | 117 | out_height = in_height // upscale_factor 118 | out_width = in_width // upscale_factor 119 | 120 | input_view = input.contiguous().view( 121 | batch_size, channels, out_height, upscale_factor, 122 | out_width, upscale_factor) 123 | 124 | channels *= upscale_factor ** 2 125 | unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() 126 | return unshuffle_out.view(batch_size, channels, out_height, out_width) 127 | 128 | 129 | class PixelUnShuffle(nn.Module): 130 | r"""Rearranges elements in a Tensor of shape :math:`(C, rH, rW)` to a 131 | tensor of shape :math:`(*, r^2C, H, W)`. 132 | 133 | Authors: 134 | Zhaoyi Yan, https://github.com/Zhaoyi-Yan 135 | Kai Zhang, https://github.com/cszn/FFDNet 136 | 137 | Date: 138 | 01/Jan/2019 139 | """ 140 | 141 | def __init__(self, upscale_factor): 142 | super(PixelUnShuffle, self).__init__() 143 | self.upscale_factor = upscale_factor 144 | 145 | def forward(self, input): 146 | return pixel_unshuffle(input, self.upscale_factor) 147 | 148 | def extra_repr(self): 149 | return 'upscale_factor={}'.format(self.upscale_factor) 150 | 151 | 152 | # -------------------------------------------- 153 | # conditional batch norm 154 | # https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775 155 | # -------------------------------------------- 156 | class ConditionalBatchNorm2d(nn.Module): 157 | def __init__(self, num_features, num_classes): 158 | super().__init__() 159 | self.num_features = num_features 160 | self.bn = nn.BatchNorm2d(num_features, affine=False) 161 | self.embed = nn.Embedding(num_classes, num_features * 2) 162 | self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 163 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 164 | 165 | def forward(self, x, y): 166 | out = self.bn(x) 167 | gamma, beta = self.embed(y).chunk(2, 1) 168 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 169 | return out 170 | 171 | 172 | # -------------------------------------------- 173 | # Concat the output of a submodule to its input 174 | # -------------------------------------------- 175 | class ConcatBlock(nn.Module): 176 | def __init__(self, submodule): 177 | super(ConcatBlock, self).__init__() 178 | self.sub = submodule 179 | 180 | def forward(self, x): 181 | output = torch.cat((x, self.sub(x)), dim=1) 182 | return output 183 | 184 | def __repr__(self): 185 | return self.sub.__repr__() + 'concat' 186 | 187 | 188 | # -------------------------------------------- 189 | # sum the output of a submodule to its input 190 | # -------------------------------------------- 191 | class ShortcutBlock(nn.Module): 192 | def __init__(self, submodule): 193 | super(ShortcutBlock, self).__init__() 194 | 195 | self.sub = submodule 196 | 197 | def forward(self, x): 198 | output = x + self.sub(x) 199 | return output 200 | 201 | def __repr__(self): 202 | tmpstr = 'Identity + \n|' 203 | modstr = self.sub.__repr__().replace('\n', '\n|') 204 | tmpstr = tmpstr + modstr 205 | return tmpstr 206 | 207 | 208 | # -------------------------------------------- 209 | # Res Block: x + conv(relu(conv(x))) 210 | # -------------------------------------------- 211 | class ResBlock(nn.Module): 212 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2): 213 | super(ResBlock, self).__init__() 214 | 215 | assert in_channels == out_channels, 'Only support in_channels==out_channels.' 216 | if mode[0] in ['R', 'L']: 217 | mode = mode[0].lower() + mode[1:] 218 | 219 | self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 220 | 221 | def forward(self, x): 222 | #res = self.res(x) 223 | return x + self.res(x) 224 | 225 | 226 | # -------------------------------------------- 227 | # simplified information multi-distillation block (IMDB) 228 | # x + conv1(concat(split(relu(conv(x)))x3)) 229 | # -------------------------------------------- 230 | class IMDBlock(nn.Module): 231 | """ 232 | @inproceedings{hui2019lightweight, 233 | title={Lightweight Image Super-Resolution with Information Multi-distillation Network}, 234 | author={Hui, Zheng and Gao, Xinbo and Yang, Yunchu and Wang, Xiumei}, 235 | booktitle={Proceedings of the 27th ACM International Conference on Multimedia (ACM MM)}, 236 | pages={2024--2032}, 237 | year={2019} 238 | } 239 | @inproceedings{zhang2019aim, 240 | title={AIM 2019 Challenge on Constrained Super-Resolution: Methods and Results}, 241 | author={Kai Zhang and Shuhang Gu and Radu Timofte and others}, 242 | booktitle={IEEE International Conference on Computer Vision Workshops}, 243 | year={2019} 244 | } 245 | """ 246 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CL', d_rate=0.25, negative_slope=0.05): 247 | super(IMDBlock, self).__init__() 248 | self.d_nc = int(in_channels * d_rate) 249 | self.r_nc = int(in_channels - self.d_nc) 250 | 251 | assert mode[0] == 'C', 'convolutional layer first' 252 | 253 | self.conv1 = conv(in_channels, in_channels, kernel_size, stride, padding, bias, mode, negative_slope) 254 | self.conv2 = conv(self.r_nc, in_channels, kernel_size, stride, padding, bias, mode, negative_slope) 255 | self.conv3 = conv(self.r_nc, in_channels, kernel_size, stride, padding, bias, mode, negative_slope) 256 | self.conv4 = conv(self.r_nc, self.d_nc, kernel_size, stride, padding, bias, mode[0], negative_slope) 257 | self.conv1x1 = conv(self.d_nc*4, out_channels, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0], negative_slope=negative_slope) 258 | 259 | def forward(self, x): 260 | d1, r = torch.split(self.conv1(x), (self.d_nc, self.r_nc), dim=1) 261 | d2, r = torch.split(self.conv2(r), (self.d_nc, self.r_nc), dim=1) 262 | d3, r = torch.split(self.conv3(r), (self.d_nc, self.r_nc), dim=1) 263 | r = self.conv4(r) 264 | res = self.conv1x1(torch.cat((d1, d2, d3, r), dim=1)) 265 | return x + res 266 | 267 | 268 | # d1, r1 = torch.split(self.conv1(x), (self.d_nc, self.r_nc), dim=1) 269 | # d2, r2 = torch.split(self.conv2(r1), (self.d_nc, self.r_nc), dim=1) 270 | # d3, r3 = torch.split(self.conv3(r2), (self.d_nc, self.r_nc), dim=1) 271 | # d4 = self.conv4(r3) 272 | # -------------------------------------------- 273 | # Channel Attention (CA) Layer 274 | # -------------------------------------------- 275 | class CALayer(nn.Module): 276 | def __init__(self, channel=64, reduction=16): 277 | super(CALayer, self).__init__() 278 | 279 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 280 | self.conv_fc = nn.Sequential( 281 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 282 | nn.ReLU(inplace=True), 283 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 284 | nn.Sigmoid() 285 | ) 286 | 287 | def forward(self, x): 288 | y = self.avg_pool(x) 289 | y = self.conv_fc(y) 290 | return x * y 291 | 292 | 293 | # -------------------------------------------- 294 | # Residual Channel Attention Block (RCAB) 295 | # -------------------------------------------- 296 | class RCABlock(nn.Module): 297 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', reduction=16, negative_slope=0.2): 298 | super(RCABlock, self).__init__() 299 | assert in_channels == out_channels, 'Only support in_channels==out_channels.' 300 | if mode[0] in ['R','L']: 301 | mode = mode[0].lower() + mode[1:] 302 | 303 | self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 304 | self.ca = CALayer(out_channels, reduction) 305 | 306 | def forward(self, x): 307 | res = self.res(x) 308 | res = self.ca(res) 309 | return res + x 310 | 311 | 312 | # -------------------------------------------- 313 | # Residual Channel Attention Group (RG) 314 | # -------------------------------------------- 315 | class RCAGroup(nn.Module): 316 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', reduction=16, nb=12, negative_slope=0.2): 317 | super(RCAGroup, self).__init__() 318 | assert in_channels == out_channels, 'Only support in_channels==out_channels.' 319 | if mode[0] in ['R','L']: 320 | mode = mode[0].lower() + mode[1:] 321 | 322 | RG = [RCABlock(in_channels, out_channels, kernel_size, stride, padding, bias, mode, reduction, negative_slope) for _ in range(nb)] 323 | RG.append(conv(out_channels, out_channels, mode='C')) 324 | self.rg = nn.Sequential(*RG) # self.rg = ShortcutBlock(nn.Sequential(*RG)) 325 | 326 | def forward(self, x): 327 | res = self.rg(x) 328 | return res + x 329 | 330 | 331 | # -------------------------------------------- 332 | # Residual Dense Block 333 | # style: 5 convs 334 | # -------------------------------------------- 335 | class ResidualDenseBlock_5C(nn.Module): 336 | def __init__(self, nc=64, gc=32, kernel_size=3, stride=1, padding=1, bias=True, mode='CR', negative_slope=0.2): 337 | super(ResidualDenseBlock_5C, self).__init__() 338 | # gc: growth channel 339 | self.conv1 = conv(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 340 | self.conv2 = conv(nc+gc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 341 | self.conv3 = conv(nc+2*gc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 342 | self.conv4 = conv(nc+3*gc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 343 | self.conv5 = conv(nc+4*gc, nc, kernel_size, stride, padding, bias, mode[:-1], negative_slope) 344 | 345 | def forward(self, x): 346 | x1 = self.conv1(x) 347 | x2 = self.conv2(torch.cat((x, x1), 1)) 348 | x3 = self.conv3(torch.cat((x, x1, x2), 1)) 349 | x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) 350 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 351 | return x5.mul_(0.2) + x 352 | 353 | 354 | # -------------------------------------------- 355 | # Residual in Residual Dense Block 356 | # 3x5c 357 | # -------------------------------------------- 358 | class RRDB(nn.Module): 359 | def __init__(self, nc=64, gc=32, kernel_size=3, stride=1, padding=1, bias=True, mode='CR', negative_slope=0.2): 360 | super(RRDB, self).__init__() 361 | 362 | self.RDB1 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 363 | self.RDB2 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 364 | self.RDB3 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 365 | 366 | def forward(self, x): 367 | out = self.RDB1(x) 368 | out = self.RDB2(out) 369 | out = self.RDB3(out) 370 | return out.mul_(0.2) + x 371 | 372 | 373 | """ 374 | # -------------------------------------------- 375 | # Upsampler 376 | # Kai Zhang, https://github.com/cszn/KAIR 377 | # -------------------------------------------- 378 | # upsample_pixelshuffle 379 | # upsample_upconv 380 | # upsample_convtranspose 381 | # -------------------------------------------- 382 | """ 383 | 384 | 385 | # -------------------------------------------- 386 | # conv + subp (+ relu) 387 | # -------------------------------------------- 388 | def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): 389 | assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' 390 | up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C'+mode, negative_slope=negative_slope) 391 | return up1 392 | 393 | 394 | # -------------------------------------------- 395 | # nearest_upsample + conv (+ R) 396 | # -------------------------------------------- 397 | def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): 398 | assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR' 399 | if mode[0] == '2': 400 | uc = 'UC' 401 | elif mode[0] == '3': 402 | uc = 'uC' 403 | elif mode[0] == '4': 404 | uc = 'vC' 405 | mode = mode.replace(mode[0], uc) 406 | up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope) 407 | return up1 408 | 409 | 410 | # -------------------------------------------- 411 | # convTranspose (+ relu) 412 | # -------------------------------------------- 413 | def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): 414 | assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' 415 | kernel_size = int(mode[0]) 416 | stride = int(mode[0]) 417 | mode = mode.replace(mode[0], 'T') 418 | up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 419 | return up1 420 | 421 | 422 | ''' 423 | # -------------------------------------------- 424 | # Downsampler 425 | # Kai Zhang, https://github.com/cszn/KAIR 426 | # -------------------------------------------- 427 | # downsample_strideconv 428 | # downsample_maxpool 429 | # downsample_avgpool 430 | # -------------------------------------------- 431 | ''' 432 | 433 | 434 | # -------------------------------------------- 435 | # strideconv (+ relu) 436 | # -------------------------------------------- 437 | def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): 438 | assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' 439 | kernel_size = int(mode[0]) 440 | stride = int(mode[0]) 441 | mode = mode.replace(mode[0], 'C') 442 | down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 443 | return down1 444 | 445 | 446 | # -------------------------------------------- 447 | # maxpooling + conv (+ relu) 448 | # -------------------------------------------- 449 | def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', negative_slope=0.2): 450 | assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' 451 | kernel_size_pool = int(mode[0]) 452 | stride_pool = int(mode[0]) 453 | mode = mode.replace(mode[0], 'MC') 454 | pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) 455 | pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) 456 | return sequential(pool, pool_tail) 457 | 458 | 459 | # -------------------------------------------- 460 | # averagepooling + conv (+ relu) 461 | # -------------------------------------------- 462 | def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): 463 | assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' 464 | kernel_size_pool = int(mode[0]) 465 | stride_pool = int(mode[0]) 466 | mode = mode.replace(mode[0], 'AC') 467 | pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) 468 | pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) 469 | return sequential(pool, pool_tail) 470 | 471 | 472 | ''' 473 | # -------------------------------------------- 474 | # NonLocalBlock2D: 475 | # embedded_gaussian 476 | # +W(softmax(thetaXphi)Xg) 477 | # -------------------------------------------- 478 | ''' 479 | 480 | 481 | # -------------------------------------------- 482 | # non-local block with embedded_gaussian 483 | # https://github.com/AlexHex7/Non-local_pytorch 484 | # -------------------------------------------- 485 | class NonLocalBlock2D(nn.Module): 486 | def __init__(self, nc=64, kernel_size=1, stride=1, padding=0, bias=True, act_mode='B', downsample=False, downsample_mode='maxpool', negative_slope=0.2): 487 | 488 | super(NonLocalBlock2D, self).__init__() 489 | 490 | inter_nc = nc // 2 491 | self.inter_nc = inter_nc 492 | self.W = conv(inter_nc, nc, kernel_size, stride, padding, bias, mode='C'+act_mode) 493 | self.theta = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C') 494 | 495 | if downsample: 496 | if downsample_mode == 'avgpool': 497 | downsample_block = downsample_avgpool 498 | elif downsample_mode == 'maxpool': 499 | downsample_block = downsample_maxpool 500 | elif downsample_mode == 'strideconv': 501 | downsample_block = downsample_strideconv 502 | else: 503 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) 504 | self.phi = downsample_block(nc, inter_nc, kernel_size, stride, padding, bias, mode='2') 505 | self.g = downsample_block(nc, inter_nc, kernel_size, stride, padding, bias, mode='2') 506 | else: 507 | self.phi = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C') 508 | self.g = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C') 509 | 510 | def forward(self, x): 511 | ''' 512 | :param x: (b, c, t, h, w) 513 | :return: 514 | ''' 515 | 516 | batch_size = x.size(0) 517 | 518 | g_x = self.g(x).view(batch_size, self.inter_nc, -1) 519 | g_x = g_x.permute(0, 2, 1) 520 | 521 | theta_x = self.theta(x).view(batch_size, self.inter_nc, -1) 522 | theta_x = theta_x.permute(0, 2, 1) 523 | phi_x = self.phi(x).view(batch_size, self.inter_nc, -1) 524 | f = torch.matmul(theta_x, phi_x) 525 | f_div_C = F.softmax(f, dim=-1) 526 | 527 | y = torch.matmul(f_div_C, g_x) 528 | y = y.permute(0, 2, 1).contiguous() 529 | y = y.view(batch_size, self.inter_nc, *x.size()[2:]) 530 | W_y = self.W(y) 531 | z = W_y + x 532 | 533 | return z 534 | -------------------------------------------------------------------------------- /models/network_dncnn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import models.basicblock as B 4 | 5 | 6 | """ 7 | # -------------------------------------------- 8 | # DnCNN (20 conv layers) 9 | # FDnCNN (20 conv layers) 10 | # -------------------------------------------- 11 | # References: 12 | @article{zhang2017beyond, 13 | title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising}, 14 | author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei}, 15 | journal={IEEE Transactions on Image Processing}, 16 | volume={26}, 17 | number={7}, 18 | pages={3142--3155}, 19 | year={2017}, 20 | publisher={IEEE} 21 | } 22 | @article{zhang2018ffdnet, 23 | title={FFDNet: Toward a fast and flexible solution for CNN-based image denoising}, 24 | author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, 25 | journal={IEEE Transactions on Image Processing}, 26 | volume={27}, 27 | number={9}, 28 | pages={4608--4622}, 29 | year={2018}, 30 | publisher={IEEE} 31 | } 32 | # -------------------------------------------- 33 | """ 34 | 35 | 36 | # -------------------------------------------- 37 | # DnCNN 38 | # -------------------------------------------- 39 | class DnCNN(nn.Module): 40 | def __init__(self, in_nc=1, out_nc=1, nc=64, nb=17, act_mode='BR'): 41 | """ 42 | # ------------------------------------ 43 | in_nc: channel number of input 44 | out_nc: channel number of output 45 | nc: channel number 46 | nb: total number of conv layers 47 | act_mode: batch norm + activation function; 'BR' means BN+ReLU. 48 | # ------------------------------------ 49 | Batch normalization and residual learning are 50 | beneficial to Gaussian denoising (especially 51 | for a single noise level). 52 | The residual of a noisy image corrupted by additive white 53 | Gaussian noise (AWGN) follows a constant 54 | Gaussian distribution which stablizes batch 55 | normalization during training. 56 | # ------------------------------------ 57 | """ 58 | super(DnCNN, self).__init__() 59 | assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL' 60 | bias = True 61 | 62 | m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias) 63 | m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)] 64 | m_tail = B.conv(nc, out_nc, mode='C', bias=bias) 65 | 66 | self.model = B.sequential(m_head, *m_body, m_tail) 67 | 68 | def forward(self, x): 69 | n = self.model(x) 70 | return x-n 71 | 72 | 73 | 74 | class IRCNN(nn.Module): 75 | def __init__(self, in_nc=1, out_nc=1, nc=64): 76 | """ 77 | # ------------------------------------ 78 | denoiser of IRCNN 79 | in_nc: channel number of input 80 | out_nc: channel number of output 81 | nc: channel number 82 | nb: total number of conv layers 83 | act_mode: batch norm + activation function; 'BR' means BN+ReLU. 84 | # ------------------------------------ 85 | Batch normalization and residual learning are 86 | beneficial to Gaussian denoising (especially 87 | for a single noise level). 88 | The residual of a noisy image corrupted by additive white 89 | Gaussian noise (AWGN) follows a constant 90 | Gaussian distribution which stablizes batch 91 | normalization during training. 92 | # ------------------------------------ 93 | """ 94 | super(IRCNN, self).__init__() 95 | L =[] 96 | L.append(nn.Conv2d(in_channels=in_nc, out_channels=nc, kernel_size=3, stride=1, padding=1, dilation=1, bias=True)) 97 | L.append(nn.ReLU(inplace=True)) 98 | L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=2, dilation=2, bias=True)) 99 | L.append(nn.ReLU(inplace=True)) 100 | L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=3, dilation=3, bias=True)) 101 | L.append(nn.ReLU(inplace=True)) 102 | L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=4, dilation=4, bias=True)) 103 | L.append(nn.ReLU(inplace=True)) 104 | L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=3, dilation=3, bias=True)) 105 | L.append(nn.ReLU(inplace=True)) 106 | L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=2, dilation=2, bias=True)) 107 | L.append(nn.ReLU(inplace=True)) 108 | L.append(nn.Conv2d(in_channels=nc, out_channels=out_nc, kernel_size=3, stride=1, padding=1, dilation=1, bias=True)) 109 | self.model = B.sequential(*L) 110 | 111 | def forward(self, x): 112 | n = self.model(x) 113 | return x-n 114 | 115 | 116 | 117 | 118 | 119 | 120 | # -------------------------------------------- 121 | # FDnCNN 122 | # -------------------------------------------- 123 | # Compared with DnCNN, FDnCNN has three modifications: 124 | # 1) add noise level map as input 125 | # 2) remove residual learning and BN 126 | # 3) train with L1 loss 127 | # may need more training time, but will not reduce the final PSNR too much. 128 | # -------------------------------------------- 129 | class FDnCNN(nn.Module): 130 | def __init__(self, in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R'): 131 | """ 132 | in_nc: channel number of input 133 | out_nc: channel number of output 134 | nc: channel number 135 | nb: total number of conv layers 136 | act_mode: batch norm + activation function; 'BR' means BN+ReLU. 137 | """ 138 | super(FDnCNN, self).__init__() 139 | assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL' 140 | bias = True 141 | 142 | m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias) 143 | m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)] 144 | m_tail = B.conv(nc, out_nc, mode='C', bias=bias) 145 | 146 | self.model = B.sequential(m_head, *m_body, m_tail) 147 | 148 | def forward(self, x): 149 | x = self.model(x) 150 | return x 151 | 152 | 153 | if __name__ == '__main__': 154 | from utils import utils_model 155 | import torch 156 | model1 = DnCNN(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='BR') 157 | print(utils_model.describe_model(model1)) 158 | 159 | model2 = FDnCNN(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R') 160 | print(utils_model.describe_model(model2)) 161 | 162 | x = torch.randn((1, 1, 240, 240)) 163 | x1 = model1(x) 164 | print(x1.shape) 165 | 166 | x = torch.randn((1, 2, 240, 240)) 167 | x2 = model2(x) 168 | print(x2.shape) 169 | 170 | # run models/network_dncnn.py -------------------------------------------------------------------------------- /models/network_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import models.basicblock as B 4 | import numpy as np 5 | 6 | ''' 7 | # ==================== 8 | # unet 9 | # ==================== 10 | ''' 11 | 12 | 13 | class UNet(nn.Module): 14 | def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=2, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'): 15 | super(UNet, self).__init__() 16 | 17 | self.m_head = B.conv(in_nc, nc[0], mode='C'+act_mode[-1]) 18 | 19 | # downsample 20 | if downsample_mode == 'avgpool': 21 | downsample_block = B.downsample_avgpool 22 | elif downsample_mode == 'maxpool': 23 | downsample_block = B.downsample_maxpool 24 | elif downsample_mode == 'strideconv': 25 | downsample_block = B.downsample_strideconv 26 | else: 27 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) 28 | 29 | self.m_down1 = B.sequential(*[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[0], nc[1], mode='2'+act_mode)) 30 | self.m_down2 = B.sequential(*[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[1], nc[2], mode='2'+act_mode)) 31 | self.m_down3 = B.sequential(*[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[2], nc[3], mode='2'+act_mode)) 32 | 33 | self.m_body = B.sequential(*[B.conv(nc[3], nc[3], mode='C'+act_mode) for _ in range(nb+1)]) 34 | 35 | # upsample 36 | if upsample_mode == 'upconv': 37 | upsample_block = B.upsample_upconv 38 | elif upsample_mode == 'pixelshuffle': 39 | upsample_block = B.upsample_pixelshuffle 40 | elif upsample_mode == 'convtranspose': 41 | upsample_block = B.upsample_convtranspose 42 | else: 43 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) 44 | 45 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], mode='2'+act_mode), *[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb)]) 46 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], mode='2'+act_mode), *[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb)]) 47 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], mode='2'+act_mode), *[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb)]) 48 | 49 | self.m_tail = B.conv(nc[0], out_nc, bias=True, mode='C') 50 | 51 | def forward(self, x0): 52 | 53 | x1 = self.m_head(x0) 54 | x2 = self.m_down1(x1) 55 | x3 = self.m_down2(x2) 56 | x4 = self.m_down3(x3) 57 | x = self.m_body(x4) 58 | x = self.m_up3(x+x4) 59 | x = self.m_up2(x+x3) 60 | x = self.m_up1(x+x2) 61 | x = self.m_tail(x+x1) + x0 62 | 63 | 64 | return x 65 | 66 | 67 | class UNetRes(nn.Module): 68 | def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'): 69 | super(UNetRes, self).__init__() 70 | 71 | self.m_head = B.conv(in_nc, nc[0], bias=False, mode='C') 72 | 73 | # downsample 74 | if downsample_mode == 'avgpool': 75 | downsample_block = B.downsample_avgpool 76 | elif downsample_mode == 'maxpool': 77 | downsample_block = B.downsample_maxpool 78 | elif downsample_mode == 'strideconv': 79 | downsample_block = B.downsample_strideconv 80 | else: 81 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) 82 | 83 | self.m_down1 = B.sequential(*[B.ResBlock(nc[0], nc[0], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], bias=False, mode='2')) 84 | self.m_down2 = B.sequential(*[B.ResBlock(nc[1], nc[1], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], bias=False, mode='2')) 85 | self.m_down3 = B.sequential(*[B.ResBlock(nc[2], nc[2], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], bias=False, mode='2')) 86 | 87 | self.m_body = B.sequential(*[B.ResBlock(nc[3], nc[3], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) 88 | 89 | # upsample 90 | if upsample_mode == 'upconv': 91 | upsample_block = B.upsample_upconv 92 | elif upsample_mode == 'pixelshuffle': 93 | upsample_block = B.upsample_pixelshuffle 94 | elif upsample_mode == 'convtranspose': 95 | upsample_block = B.upsample_convtranspose 96 | else: 97 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) 98 | 99 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], bias=False, mode='2'), *[B.ResBlock(nc[2], nc[2], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) 100 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], bias=False, mode='2'), *[B.ResBlock(nc[1], nc[1], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) 101 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], bias=False, mode='2'), *[B.ResBlock(nc[0], nc[0], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) 102 | 103 | self.m_tail = B.conv(nc[0], out_nc, bias=False, mode='C') 104 | 105 | def forward(self, x0): 106 | x1 = self.m_head(x0) 107 | x2 = self.m_down1(x1) 108 | x3 = self.m_down2(x2) 109 | x4 = self.m_down3(x3) 110 | x = self.m_body(x4) 111 | x = self.m_up3(x+x4) 112 | x = self.m_up2(x+x3) 113 | x = self.m_up1(x+x2) 114 | x = self.m_tail(x+x1) 115 | 116 | return x 117 | 118 | 119 | class ResUNet(nn.Module): 120 | def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=4, act_mode='L', downsample_mode='strideconv', upsample_mode='convtranspose'): 121 | super(ResUNet, self).__init__() 122 | 123 | self.m_head = B.conv(in_nc, nc[0], bias=False, mode='C') 124 | 125 | # downsample 126 | if downsample_mode == 'avgpool': 127 | downsample_block = B.downsample_avgpool 128 | elif downsample_mode == 'maxpool': 129 | downsample_block = B.downsample_maxpool 130 | elif downsample_mode == 'strideconv': 131 | downsample_block = B.downsample_strideconv 132 | else: 133 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) 134 | 135 | self.m_down1 = B.sequential(*[B.IMDBlock(nc[0], nc[0], bias=False, mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[0], nc[1], bias=False, mode='2')) 136 | self.m_down2 = B.sequential(*[B.IMDBlock(nc[1], nc[1], bias=False, mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[1], nc[2], bias=False, mode='2')) 137 | self.m_down3 = B.sequential(*[B.IMDBlock(nc[2], nc[2], bias=False, mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[2], nc[3], bias=False, mode='2')) 138 | 139 | self.m_body = B.sequential(*[B.IMDBlock(nc[3], nc[3], bias=False, mode='C'+act_mode) for _ in range(nb)]) 140 | 141 | # upsample 142 | if upsample_mode == 'upconv': 143 | upsample_block = B.upsample_upconv 144 | elif upsample_mode == 'pixelshuffle': 145 | upsample_block = B.upsample_pixelshuffle 146 | elif upsample_mode == 'convtranspose': 147 | upsample_block = B.upsample_convtranspose 148 | else: 149 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) 150 | 151 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], bias=False, mode='2'), *[B.IMDBlock(nc[2], nc[2], bias=False, mode='C'+act_mode) for _ in range(nb)]) 152 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], bias=False, mode='2'), *[B.IMDBlock(nc[1], nc[1], bias=False, mode='C'+act_mode) for _ in range(nb)]) 153 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], bias=False, mode='2'), *[B.IMDBlock(nc[0], nc[0], bias=False, mode='C'+act_mode) for _ in range(nb)]) 154 | 155 | self.m_tail = B.conv(nc[0], out_nc, bias=False, mode='C') 156 | 157 | def forward(self, x): 158 | 159 | h, w = x.size()[-2:] 160 | paddingBottom = int(np.ceil(h/8)*8-h) 161 | paddingRight = int(np.ceil(w/8)*8-w) 162 | x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x) 163 | 164 | x1 = self.m_head(x) 165 | x2 = self.m_down1(x1) 166 | x3 = self.m_down2(x2) 167 | x4 = self.m_down3(x3) 168 | x = self.m_body(x4) 169 | x = self.m_up3(x+x4) 170 | x = self.m_up2(x+x3) 171 | x = self.m_up1(x+x2) 172 | x = self.m_tail(x+x1) 173 | x = x[..., :h, :w] 174 | 175 | return x 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | class UNetResSubP(nn.Module): 191 | def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=2, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'): 192 | super(UNetResSubP, self).__init__() 193 | sf = 2 194 | self.m_ps_down = B.PixelUnShuffle(sf) 195 | self.m_ps_up = nn.PixelShuffle(sf) 196 | self.m_head = B.conv(in_nc*sf*sf, nc[0], mode='C'+act_mode[-1]) 197 | 198 | # downsample 199 | if downsample_mode == 'avgpool': 200 | downsample_block = B.downsample_avgpool 201 | elif downsample_mode == 'maxpool': 202 | downsample_block = B.downsample_maxpool 203 | elif downsample_mode == 'strideconv': 204 | downsample_block = B.downsample_strideconv 205 | else: 206 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) 207 | 208 | self.m_down1 = B.sequential(*[B.ResBlock(nc[0], nc[0], mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], mode='2'+act_mode)) 209 | self.m_down2 = B.sequential(*[B.ResBlock(nc[1], nc[1], mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], mode='2'+act_mode)) 210 | self.m_down3 = B.sequential(*[B.ResBlock(nc[2], nc[2], mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], mode='2'+act_mode)) 211 | 212 | self.m_body = B.sequential(*[B.ResBlock(nc[3], nc[3], mode='C'+act_mode+'C') for _ in range(nb+1)]) 213 | 214 | # upsample 215 | if upsample_mode == 'upconv': 216 | upsample_block = B.upsample_upconv 217 | elif upsample_mode == 'pixelshuffle': 218 | upsample_block = B.upsample_pixelshuffle 219 | elif upsample_mode == 'convtranspose': 220 | upsample_block = B.upsample_convtranspose 221 | else: 222 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) 223 | 224 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], mode='2'+act_mode), *[B.ResBlock(nc[2], nc[2], mode='C'+act_mode+'C') for _ in range(nb)]) 225 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], mode='2'+act_mode), *[B.ResBlock(nc[1], nc[1], mode='C'+act_mode+'C') for _ in range(nb)]) 226 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], mode='2'+act_mode), *[B.ResBlock(nc[0], nc[0], mode='C'+act_mode+'C') for _ in range(nb)]) 227 | 228 | self.m_tail = B.conv(nc[0], out_nc*sf*sf, bias=False, mode='C') 229 | 230 | def forward(self, x0): 231 | x0_d = self.m_ps_down(x0) 232 | x1 = self.m_head(x0_d) 233 | x2 = self.m_down1(x1) 234 | x3 = self.m_down2(x2) 235 | x4 = self.m_down3(x3) 236 | x = self.m_body(x4) 237 | x = self.m_up3(x+x4) 238 | x = self.m_up2(x+x3) 239 | x = self.m_up1(x+x2) 240 | x = self.m_tail(x+x1) 241 | x = self.m_ps_up(x) + x0 242 | 243 | return x 244 | 245 | 246 | class UNetPlus(nn.Module): 247 | def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=1, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'): 248 | super(UNetPlus, self).__init__() 249 | 250 | self.m_head = B.conv(in_nc, nc[0], mode='C') 251 | 252 | # downsample 253 | if downsample_mode == 'avgpool': 254 | downsample_block = B.downsample_avgpool 255 | elif downsample_mode == 'maxpool': 256 | downsample_block = B.downsample_maxpool 257 | elif downsample_mode == 'strideconv': 258 | downsample_block = B.downsample_strideconv 259 | else: 260 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) 261 | 262 | self.m_down1 = B.sequential(*[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[0], nc[1], mode='2'+act_mode[1])) 263 | self.m_down2 = B.sequential(*[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[1], nc[2], mode='2'+act_mode[1])) 264 | self.m_down3 = B.sequential(*[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[2], nc[3], mode='2'+act_mode[1])) 265 | 266 | self.m_body = B.sequential(*[B.conv(nc[3], nc[3], mode='C'+act_mode) for _ in range(nb+1)]) 267 | 268 | # upsample 269 | if upsample_mode == 'upconv': 270 | upsample_block = B.upsample_upconv 271 | elif upsample_mode == 'pixelshuffle': 272 | upsample_block = B.upsample_pixelshuffle 273 | elif upsample_mode == 'convtranspose': 274 | upsample_block = B.upsample_convtranspose 275 | else: 276 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) 277 | 278 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], mode='2'+act_mode), *[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb-1)], B.conv(nc[2], nc[2], mode='C'+act_mode[1])) 279 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], mode='2'+act_mode), *[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb-1)], B.conv(nc[1], nc[1], mode='C'+act_mode[1])) 280 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], mode='2'+act_mode), *[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb-1)], B.conv(nc[0], nc[0], mode='C'+act_mode[1])) 281 | 282 | self.m_tail = B.conv(nc[0], out_nc, mode='C') 283 | 284 | def forward(self, x0): 285 | x1 = self.m_head(x0) 286 | x2 = self.m_down1(x1) 287 | x3 = self.m_down2(x2) 288 | x4 = self.m_down3(x3) 289 | x = self.m_body(x4) 290 | x = self.m_up3(x+x4) 291 | x = self.m_up2(x+x3) 292 | x = self.m_up1(x+x2) 293 | x = self.m_tail(x+x1) + x0 294 | return x 295 | 296 | ''' 297 | # ==================== 298 | # nonlocalunet 299 | # ==================== 300 | ''' 301 | 302 | class NonLocalUNet(nn.Module): 303 | def __init__(self, in_nc=3, out_nc=3, nc=[64,128,256,512], nb=1, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'): 304 | super(NonLocalUNet, self).__init__() 305 | 306 | down_nonlocal = B.NonLocalBlock2D(nc[2], kernel_size=1, stride=1, padding=0, bias=True, act_mode='B', downsample=False, downsample_mode='strideconv') 307 | up_nonlocal = B.NonLocalBlock2D(nc[2], kernel_size=1, stride=1, padding=0, bias=True, act_mode='B', downsample=False, downsample_mode='strideconv') 308 | 309 | self.m_head = B.conv(in_nc, nc[0], mode='C'+act_mode[-1]) 310 | 311 | # downsample 312 | if downsample_mode == 'avgpool': 313 | downsample_block = B.downsample_avgpool 314 | elif downsample_mode == 'maxpool': 315 | downsample_block = B.downsample_maxpool 316 | elif downsample_mode == 'strideconv': 317 | downsample_block = B.downsample_strideconv 318 | else: 319 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) 320 | 321 | 322 | self.m_down1 = B.sequential(*[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[0], nc[1], mode='2'+act_mode)) 323 | self.m_down2 = B.sequential(*[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[1], nc[2], mode='2'+act_mode)) 324 | self.m_down3 = B.sequential(down_nonlocal, *[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[2], nc[3], mode='2'+act_mode)) 325 | 326 | self.m_body = B.sequential(*[B.conv(nc[3], nc[3], mode='C'+act_mode) for _ in range(nb+1)]) 327 | 328 | # upsample 329 | if upsample_mode == 'upconv': 330 | upsample_block = B.upsample_upconv 331 | elif upsample_mode == 'pixelshuffle': 332 | upsample_block = B.upsample_pixelshuffle 333 | elif upsample_mode == 'convtranspose': 334 | upsample_block = B.upsample_convtranspose 335 | else: 336 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) 337 | 338 | 339 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], mode='2'+act_mode), *[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb)], up_nonlocal) 340 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], mode='2'+act_mode), *[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb)]) 341 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], mode='2'+act_mode), *[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb)]) 342 | 343 | self.m_tail = B.conv(nc[0], out_nc, mode='C') 344 | 345 | def forward(self, x0): 346 | x1 = self.m_head(x0) 347 | x2 = self.m_down1(x1) 348 | x3 = self.m_down2(x2) 349 | x4 = self.m_down3(x3) 350 | x = self.m_body(x4) 351 | x = self.m_up3(x+x4) 352 | x = self.m_up2(x+x3) 353 | x = self.m_up1(x+x2) 354 | x = self.m_tail(x+x1) + x0 355 | return x 356 | 357 | 358 | if __name__ == '__main__': 359 | x = torch.rand(1,3,256,256) 360 | # net = UNet(act_mode='BR') 361 | net = NonLocalUNet() 362 | net.eval() 363 | with torch.no_grad(): 364 | y = net(x) 365 | y.size() 366 | 367 | -------------------------------------------------------------------------------- /testsets/set12/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/01.png -------------------------------------------------------------------------------- /testsets/set12/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/02.png -------------------------------------------------------------------------------- /testsets/set12/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/03.png -------------------------------------------------------------------------------- /testsets/set12/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/04.png -------------------------------------------------------------------------------- /testsets/set12/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/05.png -------------------------------------------------------------------------------- /testsets/set12/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/06.png -------------------------------------------------------------------------------- /testsets/set12/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/07.png -------------------------------------------------------------------------------- /testsets/set12/08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/08.png -------------------------------------------------------------------------------- /testsets/set12/09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/09.png -------------------------------------------------------------------------------- /testsets/set12/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/10.png -------------------------------------------------------------------------------- /testsets/set12/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/11.png -------------------------------------------------------------------------------- /testsets/set12/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/12.png -------------------------------------------------------------------------------- /testsets/set3c/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set3c/butterfly.png -------------------------------------------------------------------------------- /testsets/set3c/leaves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set3c/leaves.png -------------------------------------------------------------------------------- /testsets/set3c/starfish.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set3c/starfish.png -------------------------------------------------------------------------------- /testsets/set5/baby_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set5/baby_GT.bmp -------------------------------------------------------------------------------- /testsets/set5/bird_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set5/bird_GT.bmp -------------------------------------------------------------------------------- /testsets/set5/butterfly_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set5/butterfly_GT.bmp -------------------------------------------------------------------------------- /testsets/set5/head_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set5/head_GT.bmp -------------------------------------------------------------------------------- /testsets/set5/woman_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set5/woman_GT.bmp -------------------------------------------------------------------------------- /utils/test.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/utils/test.bmp -------------------------------------------------------------------------------- /utils/utils_bnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def deleteLayer(model, layer_type=nn.BatchNorm2d): 6 | for k, m in list(model.named_children()): 7 | if isinstance(m, layer_type): 8 | del model._modules[k] 9 | deleteLayer(m, layer_type) 10 | 11 | 12 | def merge_bn(model): 13 | ''' by Kai Zhang, 11/01/2019. 14 | ''' 15 | prev_m = None 16 | for k, m in list(model.named_children()): 17 | if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)): 18 | 19 | w = prev_m.weight.data 20 | 21 | if prev_m.bias is None: 22 | zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type()) 23 | prev_m.bias = nn.Parameter(zeros) 24 | b = prev_m.bias.data 25 | 26 | invstd = m.running_var.clone().add_(m.eps).pow_(-0.5) 27 | if isinstance(prev_m, nn.ConvTranspose2d): 28 | w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w)) 29 | else: 30 | w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) 31 | b.add_(-m.running_mean).mul_(invstd) 32 | if m.affine: 33 | if isinstance(prev_m, nn.ConvTranspose2d): 34 | w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w)) 35 | else: 36 | w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) 37 | b.mul_(m.weight.data).add_(m.bias.data) 38 | 39 | del model._modules[k] 40 | prev_m = m 41 | merge_bn(m) 42 | 43 | 44 | def add_bn(model, for_init=True): 45 | ''' by Kai Zhang, 11/01/2019. 46 | ''' 47 | for k, m in list(model.named_children()): 48 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)): 49 | if for_init: 50 | b = nn.BatchNorm2d(m.out_channels, momentum=None, affine=False, eps=1e-04) 51 | else: 52 | b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True, eps=1e-04) 53 | b.weight.data.fill_(1) 54 | 55 | new_m = nn.Sequential(model._modules[k], b) 56 | model._modules[k] = new_m 57 | add_bn(m, for_init) 58 | 59 | 60 | def deploy_sequential(model): 61 | ''' by Kai Zhang, 11/01/2019. 62 | singleton children 63 | ''' 64 | for k, m in list(model.named_children()): 65 | if isinstance(m, nn.Sequential): 66 | if m.__len__() == 1: 67 | model._modules[k] = m.__getitem__(0) 68 | deploy_sequential(m) 69 | 70 | -------------------------------------------------------------------------------- /utils/utils_csmri.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from scipy import fftpack 4 | import torch 5 | from scipy import ndimage 6 | from utils import utils_image as util 7 | from scipy.interpolate import interp2d 8 | from scipy import signal 9 | import scipy.stats as ss 10 | import scipy.io as io 11 | import scipy 12 | 13 | ''' 14 | modified by Kai Zhang (github: https://github.com/cszn) 15 | 03/03/2019 16 | ''' 17 | 18 | 19 | ''' 20 | # ================= 21 | # pytorch 22 | # ================= 23 | ''' 24 | 25 | 26 | def splits(a, sf): 27 | '''split a into sfxsf distinct blocks 28 | 29 | Args: 30 | a: NxCxWxHx2 31 | sf: split factor 32 | 33 | Returns: 34 | b: NxCx(W/sf)x(H/sf)x2x(sf^2) 35 | ''' 36 | b = torch.stack(torch.chunk(a, sf, dim=2), dim=5) 37 | b = torch.cat(torch.chunk(b, sf, dim=3), dim=5) 38 | return b 39 | 40 | 41 | def c2c(x): 42 | return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1)) 43 | 44 | 45 | def r2c(x): 46 | # convert real to complex 47 | return torch.stack([x, torch.zeros_like(x)], -1) 48 | 49 | 50 | def cdiv(x, y): 51 | # complex division 52 | a, b = x[..., 0], x[..., 1] 53 | c, d = y[..., 0], y[..., 1] 54 | cd2 = c**2 + d**2 55 | return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1) 56 | 57 | 58 | def crdiv(x, y): 59 | # complex/real division 60 | a, b = x[..., 0], x[..., 1] 61 | return torch.stack([a/y, b/y], -1) 62 | 63 | 64 | def csum(x, y): 65 | # complex + real 66 | return torch.stack([x[..., 0] + y, x[..., 1]], -1) 67 | 68 | 69 | def cabs(x): 70 | # modulus of a complex number 71 | return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5) 72 | 73 | 74 | def cabs2(x): 75 | return x[..., 0]**2+x[..., 1]**2 76 | 77 | 78 | def cmul(t1, t2): 79 | '''complex multiplication 80 | 81 | Args: 82 | t1: NxCxHxWx2, complex tensor 83 | t2: NxCxHxWx2 84 | 85 | Returns: 86 | output: NxCxHxWx2 87 | ''' 88 | real1, imag1 = t1[..., 0], t1[..., 1] 89 | real2, imag2 = t2[..., 0], t2[..., 1] 90 | return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1) 91 | 92 | 93 | def cconj(t, inplace=False): 94 | '''complex's conjugation 95 | 96 | Args: 97 | t: NxCxHxWx2 98 | 99 | Returns: 100 | output: NxCxHxWx2 101 | ''' 102 | c = t.clone() if not inplace else t 103 | c[..., 1] *= -1 104 | return c 105 | 106 | 107 | def rfft(t): 108 | # Real-to-complex Discrete Fourier Transform 109 | return torch.rfft(t, 2, onesided=False) 110 | 111 | 112 | def irfft(t): 113 | # Complex-to-real Inverse Discrete Fourier Transform 114 | return torch.irfft(t, 2, onesided=False) 115 | 116 | 117 | def fft(t): 118 | # Complex-to-complex Discrete Fourier Transform 119 | return torch.fft(t, 2) 120 | 121 | 122 | def ifft(t): 123 | # Complex-to-complex Inverse Discrete Fourier Transform 124 | return torch.ifft(t, 2) 125 | 126 | 127 | def p2o(psf, shape): 128 | ''' 129 | Convert point-spread function to optical transfer function. 130 | otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the 131 | point-spread function (PSF) array and creates the optical transfer 132 | function (OTF) array that is not influenced by the PSF off-centering. 133 | 134 | Args: 135 | psf: NxCxhxw 136 | shape: [H, W] 137 | 138 | Returns: 139 | otf: NxCxHxWx2 140 | ''' 141 | otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) 142 | otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) 143 | for axis, axis_size in enumerate(psf.shape[2:]): 144 | otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) 145 | otf = torch.rfft(otf, 2, onesided=False) 146 | n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) 147 | otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf) 148 | return otf 149 | 150 | 151 | def upsample(x, sf=3): 152 | '''s-fold upsampler 153 | 154 | Upsampling the spatial size by filling the new entries with zeros 155 | 156 | x: tensor image, NxCxWxH 157 | ''' 158 | st = 0 159 | z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x) 160 | z[..., st::sf, st::sf].copy_(x) 161 | return z 162 | 163 | 164 | def downsample(x, sf=3): 165 | '''s-fold downsampler 166 | 167 | Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others 168 | 169 | x: tensor image, NxCxWxH 170 | ''' 171 | st = 0 172 | return x[..., st::sf, st::sf] 173 | 174 | 175 | def data_solution(x, FB, FBC, F2B, FBFy, alpha, sf): 176 | FR = FBFy + torch.rfft(alpha*x, 2, onesided=False) 177 | x1 = cmul(FB, FR) 178 | FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False) 179 | invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False) 180 | invWBR = cdiv(FBR, csum(invW, alpha)) 181 | FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1)) 182 | FX = (FR-FCBinvWBR)/alpha.unsqueeze(-1) 183 | Xest = torch.irfft(FX, 2, onesided=False) 184 | return Xest 185 | 186 | 187 | def pre_calculate(x, k, sf): 188 | ''' 189 | Args: 190 | x: NxCxHxW, LR input 191 | k: NxCxhxw 192 | sf: integer 193 | 194 | Returns: 195 | FB, FBC, F2B, FBFy 196 | will be reused during iterations 197 | ''' 198 | w, h = x.shape[-2:] 199 | FB = p2o(k, (w*sf, h*sf)) 200 | FBC = cconj(FB, inplace=False) 201 | F2B = r2c(cabs2(FB)) 202 | STy = upsample(x, sf=sf) 203 | FBFy = cmul(FBC, torch.rfft(STy, 2, onesided=False)) 204 | return FB, FBC, F2B, FBFy 205 | 206 | 207 | ''' 208 | # ================= 209 | PyTorch 210 | # ================= 211 | ''' 212 | 213 | 214 | def real2complex(x): 215 | return torch.stack([x, torch.zeros_like(x)], -1) 216 | 217 | 218 | def modcrop(img, sf): 219 | ''' 220 | img: tensor image, NxCxWxH or CxWxH or WxH 221 | sf: scale factor 222 | ''' 223 | w, h = img.shape[-2:] 224 | im = img.clone() 225 | return im[..., :w - w % sf, :h - h % sf] 226 | 227 | 228 | def circular_pad(x, pad): 229 | ''' 230 | # x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding) 231 | ''' 232 | x = torch.cat([x, x[:, :, 0:pad, :]], dim=2) 233 | x = torch.cat([x, x[:, :, :, 0:pad]], dim=3) 234 | x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2) 235 | x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3) 236 | return x 237 | 238 | 239 | def pad_circular(input, padding): 240 | # type: (Tensor, List[int]) -> Tensor 241 | """ 242 | Arguments 243 | :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))` 244 | :param padding: (tuple): m-elem tuple where m is the degree of convolution 245 | Returns 246 | :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0], 247 | H + 2 * padding[1]], W + 2 * padding[2]))` 248 | """ 249 | offset = 3 250 | for dimension in range(input.dim() - offset + 1): 251 | input = dim_pad_circular(input, padding[dimension], dimension + offset) 252 | return input 253 | 254 | 255 | def dim_pad_circular(input, padding, dimension): 256 | # type: (Tensor, int, int) -> Tensor 257 | input = torch.cat([input, input[[slice(None)] * (dimension - 1) + 258 | [slice(0, padding)]]], dim=dimension - 1) 259 | input = torch.cat([input[[slice(None)] * (dimension - 1) + 260 | [slice(-2 * padding, -padding)]], input], dim=dimension - 1) 261 | return input 262 | 263 | 264 | def imfilter(x, k): 265 | ''' 266 | x: image, NxcxHxW 267 | k: kernel, cx1xhxw 268 | ''' 269 | x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2)) 270 | x = torch.nn.functional.conv2d(x, k, groups=x.shape[1]) 271 | return x 272 | 273 | 274 | def G(x, k, sf=3): 275 | ''' 276 | x: image, NxcxHxW 277 | k: kernel, cx1xhxw 278 | sf: scale factor 279 | center: the first one or the moddle one 280 | 281 | Matlab function: 282 | tmp = imfilter(x,h,'circular'); 283 | y = downsample2(tmp,K); 284 | ''' 285 | x = downsample(imfilter(x, k), sf=sf) 286 | return x 287 | 288 | 289 | def Gt(x, k, sf=3): 290 | ''' 291 | x: image, NxcxHxW 292 | k: kernel, cx1xhxw 293 | sf: scale factor 294 | center: the first one or the moddle one 295 | 296 | Matlab function: 297 | tmp = upsample2(x,K); 298 | y = imfilter(tmp,h,'circular'); 299 | ''' 300 | x = imfilter(upsample(x, sf=sf), k) 301 | return x 302 | 303 | 304 | def interpolation_down(x, sf, center=False): 305 | mask = torch.zeros_like(x) 306 | if center: 307 | start = torch.tensor((sf-1)//2) 308 | mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x) 309 | LR = x[..., start::sf, start::sf] 310 | else: 311 | mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x) 312 | LR = x[..., ::sf, ::sf] 313 | y = x.mul(mask) 314 | 315 | return LR, y, mask 316 | 317 | 318 | """ 319 | # -------------------------------------------- 320 | # degradation models 321 | # -------------------------------------------- 322 | """ 323 | 324 | 325 | def csmri_degradation(x, M): 326 | ''' 327 | Args: 328 | x: 1x1xWxH image, [0, 1] 329 | M: mask, WxH 330 | n: noise, WxHx2 331 | ''' 332 | x = rfft(x).mul(M.unsqueeze(-1).unsqueeze(0).unsqueeze(0)) # + n.unsqueeze(0).unsqueeze(0) 333 | return x 334 | 335 | 336 | if __name__ == '__main__': 337 | 338 | weight = torch.tensor([[1.,2.,3.],[4.,5.,6.],[7.,8.,9.],[7.,8.,9.],[7.,8.,9.]]).view(1,1,5,3) 339 | input = torch.linspace(1,9,9).view(1,1,3,3) 340 | input = pad_circular(input, (2,1)) 341 | 342 | 343 | 344 | -------------------------------------------------------------------------------- /utils/utils_deblur.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import scipy 4 | from scipy import fftpack 5 | import torch 6 | 7 | from math import cos, sin 8 | from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round 9 | from numpy.random import randn, rand 10 | from scipy.signal import convolve2d 11 | 12 | # import utils_image as util 13 | 14 | ''' 15 | modified by Kai Zhang (github: https://github.com/cszn) 16 | 03/03/2019 17 | ''' 18 | 19 | 20 | def get_uperleft_denominator(img, kernel): 21 | ''' 22 | img: HxWxC 23 | kernel: hxw 24 | denominator: HxWx1 25 | upperleft: HxWxC 26 | ''' 27 | V = psf2otf(kernel, img.shape[:2]) 28 | denominator = np.expand_dims(np.abs(V)**2, axis=2) 29 | upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1]) 30 | return upperleft, denominator 31 | 32 | 33 | def get_uperleft_denominator_pytorch(img, kernel): 34 | ''' 35 | img: NxCxHxW 36 | kernel: Nx1xhxw 37 | denominator: Nx1xHxW 38 | upperleft: NxCxHxWx2 39 | ''' 40 | V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2 41 | denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW 42 | upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2 43 | return upperleft, denominator 44 | 45 | 46 | def c2c(x): 47 | return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1)) 48 | 49 | 50 | def r2c(x): 51 | return torch.stack([x, torch.zeros_like(x)], -1) 52 | 53 | 54 | def cdiv(x, y): 55 | a, b = x[..., 0], x[..., 1] 56 | c, d = y[..., 0], y[..., 1] 57 | cd2 = c**2 + d**2 58 | return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1) 59 | 60 | 61 | def cabs(x): 62 | return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5) 63 | 64 | 65 | def cmul(t1, t2): 66 | ''' 67 | complex multiplication 68 | t1: NxCxHxWx2 69 | output: NxCxHxWx2 70 | ''' 71 | real1, imag1 = t1[..., 0], t1[..., 1] 72 | real2, imag2 = t2[..., 0], t2[..., 1] 73 | return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1) 74 | 75 | 76 | def cconj(t, inplace=False): 77 | ''' 78 | # complex's conjugation 79 | t: NxCxHxWx2 80 | output: NxCxHxWx2 81 | ''' 82 | c = t.clone() if not inplace else t 83 | c[..., 1] *= -1 84 | return c 85 | 86 | 87 | def rfft(t): 88 | return torch.rfft(t, 2, onesided=False) 89 | 90 | 91 | def irfft(t): 92 | return torch.irfft(t, 2, onesided=False) 93 | 94 | 95 | def fft(t): 96 | return torch.fft(t, 2) 97 | 98 | 99 | def ifft(t): 100 | return torch.ifft(t, 2) 101 | 102 | 103 | def p2o(psf, shape): 104 | ''' 105 | # psf: NxCxhxw 106 | # shape: [H,W] 107 | # otf: NxCxHxWx2 108 | ''' 109 | otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) 110 | otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) 111 | for axis, axis_size in enumerate(psf.shape[2:]): 112 | otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) 113 | otf = torch.rfft(otf, 2, onesided=False) 114 | n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) 115 | otf[...,1][torch.abs(otf[...,1])= abs(y)] = abs(x)[abs(x) >= abs(y)] 472 | maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)] 473 | minxy = np.zeros(x.shape) 474 | minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)] 475 | minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)] 476 | m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\ 477 | (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\ 478 | np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2) 479 | m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\ 480 | (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\ 481 | np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2) 482 | h = None 483 | return h 484 | 485 | 486 | def fspecial_gaussian(hsize, sigma): 487 | hsize = [hsize, hsize] 488 | siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0] 489 | std = sigma 490 | [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1)) 491 | arg = -(x*x + y*y)/(2*std*std) 492 | h = np.exp(arg) 493 | h[h < scipy.finfo(float).eps * h.max()] = 0 494 | sumh = h.sum() 495 | if sumh != 0: 496 | h = h/sumh 497 | return h 498 | 499 | 500 | def fspecial_laplacian(alpha): 501 | alpha = max([0, min([alpha,1])]) 502 | h1 = alpha/(alpha+1) 503 | h2 = (1-alpha)/(alpha+1) 504 | h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]] 505 | h = np.array(h) 506 | return h 507 | 508 | 509 | def fspecial_log(hsize, sigma): 510 | raise(NotImplemented) 511 | 512 | 513 | def fspecial_motion(motion_len, theta): 514 | raise(NotImplemented) 515 | 516 | 517 | def fspecial_prewitt(): 518 | return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]]) 519 | 520 | 521 | def fspecial_sobel(): 522 | return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) 523 | 524 | 525 | def fspecial(filter_type, *args, **kwargs): 526 | ''' 527 | python code from: 528 | https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py 529 | ''' 530 | if filter_type == 'average': 531 | return fspecial_average(*args, **kwargs) 532 | if filter_type == 'disk': 533 | return fspecial_disk(*args, **kwargs) 534 | if filter_type == 'gaussian': 535 | return fspecial_gaussian(*args, **kwargs) 536 | if filter_type == 'laplacian': 537 | return fspecial_laplacian(*args, **kwargs) 538 | if filter_type == 'log': 539 | return fspecial_log(*args, **kwargs) 540 | if filter_type == 'motion': 541 | return fspecial_motion(*args, **kwargs) 542 | if filter_type == 'prewitt': 543 | return fspecial_prewitt(*args, **kwargs) 544 | if filter_type == 'sobel': 545 | return fspecial_sobel(*args, **kwargs) 546 | 547 | 548 | def fspecial_gauss(size, sigma): 549 | x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1] 550 | g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) 551 | return g / g.sum() 552 | 553 | 554 | def blurkernel_synthesis(h=37, w=None): 555 | w = h if w is None else w 556 | kdims = [h, w] 557 | x = randomTrajectory(150) 558 | k = None 559 | while k is None: 560 | k = kernelFromTrajectory(x) 561 | 562 | # center pad to kdims 563 | pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2) 564 | pad_width = [(pad_width[0],), (pad_width[1],)] 565 | if pad_width[0][0]<0 or pad_width[1][0]<0: 566 | k = k[0:h, 0:h] 567 | else: 568 | k = pad(k, pad_width, "constant") 569 | # import matplotlib.pyplot as plt 570 | # plt.imshow(k, interpolation="nearest", cmap="gray") 571 | # plt.show() 572 | #print(k.dtype) 573 | return k 574 | 575 | 576 | def kernelFromTrajectory(x): 577 | h = 5 - log(rand()) / 0.15 578 | h = round(min([h, 27])).astype(int) 579 | h = h + 1 - h % 2 580 | w = h 581 | k = zeros((h, w)) 582 | 583 | xmin = min(x[0]) 584 | xmax = max(x[0]) 585 | ymin = min(x[1]) 586 | ymax = max(x[1]) 587 | xthr = arange(xmin, xmax, (xmax - xmin) / w) 588 | ythr = arange(ymin, ymax, (ymax - ymin) / h) 589 | 590 | for i in range(1, xthr.size): 591 | for j in range(1, ythr.size): 592 | idx = ( 593 | (x[0, :] >= xthr[i - 1]) 594 | & (x[0, :] < xthr[i]) 595 | & (x[1, :] >= ythr[j - 1]) 596 | & (x[1, :] < ythr[j]) 597 | ) 598 | k[i - 1, j - 1] = sum(idx) 599 | if sum(k) == 0: 600 | return 601 | k = k / sum(k) 602 | k = convolve2d(k, fspecial_gauss(3, 1), "same") 603 | k = k / sum(k) 604 | return k 605 | 606 | 607 | def randomTrajectory(T): 608 | x = zeros((3, T)) 609 | v = randn(3, T) 610 | r = zeros((3, T)) 611 | trv = 1 / 1 612 | trr = 2 * pi / T 613 | for t in range(1, T): 614 | F_rot = randn(3) / (t + 1) + r[:, t - 1] 615 | F_trans = randn(3) / (t + 1) 616 | r[:, t] = r[:, t - 1] + trr * F_rot 617 | v[:, t] = v[:, t - 1] + trv * F_trans 618 | st = v[:, t] 619 | st = rot3D(st, r[:, t]) 620 | x[:, t] = x[:, t - 1] + st 621 | return x 622 | 623 | 624 | def rot3D(x, r): 625 | Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]]) 626 | Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]]) 627 | Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]]) 628 | R = Rz @ Ry @ Rx 629 | x = R @ x 630 | return x 631 | 632 | 633 | if __name__ == '__main__': 634 | # a = opt_fft_size([111]) 635 | # print(a) 636 | # 637 | # print(fspecial('gaussian', 5, 1)) 638 | # 639 | # print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape) 640 | 641 | k = blurkernel_synthesis(25) 642 | print(k.shape) 643 | print(sum(k)) 644 | import matplotlib.pyplot as plt 645 | plt.imshow(k, interpolation="nearest", cmap="gray") 646 | plt.show() 647 | 648 | # kernel = fspecial('gaussian', 3, 1) 649 | # img = np.random.randn(5,5,1) 650 | # a, b = get_uperleft_denominator(img, kernel) 651 | # print(a) 652 | ## print(b) 653 | # 654 | # a, b = get_uperleft_denominator_pytorch(util.single2tensor4(img), util.single2tensor4(kernel[...,np.newaxis])) 655 | # print(a.squeeze()) 656 | # print(b.squeeze()) 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | # get_uperleft_denominator_pytorch( 665 | -------------------------------------------------------------------------------- /utils/utils_inpaint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from utils import utils_image as util 4 | 5 | ''' 6 | modified by Kai Zhang (github: https://github.com/cszn) 7 | 03/03/2019 8 | ''' 9 | 10 | 11 | # -------------------------------- 12 | # get rho and sigma 13 | # -------------------------------- 14 | def get_rho_sigma(sigma=2.55/255, iter_num=15, modelSigma2=2.55): 15 | ''' 16 | Kai Zhang (github: https://github.com/cszn) 17 | 03/03/2019 18 | ''' 19 | modelSigma1 = 49.0 20 | modelSigmaS = np.logspace(np.log10(modelSigma1), np.log10(modelSigma2), iter_num) 21 | sigmas = modelSigmaS/255. 22 | mus = list(map(lambda x: (sigma**2)/(x**2)/3, sigmas)) 23 | rhos = mus 24 | return rhos, sigmas 25 | 26 | 27 | def shepard_initialize(image, measurement_mask, window=5, p=2): 28 | wing = np.floor(window/2).astype(int) # Length of each "wing" of the window. 29 | h, w = image.shape[0:2] 30 | ch = 3 if image.ndim == 3 and image.shape[-1] == 3 else 1 31 | x = np.copy(image) # ML initialization 32 | for i in range(h): 33 | i_lower_limit = -np.min([wing, i]) 34 | i_upper_limit = np.min([wing, h-i-1]) 35 | for j in range(w): 36 | if measurement_mask[i, j] == 0: # checking if there's a need to interpolate 37 | j_lower_limit = -np.min([wing, j]) 38 | j_upper_limit = np.min([wing, w-j-1]) 39 | 40 | count = 0 # keeps track of how many measured pixels are withing the neighborhood 41 | sum_IPD = 0 42 | interpolated_value = 0 43 | 44 | num_zeros = window**2 45 | IPD = np.zeros([num_zeros]) 46 | pixel = np.zeros([num_zeros,ch]) 47 | 48 | for neighborhood_i in range(i+i_lower_limit, i+i_upper_limit): 49 | for neighborhood_j in range(j+j_lower_limit, j+j_upper_limit): 50 | if measurement_mask[neighborhood_i, neighborhood_j] == 1: 51 | # IPD: "inverse pth-power distance". 52 | IPD[count] = 1.0/((neighborhood_i - i)**p + (neighborhood_j - j)**p) 53 | sum_IPD = sum_IPD + IPD[count] 54 | pixel[count] = image[neighborhood_i, neighborhood_j] 55 | count = count + 1 56 | 57 | for c in range(count): 58 | weight = IPD[c]/sum_IPD 59 | interpolated_value = interpolated_value + weight*pixel[c] 60 | x[i, j] = interpolated_value 61 | 62 | return x 63 | 64 | 65 | if __name__ == '__main__': 66 | # image path & sampling ratio 67 | import matplotlib.pyplot as mplot 68 | import matplotlib.image as mpimg 69 | Im = mpimg.imread('test.bmp') 70 | #Im = Im[:,:,1] 71 | Im = np.squeeze(Im) 72 | 73 | SmpRatio = 0.2 74 | # creat mask 75 | mask_Array = np.random.rand(Im.shape[0],Im.shape[1]) 76 | mask_Array = (mask_Array < SmpRatio) 77 | print(mask_Array.dtype) 78 | 79 | # sampled image 80 | print('The sampling ratio is', SmpRatio) 81 | Im_sampled = np.multiply(np.expand_dims(mask_Array,2), Im) 82 | util.imshow(Im_sampled) 83 | 84 | a = shepard_initialize(Im_sampled.astype(np.float32), mask_Array, window=9) 85 | a = np.clip(a,0,255) 86 | 87 | 88 | print(a.dtype) 89 | 90 | 91 | util.imshow(np.concatenate((a,Im_sampled),1)/255.0) 92 | util.imsave(np.concatenate((a,Im_sampled),1),'a.png') 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /utils/utils_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import datetime 4 | import logging 5 | 6 | 7 | ''' 8 | modified by Kai Zhang (github: https://github.com/cszn) 9 | 03/03/2019 10 | https://github.com/xinntao/BasicSR 11 | ''' 12 | 13 | 14 | def log(*args, **kwargs): 15 | print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) 16 | 17 | 18 | ''' 19 | # =============================== 20 | # logger 21 | # logger_name = None = 'base' ??? 22 | # =============================== 23 | ''' 24 | 25 | 26 | def logger_info(logger_name, log_path='default_logger.log'): 27 | ''' set up logger 28 | modified by Kai Zhang (github: https://github.com/cszn) 29 | ''' 30 | log = logging.getLogger(logger_name) 31 | if log.hasHandlers(): 32 | print('LogHandlers exists!') 33 | else: 34 | print('LogHandlers setup!') 35 | level = logging.INFO 36 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') 37 | fh = logging.FileHandler(log_path, mode='a') 38 | fh.setFormatter(formatter) 39 | log.setLevel(level) 40 | log.addHandler(fh) 41 | # print(len(log.handlers)) 42 | 43 | sh = logging.StreamHandler() 44 | sh.setFormatter(formatter) 45 | log.addHandler(sh) 46 | 47 | 48 | ''' 49 | # =============================== 50 | # print to file and std_out simultaneously 51 | # =============================== 52 | ''' 53 | 54 | 55 | class logger_print(object): 56 | def __init__(self, log_path="default.log"): 57 | self.terminal = sys.stdout 58 | self.log = open(log_path, 'a') 59 | 60 | def write(self, message): 61 | self.terminal.write(message) 62 | self.log.write(message) # write the message 63 | 64 | def flush(self): 65 | pass 66 | -------------------------------------------------------------------------------- /utils/utils_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch 4 | from utils import utils_image as util 5 | 6 | 7 | ''' 8 | modified by Kai Zhang (github: https://github.com/cszn) 9 | 03/03/2019 10 | ''' 11 | 12 | 13 | def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1): 14 | ''' 15 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 16 | # Some testing modes 17 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 18 | # (0) normal: test(model, L) 19 | # (1) pad: test_pad(model, L, modulo=16) 20 | # (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1) 21 | # (3) x8: test_x8(model, L, modulo=1) 22 | # (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1) 23 | # (4) split only once: test_onesplit(model, L, refield=32, min_size=256, sf=1, modulo=1) 24 | # --------------------------------------- 25 | ''' 26 | if mode == 0: 27 | E = test(model, L) 28 | elif mode == 1: 29 | E = test_pad(model, L, modulo) 30 | elif mode == 2: 31 | E = test_split(model, L, refield, min_size, sf, modulo) 32 | elif mode == 3: 33 | E = test_x8(model, L, modulo) 34 | elif mode == 4: 35 | E = test_split_x8(model, L, refield, min_size, sf, modulo) 36 | elif mode == 5: 37 | E = test_onesplit(model, L, refield, min_size, sf, modulo) 38 | return E 39 | 40 | 41 | ''' 42 | # --------------------------------------- 43 | # normal (0) 44 | # --------------------------------------- 45 | ''' 46 | 47 | 48 | def test(model, L): 49 | E = model(L) 50 | return E 51 | 52 | 53 | ''' 54 | # --------------------------------------- 55 | # pad (1) 56 | # --------------------------------------- 57 | ''' 58 | 59 | 60 | def test_pad(model, L, modulo=16): 61 | h, w = L.size()[-2:] 62 | paddingBottom = int(np.ceil(h/modulo)*modulo-h) 63 | paddingRight = int(np.ceil(w/modulo)*modulo-w) 64 | L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L) 65 | E = model(L) 66 | E = E[..., :h, :w] 67 | return E 68 | 69 | 70 | ''' 71 | # --------------------------------------- 72 | # split (function) 73 | # --------------------------------------- 74 | ''' 75 | 76 | 77 | def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1): 78 | ''' 79 | model: 80 | L: input Low-quality image 81 | refield: effective receptive filed of the network, 32 is enough 82 | min_size: min_sizeXmin_size image, e.g., 256X256 image 83 | sf: scale factor for super-resolution, otherwise 1 84 | modulo: 1 if split 85 | ''' 86 | h, w = L.size()[-2:] 87 | if h*w <= min_size**2: 88 | L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L) 89 | E = model(L) 90 | E = E[..., :h*sf, :w*sf] 91 | else: 92 | top = slice(0, (h//2//refield+1)*refield) 93 | bottom = slice(h - (h//2//refield+1)*refield, h) 94 | left = slice(0, (w//2//refield+1)*refield) 95 | right = slice(w - (w//2//refield+1)*refield, w) 96 | Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] 97 | 98 | if h * w <= 4*(min_size**2): 99 | Es = [model(Ls[i]) for i in range(4)] 100 | else: 101 | Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)] 102 | 103 | b, c = Es[0].size()[:2] 104 | E = torch.zeros(b, c, sf * h, sf * w).type_as(L) 105 | 106 | E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf] 107 | E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:] 108 | E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf] 109 | E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:] 110 | return E 111 | 112 | 113 | 114 | def test_onesplit(model, L, refield=32, min_size=256, sf=1, modulo=1): 115 | ''' 116 | model: 117 | L: input Low-quality image 118 | refield: effective receptive filed of the network, 32 is enough 119 | min_size: min_sizeXmin_size image, e.g., 256X256 image 120 | sf: scale factor for super-resolution, otherwise 1 121 | modulo: 1 if split 122 | ''' 123 | h, w = L.size()[-2:] 124 | 125 | top = slice(0, (h//2//refield+1)*refield) 126 | bottom = slice(h - (h//2//refield+1)*refield, h) 127 | left = slice(0, (w//2//refield+1)*refield) 128 | right = slice(w - (w//2//refield+1)*refield, w) 129 | Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] 130 | Es = [model(Ls[i]) for i in range(4)] 131 | b, c = Es[0].size()[:2] 132 | E = torch.zeros(b, c, sf * h, sf * w).type_as(L) 133 | E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf] 134 | E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:] 135 | E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf] 136 | E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:] 137 | return E 138 | 139 | 140 | 141 | ''' 142 | # --------------------------------------- 143 | # split (2) 144 | # --------------------------------------- 145 | ''' 146 | 147 | 148 | def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1): 149 | E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo) 150 | return E 151 | 152 | 153 | ''' 154 | # --------------------------------------- 155 | # x8 (3) 156 | # --------------------------------------- 157 | ''' 158 | 159 | 160 | def test_x8(model, L, modulo=1): 161 | E_list = [test_pad(model, util.augment_img_tensor(L, mode=i), modulo=modulo) for i in range(8)] 162 | for i in range(len(E_list)): 163 | if i == 3 or i == 5: 164 | E_list[i] = util.augment_img_tensor(E_list[i], mode=8 - i) 165 | else: 166 | E_list[i] = util.augment_img_tensor(E_list[i], mode=i) 167 | output_cat = torch.stack(E_list, dim=0) 168 | E = output_cat.mean(dim=0, keepdim=False) 169 | return E 170 | 171 | 172 | ''' 173 | # --------------------------------------- 174 | # split and x8 (4) 175 | # --------------------------------------- 176 | ''' 177 | 178 | 179 | def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1): 180 | E_list = [test_split_fn(model, util.augment_img_tensor(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)] 181 | for k, i in enumerate(range(len(E_list))): 182 | if i==3 or i==5: 183 | E_list[k] = util.augment_img_tensor(E_list[k], mode=8-i) 184 | else: 185 | E_list[k] = util.augment_img_tensor(E_list[k], mode=i) 186 | output_cat = torch.stack(E_list, dim=0) 187 | E = output_cat.mean(dim=0, keepdim=False) 188 | return E 189 | 190 | 191 | ''' 192 | # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^ 193 | # _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_ 194 | # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^ 195 | ''' 196 | 197 | 198 | ''' 199 | # --------------------------------------- 200 | # print 201 | # --------------------------------------- 202 | ''' 203 | 204 | 205 | # ------------------- 206 | # print model 207 | # ------------------- 208 | def print_model(model): 209 | msg = describe_model(model) 210 | print(msg) 211 | 212 | 213 | # ------------------- 214 | # print params 215 | # ------------------- 216 | def print_params(model): 217 | msg = describe_params(model) 218 | print(msg) 219 | 220 | 221 | ''' 222 | # --------------------------------------- 223 | # information 224 | # --------------------------------------- 225 | ''' 226 | 227 | 228 | # ------------------- 229 | # model inforation 230 | # ------------------- 231 | def info_model(model): 232 | msg = describe_model(model) 233 | return msg 234 | 235 | 236 | # ------------------- 237 | # params inforation 238 | # ------------------- 239 | def info_params(model): 240 | msg = describe_params(model) 241 | return msg 242 | 243 | 244 | ''' 245 | # --------------------------------------- 246 | # description 247 | # --------------------------------------- 248 | ''' 249 | 250 | 251 | # ---------------------------------------------- 252 | # model name and total number of parameters 253 | # ---------------------------------------------- 254 | def describe_model(model): 255 | if isinstance(model, torch.nn.DataParallel): 256 | model = model.module 257 | msg = '\n' 258 | msg += 'models name: {}'.format(model.__class__.__name__) + '\n' 259 | msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n' 260 | msg += 'Net structure:\n{}'.format(str(model)) + '\n' 261 | return msg 262 | 263 | 264 | # ---------------------------------------------- 265 | # parameters description 266 | # ---------------------------------------------- 267 | def describe_params(model): 268 | if isinstance(model, torch.nn.DataParallel): 269 | model = model.module 270 | msg = '\n' 271 | msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'param_name') + '\n' 272 | for name, param in model.state_dict().items(): 273 | if not 'num_batches_tracked' in name: 274 | v = param.data.clone().float() 275 | msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), name) + '\n' 276 | return msg 277 | 278 | 279 | if __name__ == '__main__': 280 | 281 | class Net(torch.nn.Module): 282 | def __init__(self, in_channels=3, out_channels=3): 283 | super(Net, self).__init__() 284 | self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) 285 | 286 | def forward(self, x): 287 | x = self.conv(x) 288 | return x 289 | 290 | start = torch.cuda.Event(enable_timing=True) 291 | end = torch.cuda.Event(enable_timing=True) 292 | 293 | model = Net() 294 | model = model.eval() 295 | print_model(model) 296 | print_params(model) 297 | x = torch.randn((2,3,400,400)) 298 | torch.cuda.empty_cache() 299 | with torch.no_grad(): 300 | for mode in range(5): 301 | y = test_mode(model, x, mode) 302 | print(y.shape) 303 | -------------------------------------------------------------------------------- /utils/utils_mosaic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from utils import utils_image as util 4 | #import utils_image as util 5 | import cv2 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | ''' 11 | modified by Kai Zhang (github: https://github.com/cszn) 12 | 03/03/2019 13 | ''' 14 | def dm(imgs): 15 | """ bilinear demosaicking 16 | Args: 17 | imgs: Nx4xW/2xH/2 18 | 19 | Returns: 20 | output: Nx3xWxH 21 | """ 22 | k_r = 1/4 * torch.FloatTensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]]).type_as(imgs) 23 | k_g = 1/4 * torch.FloatTensor([[0, 1, 0], [1, 4, 1], [0, 1, 0]]).type_as(imgs) 24 | k = torch.stack((k_r,k_g,k_r), dim=0).unsqueeze(1) 25 | 26 | rgb = torch.zeros(imgs.size(0), 3, imgs.size(2)*2, imgs.size(3)*2).type_as(imgs) 27 | rgb[:, 0, 0::2, 0::2] = imgs[:, 0, :, :] 28 | rgb[:, 1, 0::2, 1::2] = imgs[:, 1, :, :] 29 | rgb[:, 1, 1::2, 0::2] = imgs[:, 2, :, :] 30 | rgb[:, 2, 1::2, 1::2] = imgs[:, 3, :, :] 31 | 32 | rgb = nn.functional.pad(rgb, (1, 1, 1, 1), mode='circular') 33 | rgb = nn.functional.conv2d(rgb, k, groups=3, padding=0, bias=None) 34 | 35 | return rgb 36 | 37 | 38 | def dm_matlab(imgs): 39 | """ matlab demosaicking 40 | Args: 41 | imgs: Nx4xW/2xH/2 42 | 43 | Returns: 44 | output: Nx3xWxH 45 | """ 46 | 47 | kgrb = 1/8*torch.FloatTensor([[0, 0, -1, 0, 0], 48 | [0, 0, 2, 0, 0], 49 | [-1, 2, 4, 2, -1], 50 | [0, 0, 2, 0, 0], 51 | [0, 0, -1, 0, 0]]).type_as(imgs) 52 | krbg0 = 1/8*torch.FloatTensor([[0, 0, 1/2, 0, 0], 53 | [0, -1, 0, -1, 0], 54 | [-1, 4, 5, 4, -1], 55 | [0, -1, 0, -1, 0], 56 | [0, 0, 1/2, 0, 0]]).type_as(imgs) 57 | krbg1 = krbg0.t() 58 | krbbr = 1/8*torch.FloatTensor([[0, 0, -3/2, 0, 0], 59 | [0, 2, 0, 2, 0], 60 | [-3/2, 0, 6, 0, -3/2], 61 | [0, 2, 0, 2, 0], 62 | [0, 0, -3/2, 0, 0]]).type_as(imgs) 63 | 64 | k = torch.stack((kgrb, krbg0, krbg1, krbbr), 0).unsqueeze(1) 65 | 66 | cfa = torch.zeros(imgs.size(0), 1, imgs.size(2)*2, imgs.size(3)*2).type_as(imgs) 67 | cfa[:, 0, 0::2, 0::2] = imgs[:, 0, :, :] 68 | cfa[:, 0, 0::2, 1::2] = imgs[:, 1, :, :] 69 | cfa[:, 0, 1::2, 0::2] = imgs[:, 2, :, :] 70 | cfa[:, 0, 1::2, 1::2] = imgs[:, 3, :, :] 71 | rgb = cfa.repeat(1, 3, 1, 1) 72 | 73 | cfa = nn.functional.pad(cfa, (2, 2, 2, 2), mode='reflect') 74 | conv_cfa = nn.functional.conv2d(cfa, k, padding=0, bias=None) 75 | 76 | # fill G 77 | rgb[:, 1, 0::2, 0::2] = conv_cfa[:, 0, 0::2, 0::2] 78 | rgb[:, 1, 1::2, 1::2] = conv_cfa[:, 0, 1::2, 1::2] 79 | 80 | # fill R 81 | rgb[:, 0, 0::2, 1::2] = conv_cfa[:, 1, 0::2, 1::2] 82 | rgb[:, 0, 1::2, 0::2] = conv_cfa[:, 2, 1::2, 0::2] 83 | rgb[:, 0, 1::2, 1::2] = conv_cfa[:, 3, 1::2, 1::2] 84 | 85 | # fill B 86 | rgb[:, 2, 0::2, 1::2] = conv_cfa[:, 2, 0::2, 1::2] 87 | rgb[:, 2, 1::2, 0::2] = conv_cfa[:, 1, 1::2, 0::2] 88 | rgb[:, 2, 0::2, 0::2] = conv_cfa[:, 3, 0::2, 0::2] 89 | 90 | return rgb 91 | 92 | 93 | def tstack(a): # cv2.merge() 94 | a = np.asarray(a) 95 | return np.concatenate([x[..., np.newaxis] for x in a], axis=-1) 96 | 97 | 98 | def tsplit(a): # cv2.split() 99 | a = np.asarray(a) 100 | return np.array([a[..., x] for x in range(a.shape[-1])]) 101 | 102 | 103 | def masks_CFA_Bayer(shape): 104 | pattern = 'RGGB' 105 | channels = dict((channel, np.zeros(shape)) for channel in 'RGB') 106 | for channel, (y, x) in zip(pattern, [(0, 0), (0, 1), (1, 0), (1, 1)]): 107 | channels[channel][y::2, x::2] = 1 108 | return tuple(channels[c].astype(bool) for c in 'RGB') 109 | 110 | 111 | def mosaic_CFA_Bayer(RGB): 112 | R_m, G_m, B_m = masks_CFA_Bayer(RGB.shape[0:2]) 113 | mask = np.concatenate((R_m[..., np.newaxis], G_m[..., np.newaxis], B_m[..., np.newaxis]), axis=-1) 114 | # mask = tstack((R_m, G_m, B_m)) 115 | mosaic = np.multiply(mask, RGB) # mask*RGB 116 | CFA = mosaic.sum(2).astype(np.uint8) 117 | 118 | CFA4 = np.zeros((RGB.shape[0]//2, RGB.shape[1]//2, 4), dtype=np.uint8) 119 | CFA4[:, :, 0] = CFA[0::2, 0::2] 120 | CFA4[:, :, 1] = CFA[0::2, 1::2] 121 | CFA4[:, :, 2] = CFA[1::2, 0::2] 122 | CFA4[:, :, 3] = CFA[1::2, 1::2] 123 | 124 | return CFA, CFA4, mosaic, mask 125 | 126 | 127 | if __name__ == '__main__': 128 | 129 | Im = util.imread_uint('test.bmp', 3) 130 | 131 | CFA, CFA4, mosaic, mask = mosaic_CFA_Bayer(Im) 132 | convertedImage = cv2.cvtColor(CFA, cv2.COLOR_BAYER_BG2RGB_EA) 133 | 134 | util.imshow(CFA) 135 | util.imshow(mosaic) 136 | util.imshow(mask.astype(np.float32)) 137 | util.imshow(convertedImage) 138 | 139 | util.imsave(mask.astype(np.float32)*255,'bayer.png') 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /utils/utils_pnp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | ''' 6 | modified by Kai Zhang (github: https://github.com/cszn) 7 | 03/03/2019 8 | ''' 9 | 10 | 11 | # -------------------------------- 12 | # get rho and sigma 13 | # -------------------------------- 14 | #def get_rho_sigma(sigma=2.55/255, iter_num=15, modelSigma1=49.0, modelSigma2=2.55): 15 | # ''' 16 | # One can change the sigma to implicitly change the trade-off parameter 17 | # between fidelity term and prior term 18 | # ''' 19 | # modelSigmaS = np.logspace(np.log10(modelSigma1), np.log10(modelSigma2), iter_num).astype(np.float32) 20 | # sigmas = modelSigmaS/255. 21 | # rhos = list(map(lambda x: 0.23*(sigma**2)/(x**2), sigmas)) 22 | # return rhos, sigmas 23 | 24 | # -------------------------------- 25 | # get rho and sigma 26 | # -------------------------------- 27 | def get_rho_sigma(sigma=2.55/255, iter_num=15, modelSigma1=49.0, modelSigma2=2.55, w=1.0): 28 | ''' 29 | One can change the sigma to implicitly change the trade-off parameter 30 | between fidelity term and prior term 31 | ''' 32 | modelSigmaS = np.logspace(np.log10(modelSigma1), np.log10(modelSigma2), iter_num).astype(np.float32) 33 | modelSigmaS_lin = np.linspace(modelSigma1, modelSigma2, iter_num).astype(np.float32) 34 | sigmas = (modelSigmaS*w+modelSigmaS_lin*(1-w))/255. 35 | rhos = list(map(lambda x: 0.23*(sigma**2)/(x**2), sigmas)) 36 | return rhos, sigmas 37 | 38 | 39 | def get_rho_sigma1(sigma=2.55/255, iter_num=15, modelSigma1=49.0, modelSigma2=2.55, lamda=3.0): 40 | ''' 41 | One can change the sigma to implicitly change the trade-off parameter 42 | between fidelity term and prior term 43 | ''' 44 | modelSigmaS = np.logspace(np.log10(modelSigma1), np.log10(modelSigma2), iter_num).astype(np.float32) 45 | sigmas = modelSigmaS/255. 46 | rhos = list(map(lambda x: (sigma**2)/(x**2)/lamda, sigmas)) 47 | return rhos, sigmas 48 | 49 | 50 | if __name__ == '__main__': 51 | rhos, sigmas = get_rho_sigma(sigma=2.55/255, iter_num=30, modelSigma2=2.55) 52 | print(rhos) 53 | print(sigmas*255) 54 | -------------------------------------------------------------------------------- /utils/utils_sisr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.fft 3 | import torch 4 | 5 | 6 | def splits(a, sf): 7 | '''split a into sfxsf distinct blocks 8 | Args: 9 | a: NxCxWxH 10 | sf: split factor 11 | Returns: 12 | b: NxCx(W/sf)x(H/sf)x(sf^2) 13 | ''' 14 | b = torch.stack(torch.chunk(a, sf, dim=2), dim=4) 15 | b = torch.cat(torch.chunk(b, sf, dim=3), dim=4) 16 | return b 17 | 18 | 19 | def p2o(psf, shape): 20 | ''' 21 | Convert point-spread function to optical transfer function. 22 | otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the 23 | point-spread function (PSF) array and creates the optical transfer 24 | function (OTF) array that is not influenced by the PSF off-centering. 25 | Args: 26 | psf: NxCxhxw 27 | shape: [H, W] 28 | Returns: 29 | otf: NxCxHxWx2 30 | ''' 31 | otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) 32 | otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) 33 | for axis, axis_size in enumerate(psf.shape[2:]): 34 | otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) 35 | otf = torch.fft.fftn(otf, dim=(-2,-1)) 36 | #n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) 37 | #otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf) 38 | return otf 39 | 40 | 41 | def upsample(x, sf=3): 42 | '''s-fold upsampler 43 | Upsampling the spatial size by filling the new entries with zeros 44 | x: tensor image, NxCxWxH 45 | ''' 46 | st = 0 47 | z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x) 48 | z[..., st::sf, st::sf].copy_(x) 49 | return z 50 | 51 | 52 | def downsample(x, sf=3): 53 | '''s-fold downsampler 54 | Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others 55 | x: tensor image, NxCxWxH 56 | ''' 57 | st = 0 58 | return x[..., st::sf, st::sf] 59 | 60 | 61 | 62 | def data_solution(x, FB, FBC, F2B, FBFy, alpha, sf): 63 | FR = FBFy + torch.fft.fftn(alpha*x, dim=(-2,-1)) 64 | x1 = FB.mul(FR) 65 | FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False) 66 | invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False) 67 | invWBR = FBR.div(invW + alpha) 68 | FCBinvWBR = FBC*invWBR.repeat(1, 1, sf, sf) 69 | FX = (FR-FCBinvWBR)/alpha 70 | Xest = torch.real(torch.fft.ifftn(FX, dim=(-2, -1))) 71 | 72 | return Xest 73 | 74 | 75 | def pre_calculate(x, k, sf): 76 | ''' 77 | Args: 78 | x: NxCxHxW, LR input 79 | k: NxCxhxw 80 | sf: integer 81 | 82 | Returns: 83 | FB, FBC, F2B, FBFy 84 | will be reused during iterations 85 | ''' 86 | w, h = x.shape[-2:] 87 | FB = p2o(k, (w*sf, h*sf)) 88 | FBC = torch.conj(FB) 89 | F2B = torch.pow(torch.abs(FB), 2) 90 | STy = upsample(x, sf=sf) 91 | FBFy = FBC*torch.fft.fftn(STy, dim=(-2, -1)) 92 | return FB, FBC, F2B, FBFy 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /utils/utils_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | import numpy as np 5 | import torch 6 | import cv2 7 | import utils_image as util 8 | 9 | ''' 10 | # ======================================= 11 | # image processing process on numpy image 12 | # augment(img_list, hflip=True, rot=True): 13 | # ======================================= 14 | ''' 15 | # ---------------------------------------- 16 | # get uint8 image of size HxWxn_channles (RGB) 17 | # ---------------------------------------- 18 | def imread_uint(path, n_channels=3): 19 | # input: path 20 | # output: HxWx3(RGB or GGG), or HxWx1 (G) 21 | if n_channels == 1: 22 | img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE 23 | img = np.expand_dims(img, axis=2) # HxWx1 24 | elif n_channels == 3: 25 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G 26 | if img.ndim == 2: 27 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG 28 | else: 29 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB 30 | return img 31 | 32 | 33 | def augment_img(img, mode=0): 34 | if mode == 0: 35 | return img 36 | elif mode == 1: 37 | return np.flipud(np.rot90(img)) 38 | elif mode == 2: 39 | return np.flipud(img) 40 | elif mode == 3: 41 | return np.rot90(img, k=3) 42 | elif mode == 4: 43 | return np.flipud(np.rot90(img, k=2)) 44 | elif mode == 5: 45 | return np.rot90(img) 46 | elif mode == 6: 47 | return np.rot90(img, k=2) 48 | elif mode == 7: 49 | return np.flipud(np.rot90(img, k=3)) 50 | 51 | 52 | def augment_img_tensor4(img, mode=0): 53 | if mode == 0: 54 | return img 55 | elif mode == 1: 56 | return np.flipud(np.rot90(img)) 57 | elif mode == 2: 58 | return np.flipud(img) 59 | elif mode == 3: 60 | return np.rot90(img, k=3) 61 | elif mode == 4: 62 | return np.flipud(np.rot90(img, k=2)) 63 | elif mode == 5: 64 | return np.rot90(img) 65 | elif mode == 6: 66 | return np.rot90(img, k=2) 67 | elif mode == 7: 68 | return np.flipud(np.rot90(img, k=3)) 69 | 70 | 71 | def augment_img_np3(img, mode=0): 72 | if mode == 0: 73 | return img 74 | elif mode == 1: 75 | return img.transpose(1, 0, 2) 76 | elif mode == 2: 77 | return img[::-1, :, :] 78 | elif mode == 3: 79 | img = img[::-1, :, :] 80 | img = img.transpose(1, 0, 2) 81 | return img 82 | elif mode == 4: 83 | return img[:, ::-1, :] 84 | elif mode == 5: 85 | img = img[:, ::-1, :] 86 | img = img.transpose(1, 0, 2) 87 | return img 88 | elif mode == 6: 89 | img = img[:, ::-1, :] 90 | img = img[::-1, :, :] 91 | return img 92 | elif mode == 7: 93 | img = img[:, ::-1, :] 94 | img = img[::-1, :, :] 95 | img = img.transpose(1, 0, 2) 96 | return img 97 | 98 | 99 | def augment_img_tensor(img, mode=0): 100 | img_size = img.size() 101 | img_np = img.data.cpu().numpy() 102 | if len(img_size) == 3: 103 | img_np = np.transpose(img_np, (1, 2, 0)) 104 | elif len(img_size) == 4: 105 | img_np = np.transpose(img_np, (2, 3, 1, 0)) 106 | img_np = augment_img(img_np, mode=mode) 107 | img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) 108 | if len(img_size) == 3: 109 | img_tensor = img_tensor.permute(2, 0, 1) 110 | elif len(img_size) == 4: 111 | img_tensor = img_tensor.permute(3, 2, 0, 1) 112 | 113 | return img_tensor.type_as(img) 114 | 115 | 116 | def augment_imgs(img_list, hflip=True, rot=True): 117 | # horizontal flip OR rotate 118 | hflip = hflip and random.random() < 0.5 119 | vflip = rot and random.random() < 0.5 120 | rot90 = rot and random.random() < 0.5 121 | 122 | def _augment(img): 123 | if hflip: 124 | img = img[:, ::-1, :] 125 | if vflip: 126 | img = img[::-1, :, :] 127 | if rot90: 128 | img = img.transpose(1, 0, 2) 129 | return img 130 | 131 | return [_augment(img) for img in img_list] 132 | 133 | 134 | 135 | def p2o(psf, shape): 136 | otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) 137 | otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) 138 | for axis, axis_size in enumerate(psf.shape[2:]): 139 | otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) 140 | otf = torch.rfft(otf, 2, onesided=False) 141 | n_ops = torch.sum(psf.size * torch.log2(psf.shape)) 142 | otf[...,1][torch.asb(otf[...,1])