├── LICENSE
├── README.md
├── figs
├── color_psnr.png
├── deblur_1.png
├── demosaic_1.png
├── demosaic_2.png
├── denoiser_arch.png
├── grayscale_psnr.png
├── sisr_1.png
├── test_03_noisy_0750.png
└── test_03_usrnet_2355.png
├── kernels
├── Levin09.mat
├── kernels_12.mat
└── kernels_bicubicx234.mat
├── main_download_pretrained_models.py
├── main_dpir_deblocking_color.py
├── main_dpir_deblocking_grayscale.py
├── main_dpir_deblur.py
├── main_dpir_demosaick.py
├── main_dpir_denoising.py
├── main_dpir_sisr.py
├── main_dpir_sisr_real_applications.py
├── model_zoo
└── README.md
├── models
├── basicblock.py
├── network_dncnn.py
└── network_unet.py
├── testsets
├── set12
│ ├── 01.png
│ ├── 02.png
│ ├── 03.png
│ ├── 04.png
│ ├── 05.png
│ ├── 06.png
│ ├── 07.png
│ ├── 08.png
│ ├── 09.png
│ ├── 10.png
│ ├── 11.png
│ └── 12.png
├── set3c
│ ├── butterfly.png
│ ├── leaves.png
│ └── starfish.png
└── set5
│ ├── baby_GT.bmp
│ ├── bird_GT.bmp
│ ├── butterfly_GT.bmp
│ ├── head_GT.bmp
│ └── woman_GT.bmp
└── utils
├── test.bmp
├── utils_bnorm.py
├── utils_csmri.py
├── utils_deblur.py
├── utils_image.py
├── utils_inpaint.py
├── utils_logger.py
├── utils_model.py
├── utils_mosaic.py
├── utils_pnp.py
├── utils_sisr.py
├── utils_sisr_beforepytorchversion8.py
└── utils_test.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Kai Zhang (cskaizhang@gmail.com)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Plug-and-Play Image Restoration
2 |
3 | 
4 |
5 | [**Kai Zhang**](https://cszn.github.io/), Yawei Li, Wangmeng Zuo, Lei Zhang, Luc Van Gool, Radu Timofte
6 |
7 | _[Computer Vision Lab](https://vision.ee.ethz.ch/the-institute.html), ETH Zurich, Switzerland_
8 |
9 | [[paper arxiv](https://arxiv.org/pdf/2008.13751.pdf)] [[paper tpami](https://ieeexplore.ieee.org/abstract/document/9454311)]
10 |
11 |
12 |
13 |
14 | Denoising results on BSD68 and Urban100 datasets
15 | ----------
16 | | Dataset | Noise Level | FFDNet-PSNR(RGB) | FFDNet-PSNR(Y) | **DRUNet-PSNR(RGB)** | **DRUNet-PSNR(Y)** |
17 | |:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|
18 | | CBSD68 | 30 | 30.32 | 32.05 | 30.81 | 32.44 |
19 | | CBSD68 | 50 | 27.97 | 29.65 | 28.51 | 30.09 |
20 | | Urban100| 30 | 30.53 | 32.72 | 31.83 | 33.93 |
21 | | Urban100| 50 | 28.05 | 30.09 | 29.61 | 31.57 |
22 |
23 | ```PSNR(Y) means the PSNR is calculated on the Y channel of YCbCr space.```
24 |
25 |
26 | Abstract
27 | ----------
28 | Recent works on plug-and-play image restoration have shown that a denoiser can implicitly serve as the image prior for
29 | model-based methods to solve many inverse problems. Such a property induces considerable advantages for plug-and-play image
30 | restoration (e.g., integrating the flexibility of model-based method and effectiveness of learning-based methods) when the denoiser is
31 | discriminatively learned via deep convolutional neural network (CNN) with large modeling capacity. However, while deeper and larger
32 | CNN models are rapidly gaining popularity, existing plug-and-play image restoration hinders its performance due to the lack of suitable
33 | denoiser prior. In order to push the limits of plug-and-play image restoration, we set up a benchmark deep denoiser prior by training a
34 | highly flexible and effective CNN denoiser. We then plug the deep denoiser prior as a modular part into a half quadratic splitting based
35 | iterative algorithm to solve various image restoration problems. We, meanwhile, provide a thorough analysis of parameter setting,
36 | intermediate results and empirical convergence to better understand the working mechanism. Experimental results on three
37 | representative image restoration tasks, including deblurring, super-resolution and demosaicing, demonstrate that the proposed
38 | plug-and-play image restoration with deep denoiser prior not only significantly outperforms other state-of-the-art model-based methods
39 | but also achieves competitive or even superior performance against state-of-the-art learning-based methods.
40 |
41 |
42 | The DRUNet Denoiser (state-of-the-art Gaussian denoising performance!)
43 | ----------
44 | * Network architecture
45 |
46 |
47 |
48 | * Grayscale image denoising
49 |
50 |
51 |
52 | * Color image denoising
53 |
54 |
55 |
56 | |
|
|
57 | |:---:|:---:|
58 | |(a) Noisy image with noise level 200|(b) Result by the proposed DRUNet denoiser|
59 |
60 | **Even trained on noise level range of [0, 50], DRUNet can still perform well on an extremely large unseen noise level of 200.**
61 |
62 | Image Deblurring
63 | ----------
64 | * Visual comparison
65 |
66 |
67 |
68 |
69 | Single Image Super-Resolution
70 | ----------
71 | * Visual comparison
72 |
73 |
74 |
75 |
76 | Color Image Demosaicing
77 | ----------
78 | * PSNR
79 |
80 |
81 |
82 | * Visual comparison
83 |
84 |
85 |
86 |
87 |
88 | Citation
89 | ----------
90 | ```BibTex
91 | @article{zhang2021plug,
92 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior},
93 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu},
94 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
95 | volume={44},
96 | number={10},
97 | pages={6360-6376},
98 | year={2021}
99 | }
100 | @inproceedings{zhang2017learning,
101 | title={Learning Deep CNN Denoiser Prior for Image Restoration},
102 | author={Zhang, Kai and Zuo, Wangmeng and Gu, Shuhang and Zhang, Lei},
103 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
104 | pages={3929--3938},
105 | year={2017},
106 | }
107 | ```
108 |
--------------------------------------------------------------------------------
/figs/color_psnr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/color_psnr.png
--------------------------------------------------------------------------------
/figs/deblur_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/deblur_1.png
--------------------------------------------------------------------------------
/figs/demosaic_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/demosaic_1.png
--------------------------------------------------------------------------------
/figs/demosaic_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/demosaic_2.png
--------------------------------------------------------------------------------
/figs/denoiser_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/denoiser_arch.png
--------------------------------------------------------------------------------
/figs/grayscale_psnr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/grayscale_psnr.png
--------------------------------------------------------------------------------
/figs/sisr_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/sisr_1.png
--------------------------------------------------------------------------------
/figs/test_03_noisy_0750.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/test_03_noisy_0750.png
--------------------------------------------------------------------------------
/figs/test_03_usrnet_2355.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/figs/test_03_usrnet_2355.png
--------------------------------------------------------------------------------
/kernels/Levin09.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/kernels/Levin09.mat
--------------------------------------------------------------------------------
/kernels/kernels_12.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/kernels/kernels_12.mat
--------------------------------------------------------------------------------
/kernels/kernels_bicubicx234.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/kernels/kernels_bicubicx234.mat
--------------------------------------------------------------------------------
/main_download_pretrained_models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import requests
4 | import re
5 |
6 |
7 | """
8 | How to use:
9 |
10 | download models:
11 | python main_download_pretrained_models.py --models "DPIR IRCNN" --model_dir "model_zoo"
12 |
13 | """
14 |
15 |
16 | def download_pretrained_model(model_dir='model_zoo', model_name='dncnn3.pth'):
17 | if os.path.exists(os.path.join(model_dir, model_name)):
18 | print(f'already exists, skip downloading [{model_name}]')
19 | else:
20 | os.makedirs(model_dir, exist_ok=True)
21 | if 'SwinIR' in model_name:
22 | url = 'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/{}'.format(model_name)
23 | else:
24 | url = 'https://github.com/cszn/KAIR/releases/download/v1.0/{}'.format(model_name)
25 | r = requests.get(url, allow_redirects=True)
26 | print(f'downloading [{model_dir}/{model_name}] ...')
27 | open(os.path.join(model_dir, model_name), 'wb').write(r.content)
28 | print('done!')
29 |
30 |
31 | if __name__ == '__main__':
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--models',
34 | type=lambda s: re.split(' |, ', s),
35 | default = "dncnn3.pth",
36 | help='comma or space delimited list of characters, e.g., "DnCNN", "DnCNN BSRGAN.pth", "dncnn_15.pth dncnn_50.pth"')
37 | parser.add_argument('--model_dir', type=str, default='model_zoo', help='path of model_zoo')
38 | args = parser.parse_args()
39 |
40 | print(f'trying to download {args.models}')
41 |
42 | method_model_zoo = {'DnCNN': ['dncnn_15.pth', 'dncnn_25.pth', 'dncnn_50.pth', 'dncnn3.pth', 'dncnn_color_blind.pth', 'dncnn_gray_blind.pth'],
43 | 'SRMD': ['srmdnf_x2.pth', 'srmdnf_x3.pth', 'srmdnf_x4.pth', 'srmd_x2.pth', 'srmd_x3.pth', 'srmd_x4.pth'],
44 | 'DPSR': ['dpsr_x2.pth', 'dpsr_x3.pth', 'dpsr_x4.pth', 'dpsr_x4_gan.pth'],
45 | 'FFDNet': ['ffdnet_color.pth', 'ffdnet_gray.pth', 'ffdnet_color_clip.pth', 'ffdnet_gray_clip.pth'],
46 | 'USRNet': ['usrgan.pth', 'usrgan_tiny.pth', 'usrnet.pth', 'usrnet_tiny.pth'],
47 | 'DPIR': ['drunet_gray.pth', 'drunet_color.pth', 'drunet_deblocking_color.pth', 'drunet_deblocking_grayscale.pth'],
48 | 'BSRGAN': ['BSRGAN.pth', 'BSRNet.pth', 'BSRGANx2.pth'],
49 | 'IRCNN': ['ircnn_color.pth', 'ircnn_gray.pth'],
50 | 'SwinIR': ['001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x3.pth',
51 | '001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth',
52 | '001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x3.pth',
53 | '001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x8.pth',
54 | '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth', '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x3.pth',
55 | '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x4.pth', '003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth',
56 | '003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_PSNR.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth',
57 | '004_grayDN_DFWB_s128w8_SwinIR-M_noise25.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise50.pth',
58 | '005_colorDN_DFWB_s128w8_SwinIR-M_noise15.pth', '005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth',
59 | '005_colorDN_DFWB_s128w8_SwinIR-M_noise50.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg10.pth',
60 | '006_CAR_DFWB_s126w7_SwinIR-M_jpeg20.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg30.pth',
61 | '006_CAR_DFWB_s126w7_SwinIR-M_jpeg40.pth'],
62 | 'others': ['msrresnet_x4_psnr.pth', 'msrresnet_x4_gan.pth', 'imdn_x4.pth', 'RRDB.pth', 'ESRGAN.pth',
63 | 'FSSR_DPED.pth', 'FSSR_JPEG.pth', 'RealSR_DPED.pth', 'RealSR_JPEG.pth']
64 | }
65 |
66 | method_zoo = list(method_model_zoo.keys())
67 | model_zoo = []
68 | for b in list(method_model_zoo.values()):
69 | model_zoo += b
70 |
71 | if 'all' in args.models:
72 | for method in method_zoo:
73 | for model_name in method_model_zoo[method]:
74 | download_pretrained_model(args.model_dir, model_name)
75 | else:
76 | for method_model in args.models:
77 | if method_model in method_zoo: # method, need for loop
78 | for model_name in method_model_zoo[method_model]:
79 | if 'SwinIR' in model_name:
80 | download_pretrained_model(os.path.join(args.model_dir, 'swinir'), model_name)
81 | else:
82 | download_pretrained_model(args.model_dir, model_name)
83 | elif method_model in model_zoo: # model, do not need for loop
84 | if 'SwinIR' in method_model:
85 | download_pretrained_model(os.path.join(args.model_dir, 'swinir'), method_model)
86 | else:
87 | download_pretrained_model(args.model_dir, method_model)
88 | else:
89 | print(f'Do not find {method_model} from the pre-trained model zoo!')
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/main_dpir_deblocking_color.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import logging
3 | import numpy as np
4 | from collections import OrderedDict
5 |
6 | import torch
7 | from utils import utils_logger
8 | from utils import utils_image as util
9 | import cv2
10 |
11 |
12 |
13 |
14 | def main():
15 |
16 | # ----------------------------------------
17 | # Preparation
18 | # ----------------------------------------
19 | model_name = 'drunet_color'
20 | quality_factors = [10, 20, 30, 40]
21 | testset_name = 'LIVE1' # test set, 'LIVE1'
22 | need_degradation = True # default: True
23 |
24 | task_current = 'db' # 'dn' for deblocking
25 | n_channels = 3 # fixed
26 | model_pool = 'model_zoo' # fixed
27 | testsets = 'testsets' # fixed
28 | results = 'results' # fixed
29 | noise_level_img = 0 # fixed: 0, noise level for LR image
30 | result_name = testset_name + '_' + model_name + '_' + task_current
31 | border = 0 # shave boader to calculate PSNR and SSIM
32 |
33 | # ----------------------------------------
34 | # L_path, E_path, H_path
35 | # ----------------------------------------
36 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
37 | E_path = os.path.join(results, result_name) # E_path, for Estimated images
38 | util.mkdir(E_path)
39 |
40 | logger_name = result_name
41 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
42 | logger = logging.getLogger(logger_name)
43 |
44 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45 |
46 | # ----------------------------------------
47 | # load model
48 | # ----------------------------------------
49 | model_path = os.path.join('model_zoo', 'drunet_deblocking_color.pth')
50 | from models.network_unet import UNetRes as net
51 | model = net(in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', bias=False) # define network
52 |
53 | model.load_state_dict(torch.load(model_path), strict=True)
54 | model.eval()
55 | for k, v in model.named_parameters():
56 | v.requires_grad = False
57 |
58 | model = model.to(device)
59 | logger.info('Model path: {:s}'.format(model_path))
60 | number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
61 | logger.info('Params number: {}'.format(number_parameters))
62 | L_paths = util.get_image_paths(L_path)
63 |
64 | for quality_factor in quality_factors:
65 |
66 | test_results = OrderedDict()
67 | test_results['psnr'] = []
68 | test_results['ssim'] = []
69 |
70 | logger.info('model_name:{}, quality factor:{}'.format(model_name, quality_factor))
71 | logger.info(L_path)
72 |
73 | for idx, img in enumerate(L_paths):
74 |
75 | # ------------------------------------
76 | # (1) img_L
77 | # ------------------------------------
78 | img_name, ext = os.path.splitext(os.path.basename(img))
79 | logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
80 | img_L = util.imread_uint(img, n_channels=n_channels)
81 | img_H = img_L.copy()
82 |
83 | # ------------------------------------
84 | # Do the JPEG compression
85 | # ------------------------------------
86 | if need_degradation:
87 | img_L = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR)
88 | result, encimg = cv2.imencode('.jpg', img_L, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
89 | img_L = cv2.imdecode(encimg, 1)
90 | img_L = cv2.cvtColor(img_L, cv2.COLOR_BGR2RGB)
91 |
92 | img_L = util.uint2tensor4(img_L)
93 |
94 | noise_level = (100-quality_factor)/100.0
95 | noise_level = torch.FloatTensor([noise_level])
96 | noise_level_map = torch.ones((1,1, img_L.shape[2], img_L.shape[3])).mul_(noise_level).float()
97 | img_L = torch.cat((img_L, noise_level_map), 1)
98 |
99 | img_L = img_L.to(device)
100 |
101 | # ------------------------------------
102 | # (2) img_E
103 | # ------------------------------------
104 | img_E = model(img_L)
105 | img_E = util.tensor2uint(img_E)
106 |
107 | if need_degradation:
108 |
109 | img_H = img_H.squeeze()
110 | # --------------------------------
111 | # PSNR and SSIM
112 | # --------------------------------
113 | psnr = util.calculate_psnr(img_E, img_H, border=border)
114 | ssim = util.calculate_ssim(img_E, img_H, border=border)
115 | test_results['psnr'].append(psnr)
116 | test_results['ssim'].append(ssim)
117 | logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim))
118 |
119 | # ------------------------------------
120 | # save results
121 | # ------------------------------------
122 | util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'_'+str(quality_factor)+'.png'))
123 |
124 | if need_degradation:
125 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
126 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
127 | logger.info('Average PSNR/SSIM(RGB) - {} - qf{} --PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, quality_factor, ave_psnr, ave_ssim))
128 |
129 |
130 | if __name__ == '__main__':
131 |
132 | main()
133 |
--------------------------------------------------------------------------------
/main_dpir_deblocking_grayscale.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import logging
3 | import numpy as np
4 | from collections import OrderedDict
5 |
6 | import torch
7 |
8 | from utils import utils_logger
9 | from utils import utils_image as util
10 | import cv2
11 |
12 | '''
13 | Spyder (Python 3.7)
14 | PyTorch 1.8.1
15 | Windows 10 or Linux
16 |
17 | If you have any question, please feel free to contact with me.
18 | Kai Zhang (e-mail: cskaizhang@gmail.com)
19 | (github: https://github.com/cszn/DPIR)
20 | (github: https://github.com/cszn/KAIR)
21 | by Kai Zhang (06/June/2021)
22 |
23 |
24 | How to run to get the results in Table 3:
25 | Step 1: download 'classic5' and 'LIVE1' testing dataset from https://github.com/cszn/DnCNN/tree/master/testsets
26 | Step 2: download 'drunet_deblocking_grayscale.pth' model and 'dncnn3.pth' model, and put it into 'model_zoo'
27 | 'drunet_deblocking_grayscale.pth': https://drive.google.com/file/d/1ySemeOINvVfraFi_SZxZ93UuV4hMzk8g/view?usp=sharing
28 | 'dncnn3.pth': https://drive.google.com/file/d/1wwTFLFbS3AWowuNbe1XsEd_VCa2kof5I/view?usp=sharing
29 | '''
30 |
31 |
32 | def main():
33 |
34 | # ----------------------------------------
35 | # Preparation
36 | # ----------------------------------------
37 | model_name = 'drunet'
38 | quality_factors = [10, 20, 30, 40]
39 | testset_name = 'classic5' # test set, 'classic5' | 'LIVE1'
40 | need_degradation = True # default: True
41 |
42 | task_current = 'db' # 'db' for JPEG image deblocking
43 |
44 | model_pool = 'model_zoo' # fixed
45 | testsets = 'testsets' # fixed
46 | results = 'results' # fixed
47 | result_name = testset_name + '_' + model_name + '_' + task_current
48 | border = 0 # shave boader to calculate PSNR and SSIM
49 |
50 | # ----------------------------------------
51 | # L_path, E_path, H_path
52 | # ----------------------------------------
53 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
54 | E_path = os.path.join(results, result_name) # E_path, for Estimated images
55 | util.mkdir(E_path)
56 |
57 | logger_name = result_name
58 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
59 | logger = logging.getLogger(logger_name)
60 |
61 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
62 |
63 | # ----------------------------------------
64 | # load model
65 | # ----------------------------------------
66 | if model_name == 'dncnn3':
67 | model_path = os.path.join(model_pool, model_name+'.pth')
68 | from models.network_dncnn import DnCNN as net
69 | model = net(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='R')
70 | model_path = os.path.join('model_zoo', 'dncnn3.pth')
71 | else:
72 | model_name = 'drunet'
73 | model_path = os.path.join('model_zoo', 'drunet_deblocking_grayscale.pth')
74 | from models.network_unet import UNetRes as net
75 | model = net(in_nc=2, out_nc=1, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', bias=False)
76 |
77 | model.load_state_dict(torch.load(model_path), strict=True)
78 | model.eval()
79 | for k, v in model.named_parameters():
80 | v.requires_grad = False
81 |
82 | model = model.to(device)
83 | logger.info('Model path: {:s}'.format(model_path))
84 | number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
85 | logger.info('Params number: {}'.format(number_parameters))
86 | L_paths = util.get_image_paths(L_path)
87 |
88 | for quality_factor in quality_factors:
89 |
90 | test_results = OrderedDict()
91 | test_results['psnr'] = []
92 | test_results['ssim'] = []
93 | test_results['psnr_y'] = []
94 | test_results['ssim_y'] = []
95 |
96 | logger.info('model_name:{}, quality factor:{}'.format(model_name, quality_factor))
97 |
98 | for idx, img in enumerate(L_paths):
99 |
100 | # ------------------------------------
101 | # (1) img_L
102 | # ------------------------------------
103 | img_name, ext = os.path.splitext(os.path.basename(img))
104 | logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
105 |
106 | img_L = cv2.imread(img, cv2.IMREAD_UNCHANGED) # BGR or G
107 | grayscale = True if img_L.ndim == 2 else False
108 | if not grayscale:
109 | img_L = cv2.cvtColor(img_L, cv2.COLOR_BGR2RGB) # RGB
110 | img_L_ycbcr = util.rgb2ycbcr(img_L, only_y=False)
111 | img_L = img_L_ycbcr[..., 0] # we operate on Y channel for color images
112 |
113 | img_H = img_L.copy()
114 |
115 | # ------------------------------------
116 | # Do the JPEG compression
117 | # ------------------------------------
118 | if need_degradation:
119 | result, encimg = cv2.imencode('.jpg', img_L, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
120 | img_L = cv2.imdecode(encimg, 0)
121 |
122 | img_L = util.uint2tensor4(img_L[..., np.newaxis])
123 |
124 | if model_name == 'drunet':
125 | noise_level = (100-quality_factor)/100.0
126 | noise_level = torch.FloatTensor([noise_level])
127 | noise_level_map = torch.ones((1,1, img_L.shape[2], img_L.shape[3])).mul_(noise_level).float()
128 | img_L = torch.cat((img_L, noise_level_map), 1)
129 |
130 | img_L = img_L.to(device)
131 |
132 | # ------------------------------------
133 | # (2) img_E
134 | # ------------------------------------
135 | img_E = model(img_L)
136 | img_E = util.tensor2uint(img_E)
137 |
138 | if need_degradation:
139 |
140 | # --------------------------------
141 | # PSNR and SSIM
142 | # --------------------------------
143 |
144 | psnr = util.calculate_psnr(img_E, img_H, border=border)
145 | ssim = util.calculate_ssim(img_E, img_H, border=border)
146 | test_results['psnr'].append(psnr)
147 | test_results['ssim'].append(ssim)
148 | logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim))
149 |
150 | util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'_'+str(quality_factor)+'.png'))
151 | if not grayscale:
152 | img_L_ycbcr[..., 0] = img_E
153 | img_E_rgb = util.ycbcr2rgb(img_L_ycbcr)
154 | util.imsave(img_E_rgb, os.path.join(E_path, img_name+'_'+model_name+'_'+str(quality_factor)+'_rgb.png'))
155 |
156 | if need_degradation:
157 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
158 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
159 | logger.info('Average PSNR/SSIM(RGB) - {} - qf{} --PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, quality_factor, ave_psnr, ave_ssim))
160 |
161 | if __name__ == '__main__':
162 |
163 | main()
164 |
--------------------------------------------------------------------------------
/main_dpir_deblur.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import cv2
3 | import logging
4 |
5 | import numpy as np
6 | from datetime import datetime
7 | from collections import OrderedDict
8 | import hdf5storage
9 | from scipy import ndimage
10 |
11 | import torch
12 |
13 | from utils import utils_deblur
14 | from utils import utils_logger
15 | from utils import utils_model
16 | from utils import utils_pnp as pnp
17 | from utils import utils_sisr as sr
18 | from utils import utils_image as util
19 |
20 |
21 | """
22 | Spyder (Python 3.7)
23 | PyTorch 1.6.0
24 | Windows 10 or Linux
25 | Kai Zhang (cskaizhang@gmail.com)
26 | github: https://github.com/cszn/DPIR
27 | https://github.com/cszn/IRCNN
28 | https://github.com/cszn/KAIR
29 | @article{zhang2020plug,
30 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior},
31 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu},
32 | journal={arXiv preprint},
33 | year={2020}
34 | }
35 | % If you have any question, please feel free to contact with me.
36 | % Kai Zhang (e-mail: cskaizhang@gmail.com; homepage: https://cszn.github.io/)
37 | by Kai Zhang (01/August/2020)
38 |
39 | # --------------------------------------------
40 | |--model_zoo # model_zoo
41 | |--drunet_gray # model_name, for color images
42 | |--drunet_color
43 | |--testset # testsets
44 | |--results # results
45 | # --------------------------------------------
46 | """
47 |
48 | def main():
49 |
50 | # ----------------------------------------
51 | # Preparation
52 | # ----------------------------------------
53 |
54 | noise_level_img = 7.65/255.0 # default: 0, noise level for LR image
55 | noise_level_model = noise_level_img # noise level of model, default 0
56 | model_name = 'drunet_gray' # 'drunet_gray' | 'drunet_color' | 'ircnn_gray' | 'ircnn_color'
57 | testset_name = 'Set3C' # test set, 'set5' | 'srbsd68'
58 | x8 = True # default: False, x8 to boost performance
59 | iter_num = 8 # number of iterations
60 | modelSigma1 = 49
61 | modelSigma2 = noise_level_model*255.
62 |
63 | show_img = False # default: False
64 | save_L = True # save LR image
65 | save_E = True # save estimated image
66 | save_LEH = False # save zoomed LR, E and H images
67 | border = 0
68 |
69 | # --------------------------------
70 | # load kernel
71 | # --------------------------------
72 |
73 | kernels = hdf5storage.loadmat(os.path.join('kernels', 'Levin09.mat'))['kernels']
74 |
75 | sf = 1
76 | task_current = 'deblur' # 'deblur' for deblurring
77 | n_channels = 3 if 'color' in model_name else 1 # fixed
78 | model_zoo = 'model_zoo' # fixed
79 | testsets = 'testsets' # fixed
80 | results = 'results' # fixed
81 | result_name = testset_name + '_' + task_current + '_' + model_name
82 | model_path = os.path.join(model_zoo, model_name+'.pth')
83 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84 | torch.cuda.empty_cache()
85 |
86 | # ----------------------------------------
87 | # L_path, E_path, H_path
88 | # ----------------------------------------
89 |
90 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
91 | E_path = os.path.join(results, result_name) # E_path, for Estimated images
92 | util.mkdir(E_path)
93 |
94 | logger_name = result_name
95 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
96 | logger = logging.getLogger(logger_name)
97 |
98 | # ----------------------------------------
99 | # load model
100 | # ----------------------------------------
101 |
102 | if 'drunet' in model_name:
103 | from models.network_unet import UNetRes as net
104 | model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
105 | model.load_state_dict(torch.load(model_path), strict=True)
106 | model.eval()
107 | for _, v in model.named_parameters():
108 | v.requires_grad = False
109 | model = model.to(device)
110 | elif 'ircnn' in model_name:
111 | from models.network_dncnn import IRCNN as net
112 | model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
113 | model25 = torch.load(model_path)
114 | former_idx = 0
115 |
116 | logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model))
117 | logger.info('Model path: {:s}'.format(model_path))
118 | logger.info(L_path)
119 | L_paths = util.get_image_paths(L_path)
120 |
121 | test_results_ave = OrderedDict()
122 | test_results_ave['psnr'] = [] # record average PSNR for each kernel
123 |
124 | for k_index in range(kernels.shape[1]):
125 |
126 | logger.info('-------k:{:>2d} ---------'.format(k_index))
127 | test_results = OrderedDict()
128 | test_results['psnr'] = []
129 | k = kernels[0, k_index].astype(np.float64)
130 | util.imshow(k) if show_img else None
131 |
132 | for idx, img in enumerate(L_paths):
133 |
134 | # --------------------------------
135 | # (1) get img_L
136 | # --------------------------------
137 |
138 | img_name, ext = os.path.splitext(os.path.basename(img))
139 | img_H = util.imread_uint(img, n_channels=n_channels)
140 | img_H = util.modcrop(img_H, 8) # modcrop
141 |
142 | img_L = ndimage.filters.convolve(img_H, np.expand_dims(k, axis=2), mode='wrap')
143 | util.imshow(img_L) if show_img else None
144 | img_L = util.uint2single(img_L)
145 |
146 | np.random.seed(seed=0) # for reproducibility
147 | img_L += np.random.normal(0, noise_level_img, img_L.shape) # add AWGN
148 |
149 | # --------------------------------
150 | # (2) get rhos and sigmas
151 | # --------------------------------
152 |
153 | rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_model), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1.0)
154 | rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device)
155 |
156 | # --------------------------------
157 | # (3) initialize x, and pre-calculation
158 | # --------------------------------
159 |
160 | x = util.single2tensor4(img_L).to(device)
161 |
162 | img_L_tensor, k_tensor = util.single2tensor4(img_L), util.single2tensor4(np.expand_dims(k, 2))
163 | [k_tensor, img_L_tensor] = util.todevice([k_tensor, img_L_tensor], device)
164 | FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf)
165 |
166 | # --------------------------------
167 | # (4) main iterations
168 | # --------------------------------
169 |
170 | for i in range(iter_num):
171 |
172 | # --------------------------------
173 | # step 1, FFT
174 | # --------------------------------
175 |
176 | tau = rhos[i].float().repeat(1, 1, 1, 1)
177 | x = sr.data_solution(x, FB, FBC, F2B, FBFy, tau, sf)
178 |
179 | if 'ircnn' in model_name:
180 | current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1)
181 |
182 | if current_idx != former_idx:
183 | model.load_state_dict(model25[str(current_idx)], strict=True)
184 | model.eval()
185 | for _, v in model.named_parameters():
186 | v.requires_grad = False
187 | model = model.to(device)
188 | former_idx = current_idx
189 |
190 | # --------------------------------
191 | # step 2, denoiser
192 | # --------------------------------
193 |
194 | if x8:
195 | x = util.augment_img_tensor4(x, i % 8)
196 |
197 | if 'drunet' in model_name:
198 | x = torch.cat((x, sigmas[i].float().repeat(1, 1, x.shape[2], x.shape[3])), dim=1)
199 | x = utils_model.test_mode(model, x, mode=2, refield=32, min_size=256, modulo=16)
200 | elif 'ircnn' in model_name:
201 | x = model(x)
202 |
203 | if x8:
204 | if i % 8 == 3 or i % 8 == 5:
205 | x = util.augment_img_tensor4(x, 8 - i % 8)
206 | else:
207 | x = util.augment_img_tensor4(x, i % 8)
208 |
209 | # --------------------------------
210 | # (3) img_E
211 | # --------------------------------
212 |
213 | img_E = util.tensor2uint(x)
214 | if n_channels == 1:
215 | img_H = img_H.squeeze()
216 |
217 | if save_E:
218 | util.imsave(img_E, os.path.join(E_path, img_name+'_k'+str(k_index)+'_'+model_name+'.png'))
219 |
220 | # --------------------------------
221 | # (4) img_LEH
222 | # --------------------------------
223 |
224 | if save_LEH:
225 | img_L = util.single2uint(img_L)
226 | k_v = k/np.max(k)*1.0
227 | k_v = util.single2uint(np.tile(k_v[..., np.newaxis], [1, 1, 3]))
228 | k_v = cv2.resize(k_v, (3*k_v.shape[1], 3*k_v.shape[0]), interpolation=cv2.INTER_NEAREST)
229 | img_I = cv2.resize(img_L, (sf*img_L.shape[1], sf*img_L.shape[0]), interpolation=cv2.INTER_NEAREST)
230 | img_I[:k_v.shape[0], -k_v.shape[1]:, :] = k_v
231 | img_I[:img_L.shape[0], :img_L.shape[1], :] = img_L
232 | util.imshow(np.concatenate([img_I, img_E, img_H], axis=1), title='LR / Recovered / Ground-truth') if show_img else None
233 | util.imsave(np.concatenate([img_I, img_E, img_H], axis=1), os.path.join(E_path, img_name+'_k'+str(k_index)+'_LEH.png'))
234 |
235 | if save_L:
236 | util.imsave(util.single2uint(img_L), os.path.join(E_path, img_name+'_k'+str(k_index)+'_LR.png'))
237 |
238 | psnr = util.calculate_psnr(img_E, img_H, border=border) # change with your own border
239 | test_results['psnr'].append(psnr)
240 | logger.info('{:->4d}--> {:>10s} --k:{:>2d} PSNR: {:.2f}dB'.format(idx+1, img_name+ext, k_index, psnr))
241 |
242 |
243 | # --------------------------------
244 | # Average PSNR
245 | # --------------------------------
246 |
247 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
248 | logger.info('------> Average PSNR of ({}), kernel: ({}) sigma: ({:.2f}): {:.2f} dB'.format(testset_name, k_index, noise_level_model, ave_psnr))
249 | test_results_ave['psnr'].append(ave_psnr)
250 |
251 | if __name__ == '__main__':
252 |
253 | main()
254 |
--------------------------------------------------------------------------------
/main_dpir_demosaick.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import cv2
3 | import logging
4 |
5 | import numpy as np
6 | from collections import OrderedDict
7 |
8 | import torch
9 |
10 | from utils import utils_model
11 | from utils import utils_mosaic
12 | from utils import utils_logger
13 | from utils import utils_pnp as pnp
14 | from utils import utils_image as util
15 |
16 |
17 | """
18 | Spyder (Python 3.7)
19 | PyTorch 1.6.0
20 | Windows 10 or Linux
21 | Kai Zhang (cskaizhang@gmail.com)
22 | github: https://github.com/cszn/DPIR
23 | https://github.com/cszn/IRCNN
24 | https://github.com/cszn/KAIR
25 | @article{zhang2020plug,
26 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior},
27 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu},
28 | journal={arXiv preprint},
29 | year={2020}
30 | }
31 | % If you have any question, please feel free to contact with me.
32 | % Kai Zhang (e-mail: cskaizhang@gmail.com; homepage: https://cszn.github.io/)
33 | by Kai Zhang (01/August/2020)
34 |
35 | # --------------------------------------------
36 | |--model_zoo # model_zoo
37 | |--drunet_gray # model_name, for color images
38 | |--drunet_color
39 | |--testset # testsets
40 | |--results # results
41 | # --------------------------------------------
42 |
43 | How to run:
44 | step 1: download [drunet_color.pth, ircnn_color.pth] from https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D
45 | step 2: set your own testset 'testset_name' and parameter setting such as 'noise_level_model', 'iter_num'.
46 | step 3: 'python main_dpir_demosaick.py'
47 |
48 | """
49 |
50 | def main():
51 |
52 | # ----------------------------------------
53 | # Preparation
54 | # ----------------------------------------
55 |
56 | noise_level_img = 0/255.0 # set AWGN noise level for LR image, default: 0
57 | noise_level_model = noise_level_img # set noise level of model, default: 0
58 | model_name = 'ircnn_color' # set denoiser, 'drunet_color' | 'ircnn_color'
59 | testset_name = 'Set18' # set testing set, 'set18' | 'set24'
60 | x8 = True # set PGSE to boost performance, default: True
61 | iter_num = 40 # set number of iterations, default: 40 for demosaicing
62 | modelSigma1 = 49 # set sigma_1, default: 49
63 | modelSigma2 = max(0.6, noise_level_model*255.) # set sigma_2, default
64 | matlab_init = True
65 |
66 | show_img = False # default: False
67 | save_L = True # save LR image
68 | save_E = True # save estimated image
69 | save_LEH = False # save zoomed LR, E and H images
70 | border = 10 # default 10 for demosaicing
71 |
72 | task_current = 'dm' # 'dm' for demosaicing
73 | n_channels = 3 # fixed
74 | model_zoo = 'model_zoo' # fixed
75 | testsets = 'testsets' # fixed
76 | results = 'results' # fixed
77 | result_name = testset_name + '_' + task_current + '_' + model_name
78 | model_path = os.path.join(model_zoo, model_name+'.pth')
79 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
80 | torch.cuda.empty_cache()
81 |
82 | # ----------------------------------------
83 | # L_path, E_path, H_path
84 | # ----------------------------------------
85 |
86 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
87 | E_path = os.path.join(results, result_name) # E_path, for Estimated images
88 | util.mkdir(E_path)
89 |
90 | logger_name = result_name
91 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
92 | logger = logging.getLogger(logger_name)
93 |
94 | # ----------------------------------------
95 | # load model
96 | # ----------------------------------------
97 |
98 | if 'drunet' in model_name:
99 | from models.network_unet import UNetRes as net
100 | model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
101 | model.load_state_dict(torch.load(model_path), strict=True)
102 | model.eval()
103 | for _, v in model.named_parameters():
104 | v.requires_grad = False
105 | model = model.to(device)
106 | elif 'ircnn' in model_name:
107 | from models.network_dncnn import IRCNN as net
108 | model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
109 | model25 = torch.load(model_path)
110 | former_idx = 0
111 |
112 | logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model))
113 | logger.info('Model path: {:s}'.format(model_path))
114 | logger.info(L_path)
115 | L_paths = util.get_image_paths(L_path)
116 |
117 | test_results = OrderedDict()
118 | test_results['psnr'] = []
119 |
120 | for idx, img in enumerate(L_paths):
121 |
122 | # --------------------------------
123 | # (1) get img_H and img_L
124 | # --------------------------------
125 |
126 | idx += 1
127 | img_name, ext = os.path.splitext(os.path.basename(img))
128 | img_H = util.imread_uint(img, n_channels=n_channels)
129 | CFA, CFA4, mosaic, mask = utils_mosaic.mosaic_CFA_Bayer(img_H)
130 |
131 | # --------------------------------
132 | # (2) initialize x
133 | # --------------------------------
134 |
135 | if matlab_init: # matlab demosaicing for initialization
136 | CFA4 = util.uint2tensor4(CFA4).to(device)
137 | x = utils_mosaic.dm_matlab(CFA4)
138 | else:
139 | x = cv2.cvtColor(CFA, cv2.COLOR_BAYER_BG2RGB_EA)
140 | x = util.uint2tensor4(x).to(device)
141 |
142 | img_L = util.tensor2uint(x)
143 | y = util.uint2tensor4(mosaic).to(device)
144 |
145 | util.imshow(img_L) if show_img else None
146 | mask = util.single2tensor4(mask.astype(np.float32)).to(device)
147 |
148 | # --------------------------------
149 | # (3) get rhos and sigmas
150 | # --------------------------------
151 |
152 | rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_img), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1.0)
153 | rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device)
154 |
155 | # --------------------------------
156 | # (4) main iterations
157 | # --------------------------------
158 |
159 | for i in range(iter_num):
160 |
161 | # --------------------------------
162 | # step 1, closed-form solution
163 | # --------------------------------
164 |
165 | x = (y+rhos[i].float()*x).div(mask+rhos[i])
166 |
167 | # --------------------------------
168 | # step 2, denoiser
169 | # --------------------------------
170 |
171 | if 'ircnn' in model_name:
172 | current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1)
173 | if current_idx != former_idx:
174 | model.load_state_dict(model25[str(current_idx)], strict=True)
175 | model.eval()
176 | for _, v in model.named_parameters():
177 | v.requires_grad = False
178 | model = model.to(device)
179 | former_idx = current_idx
180 |
181 | x = torch.clamp(x, 0, 1)
182 | if x8:
183 | x = util.augment_img_tensor4(x, i % 8)
184 |
185 | if 'drunet' in model_name:
186 | x = torch.cat((x, sigmas[i].float().repeat(1, 1, x.shape[2], x.shape[3])), dim=1)
187 | x = utils_model.test_mode(model, x, mode=2, refield=32, min_size=256, modulo=16)
188 | # x = model(x)
189 | elif 'ircnn' in model_name:
190 | x = model(x)
191 |
192 | if x8:
193 | if i % 8 == 3 or i % 8 == 5:
194 | x = util.augment_img_tensor4(x, 8 - i % 8)
195 | else:
196 | x = util.augment_img_tensor4(x, i % 8)
197 |
198 | x[mask.to(torch.bool)] = y[mask.to(torch.bool)]
199 |
200 | # --------------------------------
201 | # (4) img_E
202 | # --------------------------------
203 |
204 | img_E = util.tensor2uint(x)
205 | psnr = util.calculate_psnr(img_E, img_H, border=border)
206 | test_results['psnr'].append(psnr)
207 | logger.info('{:->4d}--> {:>10s} -- PSNR: {:.2f}dB'.format(idx, img_name+ext, psnr))
208 |
209 | if save_E:
210 | util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'.png'))
211 |
212 | if save_L:
213 | util.imsave(img_L, os.path.join(E_path, img_name+'_L.png'))
214 |
215 | if save_LEH:
216 | util.imsave(np.concatenate([img_L, img_E, img_H], axis=1), os.path.join(E_path, img_name+model_name+'_LEH.png'))
217 |
218 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
219 | logger.info('------> Average PSNR(RGB) of ({}) is : {:.2f} dB'.format(testset_name, ave_psnr))
220 |
221 |
222 | if __name__ == '__main__':
223 |
224 | main()
225 |
--------------------------------------------------------------------------------
/main_dpir_denoising.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import logging
3 |
4 | import numpy as np
5 | from collections import OrderedDict
6 |
7 | import torch
8 |
9 | from utils import utils_logger
10 | from utils import utils_model
11 | from utils import utils_image as util
12 |
13 |
14 | """
15 | Spyder (Python 3.7)
16 | PyTorch 1.6.0
17 | Windows 10 or Linux
18 | Kai Zhang (cskaizhang@gmail.com)
19 | github: https://github.com/cszn/DPIR
20 | https://github.com/cszn/IRCNN
21 | https://github.com/cszn/KAIR
22 | @article{zhang2020plug,
23 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior},
24 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu},
25 | journal={arXiv preprint},
26 | year={2020}
27 | }
28 | % If you have any question, please feel free to contact with me.
29 | % Kai Zhang (e-mail: cskaizhang@gmail.com; homepage: https://cszn.github.io/)
30 | by Kai Zhang (01/August/2020)
31 |
32 | # --------------------------------------------
33 | |--model_zoo # model_zoo
34 | |--drunet_gray # model_name, for color images
35 | |--drunet_color
36 | |--testset # testsets
37 | |--set12 # testset_name
38 | |--bsd68
39 | |--cbsd68
40 | |--results # results
41 | |--set12_dn_drunet_gray # result_name = testset_name + '_' + 'dn' + model_name
42 | |--set12_dn_drunet_color
43 | # --------------------------------------------
44 | """
45 |
46 |
47 | def main():
48 |
49 | # ----------------------------------------
50 | # Preparation
51 | # ----------------------------------------
52 |
53 | noise_level_img = 15 # set AWGN noise level for noisy image
54 | noise_level_model = noise_level_img # set noise level for model
55 | model_name = 'drunet_gray' # set denoiser model, 'drunet_gray' | 'drunet_color'
56 | testset_name = 'bsd68' # set test set, 'bsd68' | 'cbsd68' | 'set12'
57 | x8 = False # default: False, x8 to boost performance
58 | show_img = False # default: False
59 | border = 0 # shave boader to calculate PSNR and SSIM
60 |
61 | if 'color' in model_name:
62 | n_channels = 3 # 3 for color image
63 | else:
64 | n_channels = 1 # 1 for grayscale image
65 |
66 | model_pool = 'model_zoo' # fixed
67 | testsets = 'testsets' # fixed
68 | results = 'results' # fixed
69 | task_current = 'dn' # 'dn' for denoising
70 | result_name = testset_name + '_' + task_current + '_' + model_name
71 |
72 | model_path = os.path.join(model_pool, model_name+'.pth')
73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
74 | torch.cuda.empty_cache()
75 |
76 | # ----------------------------------------
77 | # L_path, E_path, H_path
78 | # ----------------------------------------
79 |
80 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
81 | E_path = os.path.join(results, result_name) # E_path, for Estimated images
82 | util.mkdir(E_path)
83 |
84 | logger_name = result_name
85 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
86 | logger = logging.getLogger(logger_name)
87 |
88 | # ----------------------------------------
89 | # load model
90 | # ----------------------------------------
91 |
92 | from models.network_unet import UNetRes as net
93 | model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
94 | model.load_state_dict(torch.load(model_path), strict=True)
95 | model.eval()
96 | for k, v in model.named_parameters():
97 | v.requires_grad = False
98 | model = model.to(device)
99 | logger.info('Model path: {:s}'.format(model_path))
100 | number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
101 | logger.info('Params number: {}'.format(number_parameters))
102 |
103 | test_results = OrderedDict()
104 | test_results['psnr'] = []
105 | test_results['ssim'] = []
106 |
107 | logger.info('model_name:{}, model sigma:{}, image sigma:{}'.format(model_name, noise_level_img, noise_level_model))
108 | logger.info(L_path)
109 | L_paths = util.get_image_paths(L_path)
110 |
111 | for idx, img in enumerate(L_paths):
112 |
113 | # ------------------------------------
114 | # (1) img_L
115 | # ------------------------------------
116 |
117 | img_name, ext = os.path.splitext(os.path.basename(img))
118 | # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
119 | img_H = util.imread_uint(img, n_channels=n_channels)
120 | img_L = util.uint2single(img_H)
121 |
122 | # Add noise without clipping
123 | np.random.seed(seed=0) # for reproducibility
124 | img_L += np.random.normal(0, noise_level_img/255., img_L.shape)
125 |
126 | util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format(noise_level_img)) if show_img else None
127 |
128 | img_L = util.single2tensor4(img_L)
129 | img_L = torch.cat((img_L, torch.FloatTensor([noise_level_model/255.]).repeat(1, 1, img_L.shape[2], img_L.shape[3])), dim=1)
130 | img_L = img_L.to(device)
131 |
132 | # ------------------------------------
133 | # (2) img_E
134 | # ------------------------------------
135 |
136 | if not x8 and img_L.size(2)//8==0 and img_L.size(3)//8==0:
137 | img_E = model(img_L)
138 | elif not x8 and (img_L.size(2)//8!=0 or img_L.size(3)//8!=0):
139 | img_E = utils_model.test_mode(model, img_L, refield=64, mode=5)
140 | elif x8:
141 | img_E = utils_model.test_mode(model, img_L, mode=3)
142 |
143 | img_E = util.tensor2uint(img_E)
144 |
145 | # --------------------------------
146 | # PSNR and SSIM
147 | # --------------------------------
148 |
149 | if n_channels == 1:
150 | img_H = img_H.squeeze()
151 | psnr = util.calculate_psnr(img_E, img_H, border=border)
152 | ssim = util.calculate_ssim(img_E, img_H, border=border)
153 | test_results['psnr'].append(psnr)
154 | test_results['ssim'].append(ssim)
155 | logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim))
156 |
157 | # ------------------------------------
158 | # save results
159 | # ------------------------------------
160 |
161 | util.imsave(img_E, os.path.join(E_path, img_name+ext))
162 |
163 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
164 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
165 | logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim))
166 |
167 |
168 | if __name__ == '__main__':
169 |
170 | main()
171 |
--------------------------------------------------------------------------------
/main_dpir_sisr.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import glob
3 | import cv2
4 | import logging
5 | import time
6 |
7 | import numpy as np
8 | from datetime import datetime
9 | from collections import OrderedDict
10 | import hdf5storage
11 |
12 | import torch
13 |
14 | from utils import utils_deblur
15 | from utils import utils_logger
16 | from utils import utils_model
17 | from utils import utils_pnp as pnp
18 | from utils import utils_sisr as sr
19 | from utils import utils_image as util
20 |
21 |
22 | """
23 | Spyder (Python 3.7)
24 | PyTorch 1.6.0
25 | Windows 10 or Linux
26 | Kai Zhang (cskaizhang@gmail.com)
27 | github: https://github.com/cszn/DPIR
28 | https://github.com/cszn/IRCNN
29 | https://github.com/cszn/KAIR
30 | @article{zhang2020plug,
31 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior},
32 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu},
33 | journal={arXiv preprint},
34 | year={2020}
35 | }
36 | % If you have any question, please feel free to contact with me.
37 | % Kai Zhang (e-mail: cskaizhang@gmail.com; homepage: https://cszn.github.io/)
38 | by Kai Zhang (01/August/2020)
39 |
40 | # --------------------------------------------
41 | |--model_zoo # model_zoo
42 | |--drunet_color # model_name, for color images
43 | |--drunet_gray
44 | |--testset # testsets
45 | |--results # results
46 | # --------------------------------------------
47 |
48 |
49 | How to run:
50 | step 1: download [drunet_gray.pth, drunet_color.pth, ircnn_gray.pth, ircnn_color.pth] from https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D
51 | step 2: set your own testset 'testset_name' and parameter setting such as 'noise_level_img', 'iter_num'.
52 | step 3: 'python main_dpir_sisr.py'
53 |
54 | """
55 |
56 | def main():
57 |
58 | # ----------------------------------------
59 | # Preparation
60 | # ----------------------------------------
61 |
62 | noise_level_img = 0/255.0 # set AWGN noise level for LR image, default: 0,
63 | noise_level_model = noise_level_img # setnoise level of model, default 0
64 | model_name = 'drunet_color' # set denoiser, | 'drunet_color' | 'ircnn_gray' | 'drunet_gray' | 'ircnn_color'
65 | testset_name = 'srbsd68' # set test set, 'set5' | 'srbsd68'
66 | x8 = True # default: False, x8 to boost performance
67 | test_sf = [2] # set scale factor, default: [2, 3, 4], [2], [3], [4]
68 | iter_num = 24 # set number of iterations, default: 24 for SISR
69 | modelSigma1 = 49 # set sigma_1, default: 49
70 | classical_degradation = True # set classical degradation or bicubic degradation
71 |
72 | show_img = False # default: False
73 | save_L = True # save LR image
74 | save_E = True # save estimated image
75 | save_LEH = False # save zoomed LR, E and H images
76 |
77 | task_current = 'sr' # 'sr' for super-resolution
78 | n_channels = 1 if 'gray' in model_name else 3 # fixed
79 | model_zoo = 'model_zoo' # fixed
80 | testsets = 'testsets' # fixed
81 | results = 'results' # fixed
82 | result_name = testset_name + '_' + task_current + '_' + model_name
83 | model_path = os.path.join(model_zoo, model_name+'.pth')
84 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85 | torch.cuda.empty_cache()
86 |
87 | # ----------------------------------------
88 | # L_path, E_path, H_path
89 | # ----------------------------------------
90 |
91 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
92 | E_path = os.path.join(results, result_name) # E_path, for Estimated images
93 | util.mkdir(E_path)
94 |
95 | logger_name = result_name
96 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
97 | logger = logging.getLogger(logger_name)
98 |
99 | # ----------------------------------------
100 | # load model
101 | # ----------------------------------------
102 |
103 | if 'drunet' in model_name:
104 | from models.network_unet import UNetRes as net
105 | model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
106 | model.load_state_dict(torch.load(model_path), strict=True)
107 | model.eval()
108 | for _, v in model.named_parameters():
109 | v.requires_grad = False
110 | model = model.to(device)
111 | elif 'ircnn' in model_name:
112 | from models.network_dncnn import IRCNN as net
113 | model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
114 | model25 = torch.load(model_path)
115 | former_idx = 0
116 |
117 | logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model))
118 | logger.info('Model path: {:s}'.format(model_path))
119 | logger.info(L_path)
120 | L_paths = util.get_image_paths(L_path)
121 |
122 | # --------------------------------
123 | # load kernel
124 | # --------------------------------
125 |
126 | # kernels = hdf5storage.loadmat(os.path.join('kernels', 'Levin09.mat'))['kernels']
127 | if classical_degradation:
128 | kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernels_12.mat'))['kernels']
129 | else:
130 | kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernel_bicubicx234.mat'))['kernels']
131 |
132 | test_results_ave = OrderedDict()
133 | test_results_ave['psnr_sf_k'] = []
134 | test_results_ave['psnr_y_sf_k'] = []
135 |
136 | for sf in test_sf:
137 | border = sf
138 | modelSigma2 = max(sf, noise_level_model*255.)
139 | k_num = 8 if classical_degradation else 1
140 |
141 | for k_index in range(k_num):
142 | logger.info('--------- sf:{:>1d} --k:{:>2d} ---------'.format(sf, k_index))
143 | test_results = OrderedDict()
144 | test_results['psnr'] = []
145 | test_results['psnr_y'] = []
146 |
147 | if not classical_degradation: # for bicubic degradation
148 | k_index = sf-2
149 | k = kernels[0, k_index].astype(np.float64)
150 |
151 | util.surf(k) if show_img else None
152 |
153 | for idx, img in enumerate(L_paths):
154 |
155 | # --------------------------------
156 | # (1) get img_L
157 | # --------------------------------
158 |
159 | img_name, ext = os.path.splitext(os.path.basename(img))
160 | img_H = util.imread_uint(img, n_channels=n_channels)
161 | img_H = util.modcrop(img_H, sf) # modcrop
162 |
163 | if classical_degradation:
164 | img_L = sr.classical_degradation(img_H, k, sf)
165 | util.imshow(img_L) if show_img else None
166 | img_L = util.uint2single(img_L)
167 | else:
168 | img_L = util.imresize_np(util.uint2single(img_H), 1/sf)
169 |
170 | np.random.seed(seed=0) # for reproducibility
171 | img_L += np.random.normal(0, noise_level_img, img_L.shape) # add AWGN
172 |
173 | # --------------------------------
174 | # (2) get rhos and sigmas
175 | # --------------------------------
176 |
177 | rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_model), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1)
178 | rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device)
179 |
180 | # --------------------------------
181 | # (3) initialize x, and pre-calculation
182 | # --------------------------------
183 |
184 | x = cv2.resize(img_L, (img_L.shape[1]*sf, img_L.shape[0]*sf), interpolation=cv2.INTER_CUBIC)
185 | if np.ndim(x)==2:
186 | x = x[..., None]
187 |
188 | if classical_degradation:
189 | x = sr.shift_pixel(x, sf)
190 | x = util.single2tensor4(x).to(device)
191 |
192 | img_L_tensor, k_tensor = util.single2tensor4(img_L), util.single2tensor4(np.expand_dims(k, 2))
193 | [k_tensor, img_L_tensor] = util.todevice([k_tensor, img_L_tensor], device)
194 | FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf)
195 |
196 | # --------------------------------
197 | # (4) main iterations
198 | # --------------------------------
199 |
200 | for i in range(iter_num):
201 |
202 | # --------------------------------
203 | # step 1, FFT
204 | # --------------------------------
205 |
206 | tau = rhos[i].float().repeat(1, 1, 1, 1)
207 | x = sr.data_solution(x.float(), FB, FBC, F2B, FBFy, tau, sf)
208 |
209 | if 'ircnn' in model_name:
210 | current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1)
211 |
212 | if current_idx != former_idx:
213 | model.load_state_dict(model25[str(current_idx)], strict=True)
214 | model.eval()
215 | for _, v in model.named_parameters():
216 | v.requires_grad = False
217 | model = model.to(device)
218 | former_idx = current_idx
219 |
220 | # --------------------------------
221 | # step 2, denoiser
222 | # --------------------------------
223 |
224 | if x8:
225 | x = util.augment_img_tensor4(x, i % 8)
226 |
227 | if 'drunet' in model_name:
228 | x = torch.cat((x, sigmas[i].float().repeat(1, 1, x.shape[2], x.shape[3])), dim=1)
229 | x = utils_model.test_mode(model, x, mode=2, refield=32, min_size=256, modulo=16)
230 | elif 'ircnn' in model_name:
231 | x = model(x)
232 |
233 | if x8:
234 | if i % 8 == 3 or i % 8 == 5:
235 | x = util.augment_img_tensor4(x, 8 - i % 8)
236 | else:
237 | x = util.augment_img_tensor4(x, i % 8)
238 |
239 | # --------------------------------
240 | # (3) img_E
241 | # --------------------------------
242 |
243 | img_E = util.tensor2uint(x)
244 |
245 | if save_E:
246 | util.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_k'+str(k_index)+'_'+model_name+'.png'))
247 |
248 | if n_channels == 1:
249 | img_H = img_H.squeeze()
250 |
251 | # --------------------------------
252 | # (4) img_LEH
253 | # --------------------------------
254 |
255 | img_L = util.single2uint(img_L).squeeze()
256 |
257 | if save_LEH:
258 | k_v = k/np.max(k)*1.0
259 | if n_channels==1:
260 | k_v = util.single2uint(k_v)
261 | else:
262 | k_v = util.single2uint(np.tile(k_v[..., np.newaxis], [1, 1, n_channels]))
263 | k_v = cv2.resize(k_v, (3*k_v.shape[1], 3*k_v.shape[0]), interpolation=cv2.INTER_NEAREST)
264 | img_I = cv2.resize(img_L, (sf*img_L.shape[1], sf*img_L.shape[0]), interpolation=cv2.INTER_NEAREST)
265 | img_I[:k_v.shape[0], -k_v.shape[1]:, ...] = k_v
266 | img_I[:img_L.shape[0], :img_L.shape[1], ...] = img_L
267 | util.imshow(np.concatenate([img_I, img_E, img_H], axis=1), title='LR / Recovered / Ground-truth') if show_img else None
268 | util.imsave(np.concatenate([img_I, img_E, img_H], axis=1), os.path.join(E_path, img_name+'_x'+str(sf)+'_k'+str(k_index)+'_LEH.png'))
269 |
270 | if save_L:
271 | util.imsave(img_L, os.path.join(E_path, img_name+'_x'+str(sf)+'_k'+str(k_index)+'_LR.png'))
272 |
273 | psnr = util.calculate_psnr(img_E, img_H, border=border)
274 | test_results['psnr'].append(psnr)
275 | logger.info('{:->4d}--> {:>10s} -- sf:{:>1d} --k:{:>2d} PSNR: {:.2f}dB'.format(idx+1, img_name+ext, sf, k_index, psnr))
276 |
277 | if n_channels == 3:
278 | img_E_y = util.rgb2ycbcr(img_E, only_y=True)
279 | img_H_y = util.rgb2ycbcr(img_H, only_y=True)
280 | psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border)
281 | test_results['psnr_y'].append(psnr_y)
282 |
283 | # --------------------------------
284 | # Average PSNR for all kernels
285 | # --------------------------------
286 |
287 | ave_psnr_k = sum(test_results['psnr']) / len(test_results['psnr'])
288 | logger.info('------> Average PSNR(RGB) of ({}) scale factor: ({}), kernel: ({}) sigma: ({:.2f}): {:.2f} dB'.format(testset_name, sf, k_index, noise_level_model, ave_psnr_k))
289 | test_results_ave['psnr_sf_k'].append(ave_psnr_k)
290 |
291 | if n_channels == 3: # RGB image
292 | ave_psnr_y_k = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
293 | logger.info('------> Average PSNR(Y) of ({}) scale factor: ({}), kernel: ({}) sigma: ({:.2f}): {:.2f} dB'.format(testset_name, sf, k_index, noise_level_model, ave_psnr_y_k))
294 | test_results_ave['psnr_y_sf_k'].append(ave_psnr_y_k)
295 |
296 | # ---------------------------------------
297 | # Average PSNR for all sf and kernels
298 | # ---------------------------------------
299 |
300 | ave_psnr_sf_k = sum(test_results_ave['psnr_sf_k']) / len(test_results_ave['psnr_sf_k'])
301 | logger.info('------> Average PSNR of ({}) {:.2f} dB'.format(testset_name, ave_psnr_sf_k))
302 | if n_channels == 3:
303 | ave_psnr_y_sf_k = sum(test_results_ave['psnr_y_sf_k']) / len(test_results_ave['psnr_y_sf_k'])
304 | logger.info('------> Average PSNR of ({}) {:.2f} dB'.format(testset_name, ave_psnr_y_sf_k))
305 |
306 | if __name__ == '__main__':
307 |
308 | main()
309 |
--------------------------------------------------------------------------------
/main_dpir_sisr_real_applications.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import glob
3 | import cv2
4 | import logging
5 | import time
6 |
7 | import numpy as np
8 | from datetime import datetime
9 | from collections import OrderedDict
10 | import hdf5storage
11 |
12 | import torch
13 |
14 | from utils import utils_deblur
15 | from utils import utils_logger
16 | from utils import utils_model
17 | from utils import utils_pnp as pnp
18 | from utils import utils_sisr as sr
19 | from utils import utils_image as util
20 |
21 |
22 | """
23 | Spyder (Python 3.7)
24 | PyTorch 1.6.0
25 | Windows 10 or Linux
26 | Kai Zhang (cskaizhang@gmail.com)
27 | github: https://github.com/cszn/DPIR
28 | https://github.com/cszn/IRCNN
29 | https://github.com/cszn/KAIR
30 | @article{zhang2020plug,
31 | title={Plug-and-Play Image Restoration with Deep Denoiser Prior},
32 | author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu},
33 | journal={arXiv preprint},
34 | year={2020}
35 | }
36 | % If you have any question, please feel free to contact with me.
37 | % Kai Zhang (e-mail: cskaizhang@gmail.com; homepage: https://cszn.github.io/)
38 | by Kai Zhang (01/August/2020)
39 |
40 | # --------------------------------------------
41 | |--model_zoo # model_zoo
42 | |--drunet_gray # model_name, for color images
43 | |--drunet_color
44 | |--testset # testsets
45 | |--results # results
46 | # --------------------------------------------
47 | """
48 |
49 | def main():
50 |
51 | """
52 | # ----------------------------------------------------------------------------------
53 | # In real applications, you should set proper
54 | # - "noise_level_img": from [3, 25], set 3 for clean image, try 15 for very noisy LR images
55 | # - "k" (or "kernel_width"): blur kernel is very important!!! kernel_width from [0.6, 3.0]
56 | # to get the best performance.
57 | # ----------------------------------------------------------------------------------
58 | """
59 | ##############################################################################
60 |
61 | testset_name = 'Set3C' # set test set, 'set5' | 'srbsd68'
62 | noise_level_img = 3 # set noise level of image, from [3, 25], set 3 for clean image
63 | model_name = 'drunet_color' # 'ircnn_color' # set denoiser, | 'drunet_color' | 'ircnn_gray' | 'drunet_gray' | 'ircnn_color'
64 | sf = 2 # set scale factor, 1, 2, 3, 4
65 | iter_num = 24 # set number of iterations, default: 24 for SISR
66 |
67 | # --------------------------------
68 | # set blur kernel
69 | # --------------------------------
70 | kernel_width_default_x1234 = [0.6, 0.9, 1.7, 2.2] # Gaussian kernel widths for x1, x2, x3, x4
71 | noise_level_model = noise_level_img/255. # noise level of model
72 | kernel_width = kernel_width_default_x1234[sf-1]
73 |
74 | """
75 | # set your own kernel width !!!!!!!!!!
76 | """
77 | # kernel_width = 1.0
78 |
79 |
80 | k = utils_deblur.fspecial('gaussian', 25, kernel_width)
81 | k = sr.shift_pixel(k, sf) # shift the kernel
82 | k /= np.sum(k)
83 |
84 | ##############################################################################
85 |
86 |
87 | show_img = False
88 | util.surf(k) if show_img else None
89 | x8 = True # default: False, x8 to boost performance
90 | modelSigma1 = 49 # set sigma_1, default: 49
91 | modelSigma2 = max(sf, noise_level_model*255.)
92 | classical_degradation = True # set classical degradation or bicubic degradation
93 |
94 | task_current = 'sr' # 'sr' for super-resolution
95 | n_channels = 1 if 'gray' in model_name else 3 # fixed
96 | model_zoo = 'model_zoo' # fixed
97 | testsets = 'testsets' # fixed
98 | results = 'results' # fixed
99 | result_name = testset_name + '_realapplications_' + task_current + '_' + model_name
100 | model_path = os.path.join(model_zoo, model_name+'.pth')
101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102 | torch.cuda.empty_cache()
103 |
104 | # ----------------------------------------
105 | # L_path, E_path, H_path
106 | # ----------------------------------------
107 | L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality images
108 | E_path = os.path.join(results, result_name) # E_path, for Estimated images
109 | util.mkdir(E_path)
110 |
111 | logger_name = result_name
112 | utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
113 | logger = logging.getLogger(logger_name)
114 |
115 | # ----------------------------------------
116 | # load model
117 | # ----------------------------------------
118 | if 'drunet' in model_name:
119 | from models.network_unet import UNetRes as net
120 | model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
121 | model.load_state_dict(torch.load(model_path), strict=True)
122 | model.eval()
123 | for _, v in model.named_parameters():
124 | v.requires_grad = False
125 | model = model.to(device)
126 | elif 'ircnn' in model_name:
127 | from models.network_dncnn import IRCNN as net
128 | model = net(in_nc=n_channels, out_nc=n_channels, nc=64)
129 | model25 = torch.load(model_path)
130 | former_idx = 0
131 |
132 | logger.info('model_name:{}, image sigma:{:.3f}, model sigma:{:.3f}'.format(model_name, noise_level_img, noise_level_model))
133 | logger.info('Model path: {:s}'.format(model_path))
134 | logger.info(L_path)
135 | L_paths = util.get_image_paths(L_path)
136 |
137 | for idx, img in enumerate(L_paths):
138 |
139 | # --------------------------------
140 | # (1) get img_L
141 | # --------------------------------
142 | logger.info('Model path: {:s} Image: {:s}'.format(model_path, img))
143 | img_name, ext = os.path.splitext(os.path.basename(img))
144 | img_L = util.imread_uint(img, n_channels=n_channels)
145 | img_L = util.uint2single(img_L)
146 | img_L = util.modcrop(img_L, 8) # modcrop
147 |
148 | # --------------------------------
149 | # (2) get rhos and sigmas
150 | # --------------------------------
151 | rhos, sigmas = pnp.get_rho_sigma(sigma=max(0.255/255., noise_level_model), iter_num=iter_num, modelSigma1=modelSigma1, modelSigma2=modelSigma2, w=1)
152 | rhos, sigmas = torch.tensor(rhos).to(device), torch.tensor(sigmas).to(device)
153 |
154 | # --------------------------------
155 | # (3) initialize x, and pre-calculation
156 | # --------------------------------
157 | x = cv2.resize(img_L, (img_L.shape[1]*sf, img_L.shape[0]*sf), interpolation=cv2.INTER_CUBIC)
158 |
159 | if np.ndim(x)==2:
160 | x = x[..., None]
161 |
162 | if classical_degradation:
163 | x = sr.shift_pixel(x, sf)
164 | x = util.single2tensor4(x).to(device)
165 |
166 | img_L_tensor, k_tensor = util.single2tensor4(img_L), util.single2tensor4(np.expand_dims(k, 2))
167 | [k_tensor, img_L_tensor] = util.todevice([k_tensor, img_L_tensor], device)
168 | FB, FBC, F2B, FBFy = sr.pre_calculate(img_L_tensor, k_tensor, sf)
169 |
170 | # --------------------------------
171 | # (4) main iterations
172 | # --------------------------------
173 | for i in range(iter_num):
174 |
175 | print('Iter: {} / {}'.format(i, iter_num))
176 |
177 | # --------------------------------
178 | # step 1, FFT
179 | # --------------------------------
180 | tau = rhos[i].float().repeat(1, 1, 1, 1)
181 | x = sr.data_solution(x, FB, FBC, F2B, FBFy, tau, sf)
182 |
183 | if 'ircnn' in model_name:
184 | current_idx = np.int(np.ceil(sigmas[i].cpu().numpy()*255./2.)-1)
185 |
186 | if current_idx != former_idx:
187 | model.load_state_dict(model25[str(current_idx)], strict=True)
188 | model.eval()
189 | for _, v in model.named_parameters():
190 | v.requires_grad = False
191 | model = model.to(device)
192 | former_idx = current_idx
193 |
194 | # --------------------------------
195 | # step 2, denoiser
196 | # --------------------------------
197 | if x8:
198 | x = util.augment_img_tensor4(x, i % 8)
199 |
200 | if 'drunet' in model_name:
201 | x = torch.cat((x, sigmas[i].repeat(1, 1, x.shape[2], x.shape[3])), dim=1)
202 | x = utils_model.test_mode(model, x, mode=2, refield=64, min_size=256, modulo=16)
203 | elif 'ircnn' in model_name:
204 | x = model(x)
205 |
206 | if x8:
207 | if i % 8 == 3 or i % 8 == 5:
208 | x = util.augment_img_tensor4(x, 8 - i % 8)
209 | else:
210 | x = util.augment_img_tensor4(x, i % 8)
211 |
212 | # --------------------------------
213 | # (3) img_E
214 | # --------------------------------
215 | img_E = util.tensor2uint(x)
216 | util.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'.png'))
217 |
218 | if __name__ == '__main__':
219 |
220 | main()
221 |
--------------------------------------------------------------------------------
/model_zoo/README.md:
--------------------------------------------------------------------------------
1 |
2 | * Google drive download link: [https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing](https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing)
3 |
4 | * 腾讯微云下载链接: [https://share.weiyun.com/5qO32s3](https://share.weiyun.com/5qO32s3)
5 |
6 |
7 | -----------------
8 |
9 |
10 | |Model|Download link|Download link|
11 | |---|:--:|:--:|
12 | |drunet_gray.pth| [Google drive download link](https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing) | [腾讯微云下载链接](https://share.weiyun.com/5qO32s3) |
13 | |drunet_color.pth| [Google drive download link](https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing) | [腾讯微云下载链接](https://share.weiyun.com/5qO32s3) |
14 | |ircnn_gray.pth| [Google drive download link](https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing) | [腾讯微云下载链接](https://share.weiyun.com/5qO32s3) |
15 | |ircnn_color.pth| [Google drive download link](https://drive.google.com/drive/folders/13kfr3qny7S2xwG9h7v95F5mkWs0OmU0D?usp=sharing) | [腾讯微云下载链接](https://share.weiyun.com/5qO32s3) |
16 |
--------------------------------------------------------------------------------
/models/basicblock.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | '''
8 | # --------------------------------------------
9 | # Advanced nn.Sequential
10 | # https://github.com/xinntao/BasicSR
11 | # --------------------------------------------
12 | '''
13 |
14 |
15 | def sequential(*args):
16 | """Advanced nn.Sequential.
17 |
18 | Args:
19 | nn.Sequential, nn.Module
20 |
21 | Returns:
22 | nn.Sequential
23 | """
24 | if len(args) == 1:
25 | if isinstance(args[0], OrderedDict):
26 | raise NotImplementedError('sequential does not support OrderedDict input.')
27 | return args[0] # No sequential is needed.
28 | modules = []
29 | for module in args:
30 | if isinstance(module, nn.Sequential):
31 | for submodule in module.children():
32 | modules.append(submodule)
33 | elif isinstance(module, nn.Module):
34 | modules.append(module)
35 | return nn.Sequential(*modules)
36 |
37 |
38 | '''
39 | # --------------------------------------------
40 | # Useful blocks
41 | # https://github.com/xinntao/BasicSR
42 | # --------------------------------
43 | # conv + normaliation + relu (conv)
44 | # (PixelUnShuffle)
45 | # (ConditionalBatchNorm2d)
46 | # concat (ConcatBlock)
47 | # sum (ShortcutBlock)
48 | # resblock (ResBlock)
49 | # Channel Attention (CA) Layer (CALayer)
50 | # Residual Channel Attention Block (RCABlock)
51 | # Residual Channel Attention Group (RCAGroup)
52 | # Residual Dense Block (ResidualDenseBlock_5C)
53 | # Residual in Residual Dense Block (RRDB)
54 | # --------------------------------------------
55 | '''
56 |
57 |
58 | # --------------------------------------------
59 | # return nn.Sequantial of (Conv + BN + ReLU)
60 | # --------------------------------------------
61 | def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2):
62 | L = []
63 | for t in mode:
64 | if t == 'C':
65 | L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
66 | elif t == 'T':
67 | L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
68 | elif t == 'B':
69 | L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True))
70 | elif t == 'I':
71 | L.append(nn.InstanceNorm2d(out_channels, affine=True))
72 | elif t == 'R':
73 | L.append(nn.ReLU(inplace=True))
74 | elif t == 'r':
75 | L.append(nn.ReLU(inplace=False))
76 | elif t == 'L':
77 | L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True))
78 | elif t == 'l':
79 | L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False))
80 | elif t == '2':
81 | L.append(nn.PixelShuffle(upscale_factor=2))
82 | elif t == '3':
83 | L.append(nn.PixelShuffle(upscale_factor=3))
84 | elif t == '4':
85 | L.append(nn.PixelShuffle(upscale_factor=4))
86 | elif t == 'U':
87 | L.append(nn.Upsample(scale_factor=2, mode='nearest'))
88 | elif t == 'u':
89 | L.append(nn.Upsample(scale_factor=3, mode='nearest'))
90 | elif t == 'v':
91 | L.append(nn.Upsample(scale_factor=4, mode='nearest'))
92 | elif t == 'M':
93 | L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0))
94 | elif t == 'A':
95 | L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
96 | else:
97 | raise NotImplementedError('Undefined type: '.format(t))
98 | return sequential(*L)
99 |
100 |
101 | # --------------------------------------------
102 | # inverse of pixel_shuffle
103 | # --------------------------------------------
104 | def pixel_unshuffle(input, upscale_factor):
105 | r"""Rearranges elements in a Tensor of shape :math:`(C, rH, rW)` to a
106 | tensor of shape :math:`(*, r^2C, H, W)`.
107 |
108 | Authors:
109 | Zhaoyi Yan, https://github.com/Zhaoyi-Yan
110 | Kai Zhang, https://github.com/cszn/FFDNet
111 |
112 | Date:
113 | 01/Jan/2019
114 | """
115 | batch_size, channels, in_height, in_width = input.size()
116 |
117 | out_height = in_height // upscale_factor
118 | out_width = in_width // upscale_factor
119 |
120 | input_view = input.contiguous().view(
121 | batch_size, channels, out_height, upscale_factor,
122 | out_width, upscale_factor)
123 |
124 | channels *= upscale_factor ** 2
125 | unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()
126 | return unshuffle_out.view(batch_size, channels, out_height, out_width)
127 |
128 |
129 | class PixelUnShuffle(nn.Module):
130 | r"""Rearranges elements in a Tensor of shape :math:`(C, rH, rW)` to a
131 | tensor of shape :math:`(*, r^2C, H, W)`.
132 |
133 | Authors:
134 | Zhaoyi Yan, https://github.com/Zhaoyi-Yan
135 | Kai Zhang, https://github.com/cszn/FFDNet
136 |
137 | Date:
138 | 01/Jan/2019
139 | """
140 |
141 | def __init__(self, upscale_factor):
142 | super(PixelUnShuffle, self).__init__()
143 | self.upscale_factor = upscale_factor
144 |
145 | def forward(self, input):
146 | return pixel_unshuffle(input, self.upscale_factor)
147 |
148 | def extra_repr(self):
149 | return 'upscale_factor={}'.format(self.upscale_factor)
150 |
151 |
152 | # --------------------------------------------
153 | # conditional batch norm
154 | # https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775
155 | # --------------------------------------------
156 | class ConditionalBatchNorm2d(nn.Module):
157 | def __init__(self, num_features, num_classes):
158 | super().__init__()
159 | self.num_features = num_features
160 | self.bn = nn.BatchNorm2d(num_features, affine=False)
161 | self.embed = nn.Embedding(num_classes, num_features * 2)
162 | self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
163 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
164 |
165 | def forward(self, x, y):
166 | out = self.bn(x)
167 | gamma, beta = self.embed(y).chunk(2, 1)
168 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
169 | return out
170 |
171 |
172 | # --------------------------------------------
173 | # Concat the output of a submodule to its input
174 | # --------------------------------------------
175 | class ConcatBlock(nn.Module):
176 | def __init__(self, submodule):
177 | super(ConcatBlock, self).__init__()
178 | self.sub = submodule
179 |
180 | def forward(self, x):
181 | output = torch.cat((x, self.sub(x)), dim=1)
182 | return output
183 |
184 | def __repr__(self):
185 | return self.sub.__repr__() + 'concat'
186 |
187 |
188 | # --------------------------------------------
189 | # sum the output of a submodule to its input
190 | # --------------------------------------------
191 | class ShortcutBlock(nn.Module):
192 | def __init__(self, submodule):
193 | super(ShortcutBlock, self).__init__()
194 |
195 | self.sub = submodule
196 |
197 | def forward(self, x):
198 | output = x + self.sub(x)
199 | return output
200 |
201 | def __repr__(self):
202 | tmpstr = 'Identity + \n|'
203 | modstr = self.sub.__repr__().replace('\n', '\n|')
204 | tmpstr = tmpstr + modstr
205 | return tmpstr
206 |
207 |
208 | # --------------------------------------------
209 | # Res Block: x + conv(relu(conv(x)))
210 | # --------------------------------------------
211 | class ResBlock(nn.Module):
212 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2):
213 | super(ResBlock, self).__init__()
214 |
215 | assert in_channels == out_channels, 'Only support in_channels==out_channels.'
216 | if mode[0] in ['R', 'L']:
217 | mode = mode[0].lower() + mode[1:]
218 |
219 | self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
220 |
221 | def forward(self, x):
222 | #res = self.res(x)
223 | return x + self.res(x)
224 |
225 |
226 | # --------------------------------------------
227 | # simplified information multi-distillation block (IMDB)
228 | # x + conv1(concat(split(relu(conv(x)))x3))
229 | # --------------------------------------------
230 | class IMDBlock(nn.Module):
231 | """
232 | @inproceedings{hui2019lightweight,
233 | title={Lightweight Image Super-Resolution with Information Multi-distillation Network},
234 | author={Hui, Zheng and Gao, Xinbo and Yang, Yunchu and Wang, Xiumei},
235 | booktitle={Proceedings of the 27th ACM International Conference on Multimedia (ACM MM)},
236 | pages={2024--2032},
237 | year={2019}
238 | }
239 | @inproceedings{zhang2019aim,
240 | title={AIM 2019 Challenge on Constrained Super-Resolution: Methods and Results},
241 | author={Kai Zhang and Shuhang Gu and Radu Timofte and others},
242 | booktitle={IEEE International Conference on Computer Vision Workshops},
243 | year={2019}
244 | }
245 | """
246 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CL', d_rate=0.25, negative_slope=0.05):
247 | super(IMDBlock, self).__init__()
248 | self.d_nc = int(in_channels * d_rate)
249 | self.r_nc = int(in_channels - self.d_nc)
250 |
251 | assert mode[0] == 'C', 'convolutional layer first'
252 |
253 | self.conv1 = conv(in_channels, in_channels, kernel_size, stride, padding, bias, mode, negative_slope)
254 | self.conv2 = conv(self.r_nc, in_channels, kernel_size, stride, padding, bias, mode, negative_slope)
255 | self.conv3 = conv(self.r_nc, in_channels, kernel_size, stride, padding, bias, mode, negative_slope)
256 | self.conv4 = conv(self.r_nc, self.d_nc, kernel_size, stride, padding, bias, mode[0], negative_slope)
257 | self.conv1x1 = conv(self.d_nc*4, out_channels, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0], negative_slope=negative_slope)
258 |
259 | def forward(self, x):
260 | d1, r = torch.split(self.conv1(x), (self.d_nc, self.r_nc), dim=1)
261 | d2, r = torch.split(self.conv2(r), (self.d_nc, self.r_nc), dim=1)
262 | d3, r = torch.split(self.conv3(r), (self.d_nc, self.r_nc), dim=1)
263 | r = self.conv4(r)
264 | res = self.conv1x1(torch.cat((d1, d2, d3, r), dim=1))
265 | return x + res
266 |
267 |
268 | # d1, r1 = torch.split(self.conv1(x), (self.d_nc, self.r_nc), dim=1)
269 | # d2, r2 = torch.split(self.conv2(r1), (self.d_nc, self.r_nc), dim=1)
270 | # d3, r3 = torch.split(self.conv3(r2), (self.d_nc, self.r_nc), dim=1)
271 | # d4 = self.conv4(r3)
272 | # --------------------------------------------
273 | # Channel Attention (CA) Layer
274 | # --------------------------------------------
275 | class CALayer(nn.Module):
276 | def __init__(self, channel=64, reduction=16):
277 | super(CALayer, self).__init__()
278 |
279 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
280 | self.conv_fc = nn.Sequential(
281 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
282 | nn.ReLU(inplace=True),
283 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
284 | nn.Sigmoid()
285 | )
286 |
287 | def forward(self, x):
288 | y = self.avg_pool(x)
289 | y = self.conv_fc(y)
290 | return x * y
291 |
292 |
293 | # --------------------------------------------
294 | # Residual Channel Attention Block (RCAB)
295 | # --------------------------------------------
296 | class RCABlock(nn.Module):
297 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', reduction=16, negative_slope=0.2):
298 | super(RCABlock, self).__init__()
299 | assert in_channels == out_channels, 'Only support in_channels==out_channels.'
300 | if mode[0] in ['R','L']:
301 | mode = mode[0].lower() + mode[1:]
302 |
303 | self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
304 | self.ca = CALayer(out_channels, reduction)
305 |
306 | def forward(self, x):
307 | res = self.res(x)
308 | res = self.ca(res)
309 | return res + x
310 |
311 |
312 | # --------------------------------------------
313 | # Residual Channel Attention Group (RG)
314 | # --------------------------------------------
315 | class RCAGroup(nn.Module):
316 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', reduction=16, nb=12, negative_slope=0.2):
317 | super(RCAGroup, self).__init__()
318 | assert in_channels == out_channels, 'Only support in_channels==out_channels.'
319 | if mode[0] in ['R','L']:
320 | mode = mode[0].lower() + mode[1:]
321 |
322 | RG = [RCABlock(in_channels, out_channels, kernel_size, stride, padding, bias, mode, reduction, negative_slope) for _ in range(nb)]
323 | RG.append(conv(out_channels, out_channels, mode='C'))
324 | self.rg = nn.Sequential(*RG) # self.rg = ShortcutBlock(nn.Sequential(*RG))
325 |
326 | def forward(self, x):
327 | res = self.rg(x)
328 | return res + x
329 |
330 |
331 | # --------------------------------------------
332 | # Residual Dense Block
333 | # style: 5 convs
334 | # --------------------------------------------
335 | class ResidualDenseBlock_5C(nn.Module):
336 | def __init__(self, nc=64, gc=32, kernel_size=3, stride=1, padding=1, bias=True, mode='CR', negative_slope=0.2):
337 | super(ResidualDenseBlock_5C, self).__init__()
338 | # gc: growth channel
339 | self.conv1 = conv(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
340 | self.conv2 = conv(nc+gc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
341 | self.conv3 = conv(nc+2*gc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
342 | self.conv4 = conv(nc+3*gc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
343 | self.conv5 = conv(nc+4*gc, nc, kernel_size, stride, padding, bias, mode[:-1], negative_slope)
344 |
345 | def forward(self, x):
346 | x1 = self.conv1(x)
347 | x2 = self.conv2(torch.cat((x, x1), 1))
348 | x3 = self.conv3(torch.cat((x, x1, x2), 1))
349 | x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
350 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
351 | return x5.mul_(0.2) + x
352 |
353 |
354 | # --------------------------------------------
355 | # Residual in Residual Dense Block
356 | # 3x5c
357 | # --------------------------------------------
358 | class RRDB(nn.Module):
359 | def __init__(self, nc=64, gc=32, kernel_size=3, stride=1, padding=1, bias=True, mode='CR', negative_slope=0.2):
360 | super(RRDB, self).__init__()
361 |
362 | self.RDB1 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
363 | self.RDB2 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
364 | self.RDB3 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope)
365 |
366 | def forward(self, x):
367 | out = self.RDB1(x)
368 | out = self.RDB2(out)
369 | out = self.RDB3(out)
370 | return out.mul_(0.2) + x
371 |
372 |
373 | """
374 | # --------------------------------------------
375 | # Upsampler
376 | # Kai Zhang, https://github.com/cszn/KAIR
377 | # --------------------------------------------
378 | # upsample_pixelshuffle
379 | # upsample_upconv
380 | # upsample_convtranspose
381 | # --------------------------------------------
382 | """
383 |
384 |
385 | # --------------------------------------------
386 | # conv + subp (+ relu)
387 | # --------------------------------------------
388 | def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
389 | assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
390 | up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C'+mode, negative_slope=negative_slope)
391 | return up1
392 |
393 |
394 | # --------------------------------------------
395 | # nearest_upsample + conv (+ R)
396 | # --------------------------------------------
397 | def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
398 | assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR'
399 | if mode[0] == '2':
400 | uc = 'UC'
401 | elif mode[0] == '3':
402 | uc = 'uC'
403 | elif mode[0] == '4':
404 | uc = 'vC'
405 | mode = mode.replace(mode[0], uc)
406 | up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope)
407 | return up1
408 |
409 |
410 | # --------------------------------------------
411 | # convTranspose (+ relu)
412 | # --------------------------------------------
413 | def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2):
414 | assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
415 | kernel_size = int(mode[0])
416 | stride = int(mode[0])
417 | mode = mode.replace(mode[0], 'T')
418 | up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
419 | return up1
420 |
421 |
422 | '''
423 | # --------------------------------------------
424 | # Downsampler
425 | # Kai Zhang, https://github.com/cszn/KAIR
426 | # --------------------------------------------
427 | # downsample_strideconv
428 | # downsample_maxpool
429 | # downsample_avgpool
430 | # --------------------------------------------
431 | '''
432 |
433 |
434 | # --------------------------------------------
435 | # strideconv (+ relu)
436 | # --------------------------------------------
437 | def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2):
438 | assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
439 | kernel_size = int(mode[0])
440 | stride = int(mode[0])
441 | mode = mode.replace(mode[0], 'C')
442 | down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
443 | return down1
444 |
445 |
446 | # --------------------------------------------
447 | # maxpooling + conv (+ relu)
448 | # --------------------------------------------
449 | def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', negative_slope=0.2):
450 | assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
451 | kernel_size_pool = int(mode[0])
452 | stride_pool = int(mode[0])
453 | mode = mode.replace(mode[0], 'MC')
454 | pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
455 | pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope)
456 | return sequential(pool, pool_tail)
457 |
458 |
459 | # --------------------------------------------
460 | # averagepooling + conv (+ relu)
461 | # --------------------------------------------
462 | def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2):
463 | assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
464 | kernel_size_pool = int(mode[0])
465 | stride_pool = int(mode[0])
466 | mode = mode.replace(mode[0], 'AC')
467 | pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
468 | pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope)
469 | return sequential(pool, pool_tail)
470 |
471 |
472 | '''
473 | # --------------------------------------------
474 | # NonLocalBlock2D:
475 | # embedded_gaussian
476 | # +W(softmax(thetaXphi)Xg)
477 | # --------------------------------------------
478 | '''
479 |
480 |
481 | # --------------------------------------------
482 | # non-local block with embedded_gaussian
483 | # https://github.com/AlexHex7/Non-local_pytorch
484 | # --------------------------------------------
485 | class NonLocalBlock2D(nn.Module):
486 | def __init__(self, nc=64, kernel_size=1, stride=1, padding=0, bias=True, act_mode='B', downsample=False, downsample_mode='maxpool', negative_slope=0.2):
487 |
488 | super(NonLocalBlock2D, self).__init__()
489 |
490 | inter_nc = nc // 2
491 | self.inter_nc = inter_nc
492 | self.W = conv(inter_nc, nc, kernel_size, stride, padding, bias, mode='C'+act_mode)
493 | self.theta = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C')
494 |
495 | if downsample:
496 | if downsample_mode == 'avgpool':
497 | downsample_block = downsample_avgpool
498 | elif downsample_mode == 'maxpool':
499 | downsample_block = downsample_maxpool
500 | elif downsample_mode == 'strideconv':
501 | downsample_block = downsample_strideconv
502 | else:
503 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
504 | self.phi = downsample_block(nc, inter_nc, kernel_size, stride, padding, bias, mode='2')
505 | self.g = downsample_block(nc, inter_nc, kernel_size, stride, padding, bias, mode='2')
506 | else:
507 | self.phi = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C')
508 | self.g = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C')
509 |
510 | def forward(self, x):
511 | '''
512 | :param x: (b, c, t, h, w)
513 | :return:
514 | '''
515 |
516 | batch_size = x.size(0)
517 |
518 | g_x = self.g(x).view(batch_size, self.inter_nc, -1)
519 | g_x = g_x.permute(0, 2, 1)
520 |
521 | theta_x = self.theta(x).view(batch_size, self.inter_nc, -1)
522 | theta_x = theta_x.permute(0, 2, 1)
523 | phi_x = self.phi(x).view(batch_size, self.inter_nc, -1)
524 | f = torch.matmul(theta_x, phi_x)
525 | f_div_C = F.softmax(f, dim=-1)
526 |
527 | y = torch.matmul(f_div_C, g_x)
528 | y = y.permute(0, 2, 1).contiguous()
529 | y = y.view(batch_size, self.inter_nc, *x.size()[2:])
530 | W_y = self.W(y)
531 | z = W_y + x
532 |
533 | return z
534 |
--------------------------------------------------------------------------------
/models/network_dncnn.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import models.basicblock as B
4 |
5 |
6 | """
7 | # --------------------------------------------
8 | # DnCNN (20 conv layers)
9 | # FDnCNN (20 conv layers)
10 | # --------------------------------------------
11 | # References:
12 | @article{zhang2017beyond,
13 | title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising},
14 | author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
15 | journal={IEEE Transactions on Image Processing},
16 | volume={26},
17 | number={7},
18 | pages={3142--3155},
19 | year={2017},
20 | publisher={IEEE}
21 | }
22 | @article{zhang2018ffdnet,
23 | title={FFDNet: Toward a fast and flexible solution for CNN-based image denoising},
24 | author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
25 | journal={IEEE Transactions on Image Processing},
26 | volume={27},
27 | number={9},
28 | pages={4608--4622},
29 | year={2018},
30 | publisher={IEEE}
31 | }
32 | # --------------------------------------------
33 | """
34 |
35 |
36 | # --------------------------------------------
37 | # DnCNN
38 | # --------------------------------------------
39 | class DnCNN(nn.Module):
40 | def __init__(self, in_nc=1, out_nc=1, nc=64, nb=17, act_mode='BR'):
41 | """
42 | # ------------------------------------
43 | in_nc: channel number of input
44 | out_nc: channel number of output
45 | nc: channel number
46 | nb: total number of conv layers
47 | act_mode: batch norm + activation function; 'BR' means BN+ReLU.
48 | # ------------------------------------
49 | Batch normalization and residual learning are
50 | beneficial to Gaussian denoising (especially
51 | for a single noise level).
52 | The residual of a noisy image corrupted by additive white
53 | Gaussian noise (AWGN) follows a constant
54 | Gaussian distribution which stablizes batch
55 | normalization during training.
56 | # ------------------------------------
57 | """
58 | super(DnCNN, self).__init__()
59 | assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL'
60 | bias = True
61 |
62 | m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias)
63 | m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)]
64 | m_tail = B.conv(nc, out_nc, mode='C', bias=bias)
65 |
66 | self.model = B.sequential(m_head, *m_body, m_tail)
67 |
68 | def forward(self, x):
69 | n = self.model(x)
70 | return x-n
71 |
72 |
73 |
74 | class IRCNN(nn.Module):
75 | def __init__(self, in_nc=1, out_nc=1, nc=64):
76 | """
77 | # ------------------------------------
78 | denoiser of IRCNN
79 | in_nc: channel number of input
80 | out_nc: channel number of output
81 | nc: channel number
82 | nb: total number of conv layers
83 | act_mode: batch norm + activation function; 'BR' means BN+ReLU.
84 | # ------------------------------------
85 | Batch normalization and residual learning are
86 | beneficial to Gaussian denoising (especially
87 | for a single noise level).
88 | The residual of a noisy image corrupted by additive white
89 | Gaussian noise (AWGN) follows a constant
90 | Gaussian distribution which stablizes batch
91 | normalization during training.
92 | # ------------------------------------
93 | """
94 | super(IRCNN, self).__init__()
95 | L =[]
96 | L.append(nn.Conv2d(in_channels=in_nc, out_channels=nc, kernel_size=3, stride=1, padding=1, dilation=1, bias=True))
97 | L.append(nn.ReLU(inplace=True))
98 | L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=2, dilation=2, bias=True))
99 | L.append(nn.ReLU(inplace=True))
100 | L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=3, dilation=3, bias=True))
101 | L.append(nn.ReLU(inplace=True))
102 | L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=4, dilation=4, bias=True))
103 | L.append(nn.ReLU(inplace=True))
104 | L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=3, dilation=3, bias=True))
105 | L.append(nn.ReLU(inplace=True))
106 | L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=2, dilation=2, bias=True))
107 | L.append(nn.ReLU(inplace=True))
108 | L.append(nn.Conv2d(in_channels=nc, out_channels=out_nc, kernel_size=3, stride=1, padding=1, dilation=1, bias=True))
109 | self.model = B.sequential(*L)
110 |
111 | def forward(self, x):
112 | n = self.model(x)
113 | return x-n
114 |
115 |
116 |
117 |
118 |
119 |
120 | # --------------------------------------------
121 | # FDnCNN
122 | # --------------------------------------------
123 | # Compared with DnCNN, FDnCNN has three modifications:
124 | # 1) add noise level map as input
125 | # 2) remove residual learning and BN
126 | # 3) train with L1 loss
127 | # may need more training time, but will not reduce the final PSNR too much.
128 | # --------------------------------------------
129 | class FDnCNN(nn.Module):
130 | def __init__(self, in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R'):
131 | """
132 | in_nc: channel number of input
133 | out_nc: channel number of output
134 | nc: channel number
135 | nb: total number of conv layers
136 | act_mode: batch norm + activation function; 'BR' means BN+ReLU.
137 | """
138 | super(FDnCNN, self).__init__()
139 | assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL'
140 | bias = True
141 |
142 | m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias)
143 | m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)]
144 | m_tail = B.conv(nc, out_nc, mode='C', bias=bias)
145 |
146 | self.model = B.sequential(m_head, *m_body, m_tail)
147 |
148 | def forward(self, x):
149 | x = self.model(x)
150 | return x
151 |
152 |
153 | if __name__ == '__main__':
154 | from utils import utils_model
155 | import torch
156 | model1 = DnCNN(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='BR')
157 | print(utils_model.describe_model(model1))
158 |
159 | model2 = FDnCNN(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
160 | print(utils_model.describe_model(model2))
161 |
162 | x = torch.randn((1, 1, 240, 240))
163 | x1 = model1(x)
164 | print(x1.shape)
165 |
166 | x = torch.randn((1, 2, 240, 240))
167 | x2 = model2(x)
168 | print(x2.shape)
169 |
170 | # run models/network_dncnn.py
--------------------------------------------------------------------------------
/models/network_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import models.basicblock as B
4 | import numpy as np
5 |
6 | '''
7 | # ====================
8 | # unet
9 | # ====================
10 | '''
11 |
12 |
13 | class UNet(nn.Module):
14 | def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=2, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'):
15 | super(UNet, self).__init__()
16 |
17 | self.m_head = B.conv(in_nc, nc[0], mode='C'+act_mode[-1])
18 |
19 | # downsample
20 | if downsample_mode == 'avgpool':
21 | downsample_block = B.downsample_avgpool
22 | elif downsample_mode == 'maxpool':
23 | downsample_block = B.downsample_maxpool
24 | elif downsample_mode == 'strideconv':
25 | downsample_block = B.downsample_strideconv
26 | else:
27 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
28 |
29 | self.m_down1 = B.sequential(*[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[0], nc[1], mode='2'+act_mode))
30 | self.m_down2 = B.sequential(*[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[1], nc[2], mode='2'+act_mode))
31 | self.m_down3 = B.sequential(*[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[2], nc[3], mode='2'+act_mode))
32 |
33 | self.m_body = B.sequential(*[B.conv(nc[3], nc[3], mode='C'+act_mode) for _ in range(nb+1)])
34 |
35 | # upsample
36 | if upsample_mode == 'upconv':
37 | upsample_block = B.upsample_upconv
38 | elif upsample_mode == 'pixelshuffle':
39 | upsample_block = B.upsample_pixelshuffle
40 | elif upsample_mode == 'convtranspose':
41 | upsample_block = B.upsample_convtranspose
42 | else:
43 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
44 |
45 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], mode='2'+act_mode), *[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb)])
46 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], mode='2'+act_mode), *[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb)])
47 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], mode='2'+act_mode), *[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb)])
48 |
49 | self.m_tail = B.conv(nc[0], out_nc, bias=True, mode='C')
50 |
51 | def forward(self, x0):
52 |
53 | x1 = self.m_head(x0)
54 | x2 = self.m_down1(x1)
55 | x3 = self.m_down2(x2)
56 | x4 = self.m_down3(x3)
57 | x = self.m_body(x4)
58 | x = self.m_up3(x+x4)
59 | x = self.m_up2(x+x3)
60 | x = self.m_up1(x+x2)
61 | x = self.m_tail(x+x1) + x0
62 |
63 |
64 | return x
65 |
66 |
67 | class UNetRes(nn.Module):
68 | def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'):
69 | super(UNetRes, self).__init__()
70 |
71 | self.m_head = B.conv(in_nc, nc[0], bias=False, mode='C')
72 |
73 | # downsample
74 | if downsample_mode == 'avgpool':
75 | downsample_block = B.downsample_avgpool
76 | elif downsample_mode == 'maxpool':
77 | downsample_block = B.downsample_maxpool
78 | elif downsample_mode == 'strideconv':
79 | downsample_block = B.downsample_strideconv
80 | else:
81 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
82 |
83 | self.m_down1 = B.sequential(*[B.ResBlock(nc[0], nc[0], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], bias=False, mode='2'))
84 | self.m_down2 = B.sequential(*[B.ResBlock(nc[1], nc[1], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], bias=False, mode='2'))
85 | self.m_down3 = B.sequential(*[B.ResBlock(nc[2], nc[2], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], bias=False, mode='2'))
86 |
87 | self.m_body = B.sequential(*[B.ResBlock(nc[3], nc[3], bias=False, mode='C'+act_mode+'C') for _ in range(nb)])
88 |
89 | # upsample
90 | if upsample_mode == 'upconv':
91 | upsample_block = B.upsample_upconv
92 | elif upsample_mode == 'pixelshuffle':
93 | upsample_block = B.upsample_pixelshuffle
94 | elif upsample_mode == 'convtranspose':
95 | upsample_block = B.upsample_convtranspose
96 | else:
97 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
98 |
99 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], bias=False, mode='2'), *[B.ResBlock(nc[2], nc[2], bias=False, mode='C'+act_mode+'C') for _ in range(nb)])
100 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], bias=False, mode='2'), *[B.ResBlock(nc[1], nc[1], bias=False, mode='C'+act_mode+'C') for _ in range(nb)])
101 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], bias=False, mode='2'), *[B.ResBlock(nc[0], nc[0], bias=False, mode='C'+act_mode+'C') for _ in range(nb)])
102 |
103 | self.m_tail = B.conv(nc[0], out_nc, bias=False, mode='C')
104 |
105 | def forward(self, x0):
106 | x1 = self.m_head(x0)
107 | x2 = self.m_down1(x1)
108 | x3 = self.m_down2(x2)
109 | x4 = self.m_down3(x3)
110 | x = self.m_body(x4)
111 | x = self.m_up3(x+x4)
112 | x = self.m_up2(x+x3)
113 | x = self.m_up1(x+x2)
114 | x = self.m_tail(x+x1)
115 |
116 | return x
117 |
118 |
119 | class ResUNet(nn.Module):
120 | def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=4, act_mode='L', downsample_mode='strideconv', upsample_mode='convtranspose'):
121 | super(ResUNet, self).__init__()
122 |
123 | self.m_head = B.conv(in_nc, nc[0], bias=False, mode='C')
124 |
125 | # downsample
126 | if downsample_mode == 'avgpool':
127 | downsample_block = B.downsample_avgpool
128 | elif downsample_mode == 'maxpool':
129 | downsample_block = B.downsample_maxpool
130 | elif downsample_mode == 'strideconv':
131 | downsample_block = B.downsample_strideconv
132 | else:
133 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
134 |
135 | self.m_down1 = B.sequential(*[B.IMDBlock(nc[0], nc[0], bias=False, mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[0], nc[1], bias=False, mode='2'))
136 | self.m_down2 = B.sequential(*[B.IMDBlock(nc[1], nc[1], bias=False, mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[1], nc[2], bias=False, mode='2'))
137 | self.m_down3 = B.sequential(*[B.IMDBlock(nc[2], nc[2], bias=False, mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[2], nc[3], bias=False, mode='2'))
138 |
139 | self.m_body = B.sequential(*[B.IMDBlock(nc[3], nc[3], bias=False, mode='C'+act_mode) for _ in range(nb)])
140 |
141 | # upsample
142 | if upsample_mode == 'upconv':
143 | upsample_block = B.upsample_upconv
144 | elif upsample_mode == 'pixelshuffle':
145 | upsample_block = B.upsample_pixelshuffle
146 | elif upsample_mode == 'convtranspose':
147 | upsample_block = B.upsample_convtranspose
148 | else:
149 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
150 |
151 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], bias=False, mode='2'), *[B.IMDBlock(nc[2], nc[2], bias=False, mode='C'+act_mode) for _ in range(nb)])
152 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], bias=False, mode='2'), *[B.IMDBlock(nc[1], nc[1], bias=False, mode='C'+act_mode) for _ in range(nb)])
153 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], bias=False, mode='2'), *[B.IMDBlock(nc[0], nc[0], bias=False, mode='C'+act_mode) for _ in range(nb)])
154 |
155 | self.m_tail = B.conv(nc[0], out_nc, bias=False, mode='C')
156 |
157 | def forward(self, x):
158 |
159 | h, w = x.size()[-2:]
160 | paddingBottom = int(np.ceil(h/8)*8-h)
161 | paddingRight = int(np.ceil(w/8)*8-w)
162 | x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x)
163 |
164 | x1 = self.m_head(x)
165 | x2 = self.m_down1(x1)
166 | x3 = self.m_down2(x2)
167 | x4 = self.m_down3(x3)
168 | x = self.m_body(x4)
169 | x = self.m_up3(x+x4)
170 | x = self.m_up2(x+x3)
171 | x = self.m_up1(x+x2)
172 | x = self.m_tail(x+x1)
173 | x = x[..., :h, :w]
174 |
175 | return x
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 | class UNetResSubP(nn.Module):
191 | def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=2, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'):
192 | super(UNetResSubP, self).__init__()
193 | sf = 2
194 | self.m_ps_down = B.PixelUnShuffle(sf)
195 | self.m_ps_up = nn.PixelShuffle(sf)
196 | self.m_head = B.conv(in_nc*sf*sf, nc[0], mode='C'+act_mode[-1])
197 |
198 | # downsample
199 | if downsample_mode == 'avgpool':
200 | downsample_block = B.downsample_avgpool
201 | elif downsample_mode == 'maxpool':
202 | downsample_block = B.downsample_maxpool
203 | elif downsample_mode == 'strideconv':
204 | downsample_block = B.downsample_strideconv
205 | else:
206 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
207 |
208 | self.m_down1 = B.sequential(*[B.ResBlock(nc[0], nc[0], mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], mode='2'+act_mode))
209 | self.m_down2 = B.sequential(*[B.ResBlock(nc[1], nc[1], mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], mode='2'+act_mode))
210 | self.m_down3 = B.sequential(*[B.ResBlock(nc[2], nc[2], mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], mode='2'+act_mode))
211 |
212 | self.m_body = B.sequential(*[B.ResBlock(nc[3], nc[3], mode='C'+act_mode+'C') for _ in range(nb+1)])
213 |
214 | # upsample
215 | if upsample_mode == 'upconv':
216 | upsample_block = B.upsample_upconv
217 | elif upsample_mode == 'pixelshuffle':
218 | upsample_block = B.upsample_pixelshuffle
219 | elif upsample_mode == 'convtranspose':
220 | upsample_block = B.upsample_convtranspose
221 | else:
222 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
223 |
224 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], mode='2'+act_mode), *[B.ResBlock(nc[2], nc[2], mode='C'+act_mode+'C') for _ in range(nb)])
225 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], mode='2'+act_mode), *[B.ResBlock(nc[1], nc[1], mode='C'+act_mode+'C') for _ in range(nb)])
226 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], mode='2'+act_mode), *[B.ResBlock(nc[0], nc[0], mode='C'+act_mode+'C') for _ in range(nb)])
227 |
228 | self.m_tail = B.conv(nc[0], out_nc*sf*sf, bias=False, mode='C')
229 |
230 | def forward(self, x0):
231 | x0_d = self.m_ps_down(x0)
232 | x1 = self.m_head(x0_d)
233 | x2 = self.m_down1(x1)
234 | x3 = self.m_down2(x2)
235 | x4 = self.m_down3(x3)
236 | x = self.m_body(x4)
237 | x = self.m_up3(x+x4)
238 | x = self.m_up2(x+x3)
239 | x = self.m_up1(x+x2)
240 | x = self.m_tail(x+x1)
241 | x = self.m_ps_up(x) + x0
242 |
243 | return x
244 |
245 |
246 | class UNetPlus(nn.Module):
247 | def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=1, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'):
248 | super(UNetPlus, self).__init__()
249 |
250 | self.m_head = B.conv(in_nc, nc[0], mode='C')
251 |
252 | # downsample
253 | if downsample_mode == 'avgpool':
254 | downsample_block = B.downsample_avgpool
255 | elif downsample_mode == 'maxpool':
256 | downsample_block = B.downsample_maxpool
257 | elif downsample_mode == 'strideconv':
258 | downsample_block = B.downsample_strideconv
259 | else:
260 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
261 |
262 | self.m_down1 = B.sequential(*[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[0], nc[1], mode='2'+act_mode[1]))
263 | self.m_down2 = B.sequential(*[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[1], nc[2], mode='2'+act_mode[1]))
264 | self.m_down3 = B.sequential(*[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[2], nc[3], mode='2'+act_mode[1]))
265 |
266 | self.m_body = B.sequential(*[B.conv(nc[3], nc[3], mode='C'+act_mode) for _ in range(nb+1)])
267 |
268 | # upsample
269 | if upsample_mode == 'upconv':
270 | upsample_block = B.upsample_upconv
271 | elif upsample_mode == 'pixelshuffle':
272 | upsample_block = B.upsample_pixelshuffle
273 | elif upsample_mode == 'convtranspose':
274 | upsample_block = B.upsample_convtranspose
275 | else:
276 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
277 |
278 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], mode='2'+act_mode), *[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb-1)], B.conv(nc[2], nc[2], mode='C'+act_mode[1]))
279 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], mode='2'+act_mode), *[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb-1)], B.conv(nc[1], nc[1], mode='C'+act_mode[1]))
280 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], mode='2'+act_mode), *[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb-1)], B.conv(nc[0], nc[0], mode='C'+act_mode[1]))
281 |
282 | self.m_tail = B.conv(nc[0], out_nc, mode='C')
283 |
284 | def forward(self, x0):
285 | x1 = self.m_head(x0)
286 | x2 = self.m_down1(x1)
287 | x3 = self.m_down2(x2)
288 | x4 = self.m_down3(x3)
289 | x = self.m_body(x4)
290 | x = self.m_up3(x+x4)
291 | x = self.m_up2(x+x3)
292 | x = self.m_up1(x+x2)
293 | x = self.m_tail(x+x1) + x0
294 | return x
295 |
296 | '''
297 | # ====================
298 | # nonlocalunet
299 | # ====================
300 | '''
301 |
302 | class NonLocalUNet(nn.Module):
303 | def __init__(self, in_nc=3, out_nc=3, nc=[64,128,256,512], nb=1, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'):
304 | super(NonLocalUNet, self).__init__()
305 |
306 | down_nonlocal = B.NonLocalBlock2D(nc[2], kernel_size=1, stride=1, padding=0, bias=True, act_mode='B', downsample=False, downsample_mode='strideconv')
307 | up_nonlocal = B.NonLocalBlock2D(nc[2], kernel_size=1, stride=1, padding=0, bias=True, act_mode='B', downsample=False, downsample_mode='strideconv')
308 |
309 | self.m_head = B.conv(in_nc, nc[0], mode='C'+act_mode[-1])
310 |
311 | # downsample
312 | if downsample_mode == 'avgpool':
313 | downsample_block = B.downsample_avgpool
314 | elif downsample_mode == 'maxpool':
315 | downsample_block = B.downsample_maxpool
316 | elif downsample_mode == 'strideconv':
317 | downsample_block = B.downsample_strideconv
318 | else:
319 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
320 |
321 |
322 | self.m_down1 = B.sequential(*[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[0], nc[1], mode='2'+act_mode))
323 | self.m_down2 = B.sequential(*[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[1], nc[2], mode='2'+act_mode))
324 | self.m_down3 = B.sequential(down_nonlocal, *[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb)], downsample_block(nc[2], nc[3], mode='2'+act_mode))
325 |
326 | self.m_body = B.sequential(*[B.conv(nc[3], nc[3], mode='C'+act_mode) for _ in range(nb+1)])
327 |
328 | # upsample
329 | if upsample_mode == 'upconv':
330 | upsample_block = B.upsample_upconv
331 | elif upsample_mode == 'pixelshuffle':
332 | upsample_block = B.upsample_pixelshuffle
333 | elif upsample_mode == 'convtranspose':
334 | upsample_block = B.upsample_convtranspose
335 | else:
336 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
337 |
338 |
339 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], mode='2'+act_mode), *[B.conv(nc[2], nc[2], mode='C'+act_mode) for _ in range(nb)], up_nonlocal)
340 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], mode='2'+act_mode), *[B.conv(nc[1], nc[1], mode='C'+act_mode) for _ in range(nb)])
341 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], mode='2'+act_mode), *[B.conv(nc[0], nc[0], mode='C'+act_mode) for _ in range(nb)])
342 |
343 | self.m_tail = B.conv(nc[0], out_nc, mode='C')
344 |
345 | def forward(self, x0):
346 | x1 = self.m_head(x0)
347 | x2 = self.m_down1(x1)
348 | x3 = self.m_down2(x2)
349 | x4 = self.m_down3(x3)
350 | x = self.m_body(x4)
351 | x = self.m_up3(x+x4)
352 | x = self.m_up2(x+x3)
353 | x = self.m_up1(x+x2)
354 | x = self.m_tail(x+x1) + x0
355 | return x
356 |
357 |
358 | if __name__ == '__main__':
359 | x = torch.rand(1,3,256,256)
360 | # net = UNet(act_mode='BR')
361 | net = NonLocalUNet()
362 | net.eval()
363 | with torch.no_grad():
364 | y = net(x)
365 | y.size()
366 |
367 |
--------------------------------------------------------------------------------
/testsets/set12/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/01.png
--------------------------------------------------------------------------------
/testsets/set12/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/02.png
--------------------------------------------------------------------------------
/testsets/set12/03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/03.png
--------------------------------------------------------------------------------
/testsets/set12/04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/04.png
--------------------------------------------------------------------------------
/testsets/set12/05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/05.png
--------------------------------------------------------------------------------
/testsets/set12/06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/06.png
--------------------------------------------------------------------------------
/testsets/set12/07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/07.png
--------------------------------------------------------------------------------
/testsets/set12/08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/08.png
--------------------------------------------------------------------------------
/testsets/set12/09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/09.png
--------------------------------------------------------------------------------
/testsets/set12/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/10.png
--------------------------------------------------------------------------------
/testsets/set12/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/11.png
--------------------------------------------------------------------------------
/testsets/set12/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set12/12.png
--------------------------------------------------------------------------------
/testsets/set3c/butterfly.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set3c/butterfly.png
--------------------------------------------------------------------------------
/testsets/set3c/leaves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set3c/leaves.png
--------------------------------------------------------------------------------
/testsets/set3c/starfish.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set3c/starfish.png
--------------------------------------------------------------------------------
/testsets/set5/baby_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set5/baby_GT.bmp
--------------------------------------------------------------------------------
/testsets/set5/bird_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set5/bird_GT.bmp
--------------------------------------------------------------------------------
/testsets/set5/butterfly_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set5/butterfly_GT.bmp
--------------------------------------------------------------------------------
/testsets/set5/head_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set5/head_GT.bmp
--------------------------------------------------------------------------------
/testsets/set5/woman_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/testsets/set5/woman_GT.bmp
--------------------------------------------------------------------------------
/utils/test.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszn/DPIR/15bca3fcc1f3cc51a1f99ccf027691e278c19354/utils/test.bmp
--------------------------------------------------------------------------------
/utils/utils_bnorm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def deleteLayer(model, layer_type=nn.BatchNorm2d):
6 | for k, m in list(model.named_children()):
7 | if isinstance(m, layer_type):
8 | del model._modules[k]
9 | deleteLayer(m, layer_type)
10 |
11 |
12 | def merge_bn(model):
13 | ''' by Kai Zhang, 11/01/2019.
14 | '''
15 | prev_m = None
16 | for k, m in list(model.named_children()):
17 | if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)):
18 |
19 | w = prev_m.weight.data
20 |
21 | if prev_m.bias is None:
22 | zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type())
23 | prev_m.bias = nn.Parameter(zeros)
24 | b = prev_m.bias.data
25 |
26 | invstd = m.running_var.clone().add_(m.eps).pow_(-0.5)
27 | if isinstance(prev_m, nn.ConvTranspose2d):
28 | w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w))
29 | else:
30 | w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
31 | b.add_(-m.running_mean).mul_(invstd)
32 | if m.affine:
33 | if isinstance(prev_m, nn.ConvTranspose2d):
34 | w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w))
35 | else:
36 | w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
37 | b.mul_(m.weight.data).add_(m.bias.data)
38 |
39 | del model._modules[k]
40 | prev_m = m
41 | merge_bn(m)
42 |
43 |
44 | def add_bn(model, for_init=True):
45 | ''' by Kai Zhang, 11/01/2019.
46 | '''
47 | for k, m in list(model.named_children()):
48 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)):
49 | if for_init:
50 | b = nn.BatchNorm2d(m.out_channels, momentum=None, affine=False, eps=1e-04)
51 | else:
52 | b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True, eps=1e-04)
53 | b.weight.data.fill_(1)
54 |
55 | new_m = nn.Sequential(model._modules[k], b)
56 | model._modules[k] = new_m
57 | add_bn(m, for_init)
58 |
59 |
60 | def deploy_sequential(model):
61 | ''' by Kai Zhang, 11/01/2019.
62 | singleton children
63 | '''
64 | for k, m in list(model.named_children()):
65 | if isinstance(m, nn.Sequential):
66 | if m.__len__() == 1:
67 | model._modules[k] = m.__getitem__(0)
68 | deploy_sequential(m)
69 |
70 |
--------------------------------------------------------------------------------
/utils/utils_csmri.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | from scipy import fftpack
4 | import torch
5 | from scipy import ndimage
6 | from utils import utils_image as util
7 | from scipy.interpolate import interp2d
8 | from scipy import signal
9 | import scipy.stats as ss
10 | import scipy.io as io
11 | import scipy
12 |
13 | '''
14 | modified by Kai Zhang (github: https://github.com/cszn)
15 | 03/03/2019
16 | '''
17 |
18 |
19 | '''
20 | # =================
21 | # pytorch
22 | # =================
23 | '''
24 |
25 |
26 | def splits(a, sf):
27 | '''split a into sfxsf distinct blocks
28 |
29 | Args:
30 | a: NxCxWxHx2
31 | sf: split factor
32 |
33 | Returns:
34 | b: NxCx(W/sf)x(H/sf)x2x(sf^2)
35 | '''
36 | b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
37 | b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
38 | return b
39 |
40 |
41 | def c2c(x):
42 | return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
43 |
44 |
45 | def r2c(x):
46 | # convert real to complex
47 | return torch.stack([x, torch.zeros_like(x)], -1)
48 |
49 |
50 | def cdiv(x, y):
51 | # complex division
52 | a, b = x[..., 0], x[..., 1]
53 | c, d = y[..., 0], y[..., 1]
54 | cd2 = c**2 + d**2
55 | return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
56 |
57 |
58 | def crdiv(x, y):
59 | # complex/real division
60 | a, b = x[..., 0], x[..., 1]
61 | return torch.stack([a/y, b/y], -1)
62 |
63 |
64 | def csum(x, y):
65 | # complex + real
66 | return torch.stack([x[..., 0] + y, x[..., 1]], -1)
67 |
68 |
69 | def cabs(x):
70 | # modulus of a complex number
71 | return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
72 |
73 |
74 | def cabs2(x):
75 | return x[..., 0]**2+x[..., 1]**2
76 |
77 |
78 | def cmul(t1, t2):
79 | '''complex multiplication
80 |
81 | Args:
82 | t1: NxCxHxWx2, complex tensor
83 | t2: NxCxHxWx2
84 |
85 | Returns:
86 | output: NxCxHxWx2
87 | '''
88 | real1, imag1 = t1[..., 0], t1[..., 1]
89 | real2, imag2 = t2[..., 0], t2[..., 1]
90 | return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
91 |
92 |
93 | def cconj(t, inplace=False):
94 | '''complex's conjugation
95 |
96 | Args:
97 | t: NxCxHxWx2
98 |
99 | Returns:
100 | output: NxCxHxWx2
101 | '''
102 | c = t.clone() if not inplace else t
103 | c[..., 1] *= -1
104 | return c
105 |
106 |
107 | def rfft(t):
108 | # Real-to-complex Discrete Fourier Transform
109 | return torch.rfft(t, 2, onesided=False)
110 |
111 |
112 | def irfft(t):
113 | # Complex-to-real Inverse Discrete Fourier Transform
114 | return torch.irfft(t, 2, onesided=False)
115 |
116 |
117 | def fft(t):
118 | # Complex-to-complex Discrete Fourier Transform
119 | return torch.fft(t, 2)
120 |
121 |
122 | def ifft(t):
123 | # Complex-to-complex Inverse Discrete Fourier Transform
124 | return torch.ifft(t, 2)
125 |
126 |
127 | def p2o(psf, shape):
128 | '''
129 | Convert point-spread function to optical transfer function.
130 | otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the
131 | point-spread function (PSF) array and creates the optical transfer
132 | function (OTF) array that is not influenced by the PSF off-centering.
133 |
134 | Args:
135 | psf: NxCxhxw
136 | shape: [H, W]
137 |
138 | Returns:
139 | otf: NxCxHxWx2
140 | '''
141 | otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
142 | otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
143 | for axis, axis_size in enumerate(psf.shape[2:]):
144 | otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
145 | otf = torch.rfft(otf, 2, onesided=False)
146 | n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
147 | otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
148 | return otf
149 |
150 |
151 | def upsample(x, sf=3):
152 | '''s-fold upsampler
153 |
154 | Upsampling the spatial size by filling the new entries with zeros
155 |
156 | x: tensor image, NxCxWxH
157 | '''
158 | st = 0
159 | z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
160 | z[..., st::sf, st::sf].copy_(x)
161 | return z
162 |
163 |
164 | def downsample(x, sf=3):
165 | '''s-fold downsampler
166 |
167 | Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others
168 |
169 | x: tensor image, NxCxWxH
170 | '''
171 | st = 0
172 | return x[..., st::sf, st::sf]
173 |
174 |
175 | def data_solution(x, FB, FBC, F2B, FBFy, alpha, sf):
176 | FR = FBFy + torch.rfft(alpha*x, 2, onesided=False)
177 | x1 = cmul(FB, FR)
178 | FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
179 | invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
180 | invWBR = cdiv(FBR, csum(invW, alpha))
181 | FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1))
182 | FX = (FR-FCBinvWBR)/alpha.unsqueeze(-1)
183 | Xest = torch.irfft(FX, 2, onesided=False)
184 | return Xest
185 |
186 |
187 | def pre_calculate(x, k, sf):
188 | '''
189 | Args:
190 | x: NxCxHxW, LR input
191 | k: NxCxhxw
192 | sf: integer
193 |
194 | Returns:
195 | FB, FBC, F2B, FBFy
196 | will be reused during iterations
197 | '''
198 | w, h = x.shape[-2:]
199 | FB = p2o(k, (w*sf, h*sf))
200 | FBC = cconj(FB, inplace=False)
201 | F2B = r2c(cabs2(FB))
202 | STy = upsample(x, sf=sf)
203 | FBFy = cmul(FBC, torch.rfft(STy, 2, onesided=False))
204 | return FB, FBC, F2B, FBFy
205 |
206 |
207 | '''
208 | # =================
209 | PyTorch
210 | # =================
211 | '''
212 |
213 |
214 | def real2complex(x):
215 | return torch.stack([x, torch.zeros_like(x)], -1)
216 |
217 |
218 | def modcrop(img, sf):
219 | '''
220 | img: tensor image, NxCxWxH or CxWxH or WxH
221 | sf: scale factor
222 | '''
223 | w, h = img.shape[-2:]
224 | im = img.clone()
225 | return im[..., :w - w % sf, :h - h % sf]
226 |
227 |
228 | def circular_pad(x, pad):
229 | '''
230 | # x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
231 | '''
232 | x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
233 | x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
234 | x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
235 | x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
236 | return x
237 |
238 |
239 | def pad_circular(input, padding):
240 | # type: (Tensor, List[int]) -> Tensor
241 | """
242 | Arguments
243 | :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
244 | :param padding: (tuple): m-elem tuple where m is the degree of convolution
245 | Returns
246 | :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
247 | H + 2 * padding[1]], W + 2 * padding[2]))`
248 | """
249 | offset = 3
250 | for dimension in range(input.dim() - offset + 1):
251 | input = dim_pad_circular(input, padding[dimension], dimension + offset)
252 | return input
253 |
254 |
255 | def dim_pad_circular(input, padding, dimension):
256 | # type: (Tensor, int, int) -> Tensor
257 | input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
258 | [slice(0, padding)]]], dim=dimension - 1)
259 | input = torch.cat([input[[slice(None)] * (dimension - 1) +
260 | [slice(-2 * padding, -padding)]], input], dim=dimension - 1)
261 | return input
262 |
263 |
264 | def imfilter(x, k):
265 | '''
266 | x: image, NxcxHxW
267 | k: kernel, cx1xhxw
268 | '''
269 | x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
270 | x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
271 | return x
272 |
273 |
274 | def G(x, k, sf=3):
275 | '''
276 | x: image, NxcxHxW
277 | k: kernel, cx1xhxw
278 | sf: scale factor
279 | center: the first one or the moddle one
280 |
281 | Matlab function:
282 | tmp = imfilter(x,h,'circular');
283 | y = downsample2(tmp,K);
284 | '''
285 | x = downsample(imfilter(x, k), sf=sf)
286 | return x
287 |
288 |
289 | def Gt(x, k, sf=3):
290 | '''
291 | x: image, NxcxHxW
292 | k: kernel, cx1xhxw
293 | sf: scale factor
294 | center: the first one or the moddle one
295 |
296 | Matlab function:
297 | tmp = upsample2(x,K);
298 | y = imfilter(tmp,h,'circular');
299 | '''
300 | x = imfilter(upsample(x, sf=sf), k)
301 | return x
302 |
303 |
304 | def interpolation_down(x, sf, center=False):
305 | mask = torch.zeros_like(x)
306 | if center:
307 | start = torch.tensor((sf-1)//2)
308 | mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
309 | LR = x[..., start::sf, start::sf]
310 | else:
311 | mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
312 | LR = x[..., ::sf, ::sf]
313 | y = x.mul(mask)
314 |
315 | return LR, y, mask
316 |
317 |
318 | """
319 | # --------------------------------------------
320 | # degradation models
321 | # --------------------------------------------
322 | """
323 |
324 |
325 | def csmri_degradation(x, M):
326 | '''
327 | Args:
328 | x: 1x1xWxH image, [0, 1]
329 | M: mask, WxH
330 | n: noise, WxHx2
331 | '''
332 | x = rfft(x).mul(M.unsqueeze(-1).unsqueeze(0).unsqueeze(0)) # + n.unsqueeze(0).unsqueeze(0)
333 | return x
334 |
335 |
336 | if __name__ == '__main__':
337 |
338 | weight = torch.tensor([[1.,2.,3.],[4.,5.,6.],[7.,8.,9.],[7.,8.,9.],[7.,8.,9.]]).view(1,1,5,3)
339 | input = torch.linspace(1,9,9).view(1,1,3,3)
340 | input = pad_circular(input, (2,1))
341 |
342 |
343 |
344 |
--------------------------------------------------------------------------------
/utils/utils_deblur.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | import scipy
4 | from scipy import fftpack
5 | import torch
6 |
7 | from math import cos, sin
8 | from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round
9 | from numpy.random import randn, rand
10 | from scipy.signal import convolve2d
11 |
12 | # import utils_image as util
13 |
14 | '''
15 | modified by Kai Zhang (github: https://github.com/cszn)
16 | 03/03/2019
17 | '''
18 |
19 |
20 | def get_uperleft_denominator(img, kernel):
21 | '''
22 | img: HxWxC
23 | kernel: hxw
24 | denominator: HxWx1
25 | upperleft: HxWxC
26 | '''
27 | V = psf2otf(kernel, img.shape[:2])
28 | denominator = np.expand_dims(np.abs(V)**2, axis=2)
29 | upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
30 | return upperleft, denominator
31 |
32 |
33 | def get_uperleft_denominator_pytorch(img, kernel):
34 | '''
35 | img: NxCxHxW
36 | kernel: Nx1xhxw
37 | denominator: Nx1xHxW
38 | upperleft: NxCxHxWx2
39 | '''
40 | V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2
41 | denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW
42 | upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2
43 | return upperleft, denominator
44 |
45 |
46 | def c2c(x):
47 | return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
48 |
49 |
50 | def r2c(x):
51 | return torch.stack([x, torch.zeros_like(x)], -1)
52 |
53 |
54 | def cdiv(x, y):
55 | a, b = x[..., 0], x[..., 1]
56 | c, d = y[..., 0], y[..., 1]
57 | cd2 = c**2 + d**2
58 | return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
59 |
60 |
61 | def cabs(x):
62 | return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
63 |
64 |
65 | def cmul(t1, t2):
66 | '''
67 | complex multiplication
68 | t1: NxCxHxWx2
69 | output: NxCxHxWx2
70 | '''
71 | real1, imag1 = t1[..., 0], t1[..., 1]
72 | real2, imag2 = t2[..., 0], t2[..., 1]
73 | return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
74 |
75 |
76 | def cconj(t, inplace=False):
77 | '''
78 | # complex's conjugation
79 | t: NxCxHxWx2
80 | output: NxCxHxWx2
81 | '''
82 | c = t.clone() if not inplace else t
83 | c[..., 1] *= -1
84 | return c
85 |
86 |
87 | def rfft(t):
88 | return torch.rfft(t, 2, onesided=False)
89 |
90 |
91 | def irfft(t):
92 | return torch.irfft(t, 2, onesided=False)
93 |
94 |
95 | def fft(t):
96 | return torch.fft(t, 2)
97 |
98 |
99 | def ifft(t):
100 | return torch.ifft(t, 2)
101 |
102 |
103 | def p2o(psf, shape):
104 | '''
105 | # psf: NxCxhxw
106 | # shape: [H,W]
107 | # otf: NxCxHxWx2
108 | '''
109 | otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
110 | otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
111 | for axis, axis_size in enumerate(psf.shape[2:]):
112 | otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
113 | otf = torch.rfft(otf, 2, onesided=False)
114 | n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
115 | otf[...,1][torch.abs(otf[...,1])= abs(y)] = abs(x)[abs(x) >= abs(y)]
472 | maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)]
473 | minxy = np.zeros(x.shape)
474 | minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)]
475 | minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)]
476 | m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\
477 | (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\
478 | np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2)
479 | m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\
480 | (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\
481 | np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2)
482 | h = None
483 | return h
484 |
485 |
486 | def fspecial_gaussian(hsize, sigma):
487 | hsize = [hsize, hsize]
488 | siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
489 | std = sigma
490 | [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
491 | arg = -(x*x + y*y)/(2*std*std)
492 | h = np.exp(arg)
493 | h[h < scipy.finfo(float).eps * h.max()] = 0
494 | sumh = h.sum()
495 | if sumh != 0:
496 | h = h/sumh
497 | return h
498 |
499 |
500 | def fspecial_laplacian(alpha):
501 | alpha = max([0, min([alpha,1])])
502 | h1 = alpha/(alpha+1)
503 | h2 = (1-alpha)/(alpha+1)
504 | h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
505 | h = np.array(h)
506 | return h
507 |
508 |
509 | def fspecial_log(hsize, sigma):
510 | raise(NotImplemented)
511 |
512 |
513 | def fspecial_motion(motion_len, theta):
514 | raise(NotImplemented)
515 |
516 |
517 | def fspecial_prewitt():
518 | return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])
519 |
520 |
521 | def fspecial_sobel():
522 | return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
523 |
524 |
525 | def fspecial(filter_type, *args, **kwargs):
526 | '''
527 | python code from:
528 | https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
529 | '''
530 | if filter_type == 'average':
531 | return fspecial_average(*args, **kwargs)
532 | if filter_type == 'disk':
533 | return fspecial_disk(*args, **kwargs)
534 | if filter_type == 'gaussian':
535 | return fspecial_gaussian(*args, **kwargs)
536 | if filter_type == 'laplacian':
537 | return fspecial_laplacian(*args, **kwargs)
538 | if filter_type == 'log':
539 | return fspecial_log(*args, **kwargs)
540 | if filter_type == 'motion':
541 | return fspecial_motion(*args, **kwargs)
542 | if filter_type == 'prewitt':
543 | return fspecial_prewitt(*args, **kwargs)
544 | if filter_type == 'sobel':
545 | return fspecial_sobel(*args, **kwargs)
546 |
547 |
548 | def fspecial_gauss(size, sigma):
549 | x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1]
550 | g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
551 | return g / g.sum()
552 |
553 |
554 | def blurkernel_synthesis(h=37, w=None):
555 | w = h if w is None else w
556 | kdims = [h, w]
557 | x = randomTrajectory(150)
558 | k = None
559 | while k is None:
560 | k = kernelFromTrajectory(x)
561 |
562 | # center pad to kdims
563 | pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2)
564 | pad_width = [(pad_width[0],), (pad_width[1],)]
565 | if pad_width[0][0]<0 or pad_width[1][0]<0:
566 | k = k[0:h, 0:h]
567 | else:
568 | k = pad(k, pad_width, "constant")
569 | # import matplotlib.pyplot as plt
570 | # plt.imshow(k, interpolation="nearest", cmap="gray")
571 | # plt.show()
572 | #print(k.dtype)
573 | return k
574 |
575 |
576 | def kernelFromTrajectory(x):
577 | h = 5 - log(rand()) / 0.15
578 | h = round(min([h, 27])).astype(int)
579 | h = h + 1 - h % 2
580 | w = h
581 | k = zeros((h, w))
582 |
583 | xmin = min(x[0])
584 | xmax = max(x[0])
585 | ymin = min(x[1])
586 | ymax = max(x[1])
587 | xthr = arange(xmin, xmax, (xmax - xmin) / w)
588 | ythr = arange(ymin, ymax, (ymax - ymin) / h)
589 |
590 | for i in range(1, xthr.size):
591 | for j in range(1, ythr.size):
592 | idx = (
593 | (x[0, :] >= xthr[i - 1])
594 | & (x[0, :] < xthr[i])
595 | & (x[1, :] >= ythr[j - 1])
596 | & (x[1, :] < ythr[j])
597 | )
598 | k[i - 1, j - 1] = sum(idx)
599 | if sum(k) == 0:
600 | return
601 | k = k / sum(k)
602 | k = convolve2d(k, fspecial_gauss(3, 1), "same")
603 | k = k / sum(k)
604 | return k
605 |
606 |
607 | def randomTrajectory(T):
608 | x = zeros((3, T))
609 | v = randn(3, T)
610 | r = zeros((3, T))
611 | trv = 1 / 1
612 | trr = 2 * pi / T
613 | for t in range(1, T):
614 | F_rot = randn(3) / (t + 1) + r[:, t - 1]
615 | F_trans = randn(3) / (t + 1)
616 | r[:, t] = r[:, t - 1] + trr * F_rot
617 | v[:, t] = v[:, t - 1] + trv * F_trans
618 | st = v[:, t]
619 | st = rot3D(st, r[:, t])
620 | x[:, t] = x[:, t - 1] + st
621 | return x
622 |
623 |
624 | def rot3D(x, r):
625 | Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]])
626 | Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]])
627 | Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]])
628 | R = Rz @ Ry @ Rx
629 | x = R @ x
630 | return x
631 |
632 |
633 | if __name__ == '__main__':
634 | # a = opt_fft_size([111])
635 | # print(a)
636 | #
637 | # print(fspecial('gaussian', 5, 1))
638 | #
639 | # print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape)
640 |
641 | k = blurkernel_synthesis(25)
642 | print(k.shape)
643 | print(sum(k))
644 | import matplotlib.pyplot as plt
645 | plt.imshow(k, interpolation="nearest", cmap="gray")
646 | plt.show()
647 |
648 | # kernel = fspecial('gaussian', 3, 1)
649 | # img = np.random.randn(5,5,1)
650 | # a, b = get_uperleft_denominator(img, kernel)
651 | # print(a)
652 | ## print(b)
653 | #
654 | # a, b = get_uperleft_denominator_pytorch(util.single2tensor4(img), util.single2tensor4(kernel[...,np.newaxis]))
655 | # print(a.squeeze())
656 | # print(b.squeeze())
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 | # get_uperleft_denominator_pytorch(
665 |
--------------------------------------------------------------------------------
/utils/utils_inpaint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | from utils import utils_image as util
4 |
5 | '''
6 | modified by Kai Zhang (github: https://github.com/cszn)
7 | 03/03/2019
8 | '''
9 |
10 |
11 | # --------------------------------
12 | # get rho and sigma
13 | # --------------------------------
14 | def get_rho_sigma(sigma=2.55/255, iter_num=15, modelSigma2=2.55):
15 | '''
16 | Kai Zhang (github: https://github.com/cszn)
17 | 03/03/2019
18 | '''
19 | modelSigma1 = 49.0
20 | modelSigmaS = np.logspace(np.log10(modelSigma1), np.log10(modelSigma2), iter_num)
21 | sigmas = modelSigmaS/255.
22 | mus = list(map(lambda x: (sigma**2)/(x**2)/3, sigmas))
23 | rhos = mus
24 | return rhos, sigmas
25 |
26 |
27 | def shepard_initialize(image, measurement_mask, window=5, p=2):
28 | wing = np.floor(window/2).astype(int) # Length of each "wing" of the window.
29 | h, w = image.shape[0:2]
30 | ch = 3 if image.ndim == 3 and image.shape[-1] == 3 else 1
31 | x = np.copy(image) # ML initialization
32 | for i in range(h):
33 | i_lower_limit = -np.min([wing, i])
34 | i_upper_limit = np.min([wing, h-i-1])
35 | for j in range(w):
36 | if measurement_mask[i, j] == 0: # checking if there's a need to interpolate
37 | j_lower_limit = -np.min([wing, j])
38 | j_upper_limit = np.min([wing, w-j-1])
39 |
40 | count = 0 # keeps track of how many measured pixels are withing the neighborhood
41 | sum_IPD = 0
42 | interpolated_value = 0
43 |
44 | num_zeros = window**2
45 | IPD = np.zeros([num_zeros])
46 | pixel = np.zeros([num_zeros,ch])
47 |
48 | for neighborhood_i in range(i+i_lower_limit, i+i_upper_limit):
49 | for neighborhood_j in range(j+j_lower_limit, j+j_upper_limit):
50 | if measurement_mask[neighborhood_i, neighborhood_j] == 1:
51 | # IPD: "inverse pth-power distance".
52 | IPD[count] = 1.0/((neighborhood_i - i)**p + (neighborhood_j - j)**p)
53 | sum_IPD = sum_IPD + IPD[count]
54 | pixel[count] = image[neighborhood_i, neighborhood_j]
55 | count = count + 1
56 |
57 | for c in range(count):
58 | weight = IPD[c]/sum_IPD
59 | interpolated_value = interpolated_value + weight*pixel[c]
60 | x[i, j] = interpolated_value
61 |
62 | return x
63 |
64 |
65 | if __name__ == '__main__':
66 | # image path & sampling ratio
67 | import matplotlib.pyplot as mplot
68 | import matplotlib.image as mpimg
69 | Im = mpimg.imread('test.bmp')
70 | #Im = Im[:,:,1]
71 | Im = np.squeeze(Im)
72 |
73 | SmpRatio = 0.2
74 | # creat mask
75 | mask_Array = np.random.rand(Im.shape[0],Im.shape[1])
76 | mask_Array = (mask_Array < SmpRatio)
77 | print(mask_Array.dtype)
78 |
79 | # sampled image
80 | print('The sampling ratio is', SmpRatio)
81 | Im_sampled = np.multiply(np.expand_dims(mask_Array,2), Im)
82 | util.imshow(Im_sampled)
83 |
84 | a = shepard_initialize(Im_sampled.astype(np.float32), mask_Array, window=9)
85 | a = np.clip(a,0,255)
86 |
87 |
88 | print(a.dtype)
89 |
90 |
91 | util.imshow(np.concatenate((a,Im_sampled),1)/255.0)
92 | util.imsave(np.concatenate((a,Im_sampled),1),'a.png')
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/utils/utils_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import datetime
4 | import logging
5 |
6 |
7 | '''
8 | modified by Kai Zhang (github: https://github.com/cszn)
9 | 03/03/2019
10 | https://github.com/xinntao/BasicSR
11 | '''
12 |
13 |
14 | def log(*args, **kwargs):
15 | print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
16 |
17 |
18 | '''
19 | # ===============================
20 | # logger
21 | # logger_name = None = 'base' ???
22 | # ===============================
23 | '''
24 |
25 |
26 | def logger_info(logger_name, log_path='default_logger.log'):
27 | ''' set up logger
28 | modified by Kai Zhang (github: https://github.com/cszn)
29 | '''
30 | log = logging.getLogger(logger_name)
31 | if log.hasHandlers():
32 | print('LogHandlers exists!')
33 | else:
34 | print('LogHandlers setup!')
35 | level = logging.INFO
36 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
37 | fh = logging.FileHandler(log_path, mode='a')
38 | fh.setFormatter(formatter)
39 | log.setLevel(level)
40 | log.addHandler(fh)
41 | # print(len(log.handlers))
42 |
43 | sh = logging.StreamHandler()
44 | sh.setFormatter(formatter)
45 | log.addHandler(sh)
46 |
47 |
48 | '''
49 | # ===============================
50 | # print to file and std_out simultaneously
51 | # ===============================
52 | '''
53 |
54 |
55 | class logger_print(object):
56 | def __init__(self, log_path="default.log"):
57 | self.terminal = sys.stdout
58 | self.log = open(log_path, 'a')
59 |
60 | def write(self, message):
61 | self.terminal.write(message)
62 | self.log.write(message) # write the message
63 |
64 | def flush(self):
65 | pass
66 |
--------------------------------------------------------------------------------
/utils/utils_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | import torch
4 | from utils import utils_image as util
5 |
6 |
7 | '''
8 | modified by Kai Zhang (github: https://github.com/cszn)
9 | 03/03/2019
10 | '''
11 |
12 |
13 | def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1):
14 | '''
15 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
16 | # Some testing modes
17 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
18 | # (0) normal: test(model, L)
19 | # (1) pad: test_pad(model, L, modulo=16)
20 | # (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1)
21 | # (3) x8: test_x8(model, L, modulo=1)
22 | # (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1)
23 | # (4) split only once: test_onesplit(model, L, refield=32, min_size=256, sf=1, modulo=1)
24 | # ---------------------------------------
25 | '''
26 | if mode == 0:
27 | E = test(model, L)
28 | elif mode == 1:
29 | E = test_pad(model, L, modulo)
30 | elif mode == 2:
31 | E = test_split(model, L, refield, min_size, sf, modulo)
32 | elif mode == 3:
33 | E = test_x8(model, L, modulo)
34 | elif mode == 4:
35 | E = test_split_x8(model, L, refield, min_size, sf, modulo)
36 | elif mode == 5:
37 | E = test_onesplit(model, L, refield, min_size, sf, modulo)
38 | return E
39 |
40 |
41 | '''
42 | # ---------------------------------------
43 | # normal (0)
44 | # ---------------------------------------
45 | '''
46 |
47 |
48 | def test(model, L):
49 | E = model(L)
50 | return E
51 |
52 |
53 | '''
54 | # ---------------------------------------
55 | # pad (1)
56 | # ---------------------------------------
57 | '''
58 |
59 |
60 | def test_pad(model, L, modulo=16):
61 | h, w = L.size()[-2:]
62 | paddingBottom = int(np.ceil(h/modulo)*modulo-h)
63 | paddingRight = int(np.ceil(w/modulo)*modulo-w)
64 | L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
65 | E = model(L)
66 | E = E[..., :h, :w]
67 | return E
68 |
69 |
70 | '''
71 | # ---------------------------------------
72 | # split (function)
73 | # ---------------------------------------
74 | '''
75 |
76 |
77 | def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1):
78 | '''
79 | model:
80 | L: input Low-quality image
81 | refield: effective receptive filed of the network, 32 is enough
82 | min_size: min_sizeXmin_size image, e.g., 256X256 image
83 | sf: scale factor for super-resolution, otherwise 1
84 | modulo: 1 if split
85 | '''
86 | h, w = L.size()[-2:]
87 | if h*w <= min_size**2:
88 | L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
89 | E = model(L)
90 | E = E[..., :h*sf, :w*sf]
91 | else:
92 | top = slice(0, (h//2//refield+1)*refield)
93 | bottom = slice(h - (h//2//refield+1)*refield, h)
94 | left = slice(0, (w//2//refield+1)*refield)
95 | right = slice(w - (w//2//refield+1)*refield, w)
96 | Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
97 |
98 | if h * w <= 4*(min_size**2):
99 | Es = [model(Ls[i]) for i in range(4)]
100 | else:
101 | Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
102 |
103 | b, c = Es[0].size()[:2]
104 | E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
105 |
106 | E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
107 | E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
108 | E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
109 | E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
110 | return E
111 |
112 |
113 |
114 | def test_onesplit(model, L, refield=32, min_size=256, sf=1, modulo=1):
115 | '''
116 | model:
117 | L: input Low-quality image
118 | refield: effective receptive filed of the network, 32 is enough
119 | min_size: min_sizeXmin_size image, e.g., 256X256 image
120 | sf: scale factor for super-resolution, otherwise 1
121 | modulo: 1 if split
122 | '''
123 | h, w = L.size()[-2:]
124 |
125 | top = slice(0, (h//2//refield+1)*refield)
126 | bottom = slice(h - (h//2//refield+1)*refield, h)
127 | left = slice(0, (w//2//refield+1)*refield)
128 | right = slice(w - (w//2//refield+1)*refield, w)
129 | Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
130 | Es = [model(Ls[i]) for i in range(4)]
131 | b, c = Es[0].size()[:2]
132 | E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
133 | E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
134 | E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
135 | E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
136 | E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
137 | return E
138 |
139 |
140 |
141 | '''
142 | # ---------------------------------------
143 | # split (2)
144 | # ---------------------------------------
145 | '''
146 |
147 |
148 | def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1):
149 | E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
150 | return E
151 |
152 |
153 | '''
154 | # ---------------------------------------
155 | # x8 (3)
156 | # ---------------------------------------
157 | '''
158 |
159 |
160 | def test_x8(model, L, modulo=1):
161 | E_list = [test_pad(model, util.augment_img_tensor(L, mode=i), modulo=modulo) for i in range(8)]
162 | for i in range(len(E_list)):
163 | if i == 3 or i == 5:
164 | E_list[i] = util.augment_img_tensor(E_list[i], mode=8 - i)
165 | else:
166 | E_list[i] = util.augment_img_tensor(E_list[i], mode=i)
167 | output_cat = torch.stack(E_list, dim=0)
168 | E = output_cat.mean(dim=0, keepdim=False)
169 | return E
170 |
171 |
172 | '''
173 | # ---------------------------------------
174 | # split and x8 (4)
175 | # ---------------------------------------
176 | '''
177 |
178 |
179 | def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1):
180 | E_list = [test_split_fn(model, util.augment_img_tensor(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)]
181 | for k, i in enumerate(range(len(E_list))):
182 | if i==3 or i==5:
183 | E_list[k] = util.augment_img_tensor(E_list[k], mode=8-i)
184 | else:
185 | E_list[k] = util.augment_img_tensor(E_list[k], mode=i)
186 | output_cat = torch.stack(E_list, dim=0)
187 | E = output_cat.mean(dim=0, keepdim=False)
188 | return E
189 |
190 |
191 | '''
192 | # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
193 | # _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_
194 | # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
195 | '''
196 |
197 |
198 | '''
199 | # ---------------------------------------
200 | # print
201 | # ---------------------------------------
202 | '''
203 |
204 |
205 | # -------------------
206 | # print model
207 | # -------------------
208 | def print_model(model):
209 | msg = describe_model(model)
210 | print(msg)
211 |
212 |
213 | # -------------------
214 | # print params
215 | # -------------------
216 | def print_params(model):
217 | msg = describe_params(model)
218 | print(msg)
219 |
220 |
221 | '''
222 | # ---------------------------------------
223 | # information
224 | # ---------------------------------------
225 | '''
226 |
227 |
228 | # -------------------
229 | # model inforation
230 | # -------------------
231 | def info_model(model):
232 | msg = describe_model(model)
233 | return msg
234 |
235 |
236 | # -------------------
237 | # params inforation
238 | # -------------------
239 | def info_params(model):
240 | msg = describe_params(model)
241 | return msg
242 |
243 |
244 | '''
245 | # ---------------------------------------
246 | # description
247 | # ---------------------------------------
248 | '''
249 |
250 |
251 | # ----------------------------------------------
252 | # model name and total number of parameters
253 | # ----------------------------------------------
254 | def describe_model(model):
255 | if isinstance(model, torch.nn.DataParallel):
256 | model = model.module
257 | msg = '\n'
258 | msg += 'models name: {}'.format(model.__class__.__name__) + '\n'
259 | msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n'
260 | msg += 'Net structure:\n{}'.format(str(model)) + '\n'
261 | return msg
262 |
263 |
264 | # ----------------------------------------------
265 | # parameters description
266 | # ----------------------------------------------
267 | def describe_params(model):
268 | if isinstance(model, torch.nn.DataParallel):
269 | model = model.module
270 | msg = '\n'
271 | msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'param_name') + '\n'
272 | for name, param in model.state_dict().items():
273 | if not 'num_batches_tracked' in name:
274 | v = param.data.clone().float()
275 | msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), name) + '\n'
276 | return msg
277 |
278 |
279 | if __name__ == '__main__':
280 |
281 | class Net(torch.nn.Module):
282 | def __init__(self, in_channels=3, out_channels=3):
283 | super(Net, self).__init__()
284 | self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
285 |
286 | def forward(self, x):
287 | x = self.conv(x)
288 | return x
289 |
290 | start = torch.cuda.Event(enable_timing=True)
291 | end = torch.cuda.Event(enable_timing=True)
292 |
293 | model = Net()
294 | model = model.eval()
295 | print_model(model)
296 | print_params(model)
297 | x = torch.randn((2,3,400,400))
298 | torch.cuda.empty_cache()
299 | with torch.no_grad():
300 | for mode in range(5):
301 | y = test_mode(model, x, mode)
302 | print(y.shape)
303 |
--------------------------------------------------------------------------------
/utils/utils_mosaic.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | from utils import utils_image as util
4 | #import utils_image as util
5 | import cv2
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | '''
11 | modified by Kai Zhang (github: https://github.com/cszn)
12 | 03/03/2019
13 | '''
14 | def dm(imgs):
15 | """ bilinear demosaicking
16 | Args:
17 | imgs: Nx4xW/2xH/2
18 |
19 | Returns:
20 | output: Nx3xWxH
21 | """
22 | k_r = 1/4 * torch.FloatTensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]]).type_as(imgs)
23 | k_g = 1/4 * torch.FloatTensor([[0, 1, 0], [1, 4, 1], [0, 1, 0]]).type_as(imgs)
24 | k = torch.stack((k_r,k_g,k_r), dim=0).unsqueeze(1)
25 |
26 | rgb = torch.zeros(imgs.size(0), 3, imgs.size(2)*2, imgs.size(3)*2).type_as(imgs)
27 | rgb[:, 0, 0::2, 0::2] = imgs[:, 0, :, :]
28 | rgb[:, 1, 0::2, 1::2] = imgs[:, 1, :, :]
29 | rgb[:, 1, 1::2, 0::2] = imgs[:, 2, :, :]
30 | rgb[:, 2, 1::2, 1::2] = imgs[:, 3, :, :]
31 |
32 | rgb = nn.functional.pad(rgb, (1, 1, 1, 1), mode='circular')
33 | rgb = nn.functional.conv2d(rgb, k, groups=3, padding=0, bias=None)
34 |
35 | return rgb
36 |
37 |
38 | def dm_matlab(imgs):
39 | """ matlab demosaicking
40 | Args:
41 | imgs: Nx4xW/2xH/2
42 |
43 | Returns:
44 | output: Nx3xWxH
45 | """
46 |
47 | kgrb = 1/8*torch.FloatTensor([[0, 0, -1, 0, 0],
48 | [0, 0, 2, 0, 0],
49 | [-1, 2, 4, 2, -1],
50 | [0, 0, 2, 0, 0],
51 | [0, 0, -1, 0, 0]]).type_as(imgs)
52 | krbg0 = 1/8*torch.FloatTensor([[0, 0, 1/2, 0, 0],
53 | [0, -1, 0, -1, 0],
54 | [-1, 4, 5, 4, -1],
55 | [0, -1, 0, -1, 0],
56 | [0, 0, 1/2, 0, 0]]).type_as(imgs)
57 | krbg1 = krbg0.t()
58 | krbbr = 1/8*torch.FloatTensor([[0, 0, -3/2, 0, 0],
59 | [0, 2, 0, 2, 0],
60 | [-3/2, 0, 6, 0, -3/2],
61 | [0, 2, 0, 2, 0],
62 | [0, 0, -3/2, 0, 0]]).type_as(imgs)
63 |
64 | k = torch.stack((kgrb, krbg0, krbg1, krbbr), 0).unsqueeze(1)
65 |
66 | cfa = torch.zeros(imgs.size(0), 1, imgs.size(2)*2, imgs.size(3)*2).type_as(imgs)
67 | cfa[:, 0, 0::2, 0::2] = imgs[:, 0, :, :]
68 | cfa[:, 0, 0::2, 1::2] = imgs[:, 1, :, :]
69 | cfa[:, 0, 1::2, 0::2] = imgs[:, 2, :, :]
70 | cfa[:, 0, 1::2, 1::2] = imgs[:, 3, :, :]
71 | rgb = cfa.repeat(1, 3, 1, 1)
72 |
73 | cfa = nn.functional.pad(cfa, (2, 2, 2, 2), mode='reflect')
74 | conv_cfa = nn.functional.conv2d(cfa, k, padding=0, bias=None)
75 |
76 | # fill G
77 | rgb[:, 1, 0::2, 0::2] = conv_cfa[:, 0, 0::2, 0::2]
78 | rgb[:, 1, 1::2, 1::2] = conv_cfa[:, 0, 1::2, 1::2]
79 |
80 | # fill R
81 | rgb[:, 0, 0::2, 1::2] = conv_cfa[:, 1, 0::2, 1::2]
82 | rgb[:, 0, 1::2, 0::2] = conv_cfa[:, 2, 1::2, 0::2]
83 | rgb[:, 0, 1::2, 1::2] = conv_cfa[:, 3, 1::2, 1::2]
84 |
85 | # fill B
86 | rgb[:, 2, 0::2, 1::2] = conv_cfa[:, 2, 0::2, 1::2]
87 | rgb[:, 2, 1::2, 0::2] = conv_cfa[:, 1, 1::2, 0::2]
88 | rgb[:, 2, 0::2, 0::2] = conv_cfa[:, 3, 0::2, 0::2]
89 |
90 | return rgb
91 |
92 |
93 | def tstack(a): # cv2.merge()
94 | a = np.asarray(a)
95 | return np.concatenate([x[..., np.newaxis] for x in a], axis=-1)
96 |
97 |
98 | def tsplit(a): # cv2.split()
99 | a = np.asarray(a)
100 | return np.array([a[..., x] for x in range(a.shape[-1])])
101 |
102 |
103 | def masks_CFA_Bayer(shape):
104 | pattern = 'RGGB'
105 | channels = dict((channel, np.zeros(shape)) for channel in 'RGB')
106 | for channel, (y, x) in zip(pattern, [(0, 0), (0, 1), (1, 0), (1, 1)]):
107 | channels[channel][y::2, x::2] = 1
108 | return tuple(channels[c].astype(bool) for c in 'RGB')
109 |
110 |
111 | def mosaic_CFA_Bayer(RGB):
112 | R_m, G_m, B_m = masks_CFA_Bayer(RGB.shape[0:2])
113 | mask = np.concatenate((R_m[..., np.newaxis], G_m[..., np.newaxis], B_m[..., np.newaxis]), axis=-1)
114 | # mask = tstack((R_m, G_m, B_m))
115 | mosaic = np.multiply(mask, RGB) # mask*RGB
116 | CFA = mosaic.sum(2).astype(np.uint8)
117 |
118 | CFA4 = np.zeros((RGB.shape[0]//2, RGB.shape[1]//2, 4), dtype=np.uint8)
119 | CFA4[:, :, 0] = CFA[0::2, 0::2]
120 | CFA4[:, :, 1] = CFA[0::2, 1::2]
121 | CFA4[:, :, 2] = CFA[1::2, 0::2]
122 | CFA4[:, :, 3] = CFA[1::2, 1::2]
123 |
124 | return CFA, CFA4, mosaic, mask
125 |
126 |
127 | if __name__ == '__main__':
128 |
129 | Im = util.imread_uint('test.bmp', 3)
130 |
131 | CFA, CFA4, mosaic, mask = mosaic_CFA_Bayer(Im)
132 | convertedImage = cv2.cvtColor(CFA, cv2.COLOR_BAYER_BG2RGB_EA)
133 |
134 | util.imshow(CFA)
135 | util.imshow(mosaic)
136 | util.imshow(mask.astype(np.float32))
137 | util.imshow(convertedImage)
138 |
139 | util.imsave(mask.astype(np.float32)*255,'bayer.png')
140 |
141 |
142 |
143 |
--------------------------------------------------------------------------------
/utils/utils_pnp.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 |
5 | '''
6 | modified by Kai Zhang (github: https://github.com/cszn)
7 | 03/03/2019
8 | '''
9 |
10 |
11 | # --------------------------------
12 | # get rho and sigma
13 | # --------------------------------
14 | #def get_rho_sigma(sigma=2.55/255, iter_num=15, modelSigma1=49.0, modelSigma2=2.55):
15 | # '''
16 | # One can change the sigma to implicitly change the trade-off parameter
17 | # between fidelity term and prior term
18 | # '''
19 | # modelSigmaS = np.logspace(np.log10(modelSigma1), np.log10(modelSigma2), iter_num).astype(np.float32)
20 | # sigmas = modelSigmaS/255.
21 | # rhos = list(map(lambda x: 0.23*(sigma**2)/(x**2), sigmas))
22 | # return rhos, sigmas
23 |
24 | # --------------------------------
25 | # get rho and sigma
26 | # --------------------------------
27 | def get_rho_sigma(sigma=2.55/255, iter_num=15, modelSigma1=49.0, modelSigma2=2.55, w=1.0):
28 | '''
29 | One can change the sigma to implicitly change the trade-off parameter
30 | between fidelity term and prior term
31 | '''
32 | modelSigmaS = np.logspace(np.log10(modelSigma1), np.log10(modelSigma2), iter_num).astype(np.float32)
33 | modelSigmaS_lin = np.linspace(modelSigma1, modelSigma2, iter_num).astype(np.float32)
34 | sigmas = (modelSigmaS*w+modelSigmaS_lin*(1-w))/255.
35 | rhos = list(map(lambda x: 0.23*(sigma**2)/(x**2), sigmas))
36 | return rhos, sigmas
37 |
38 |
39 | def get_rho_sigma1(sigma=2.55/255, iter_num=15, modelSigma1=49.0, modelSigma2=2.55, lamda=3.0):
40 | '''
41 | One can change the sigma to implicitly change the trade-off parameter
42 | between fidelity term and prior term
43 | '''
44 | modelSigmaS = np.logspace(np.log10(modelSigma1), np.log10(modelSigma2), iter_num).astype(np.float32)
45 | sigmas = modelSigmaS/255.
46 | rhos = list(map(lambda x: (sigma**2)/(x**2)/lamda, sigmas))
47 | return rhos, sigmas
48 |
49 |
50 | if __name__ == '__main__':
51 | rhos, sigmas = get_rho_sigma(sigma=2.55/255, iter_num=30, modelSigma2=2.55)
52 | print(rhos)
53 | print(sigmas*255)
54 |
--------------------------------------------------------------------------------
/utils/utils_sisr.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch.fft
3 | import torch
4 |
5 |
6 | def splits(a, sf):
7 | '''split a into sfxsf distinct blocks
8 | Args:
9 | a: NxCxWxH
10 | sf: split factor
11 | Returns:
12 | b: NxCx(W/sf)x(H/sf)x(sf^2)
13 | '''
14 | b = torch.stack(torch.chunk(a, sf, dim=2), dim=4)
15 | b = torch.cat(torch.chunk(b, sf, dim=3), dim=4)
16 | return b
17 |
18 |
19 | def p2o(psf, shape):
20 | '''
21 | Convert point-spread function to optical transfer function.
22 | otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the
23 | point-spread function (PSF) array and creates the optical transfer
24 | function (OTF) array that is not influenced by the PSF off-centering.
25 | Args:
26 | psf: NxCxhxw
27 | shape: [H, W]
28 | Returns:
29 | otf: NxCxHxWx2
30 | '''
31 | otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
32 | otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
33 | for axis, axis_size in enumerate(psf.shape[2:]):
34 | otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
35 | otf = torch.fft.fftn(otf, dim=(-2,-1))
36 | #n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
37 | #otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
38 | return otf
39 |
40 |
41 | def upsample(x, sf=3):
42 | '''s-fold upsampler
43 | Upsampling the spatial size by filling the new entries with zeros
44 | x: tensor image, NxCxWxH
45 | '''
46 | st = 0
47 | z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
48 | z[..., st::sf, st::sf].copy_(x)
49 | return z
50 |
51 |
52 | def downsample(x, sf=3):
53 | '''s-fold downsampler
54 | Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others
55 | x: tensor image, NxCxWxH
56 | '''
57 | st = 0
58 | return x[..., st::sf, st::sf]
59 |
60 |
61 |
62 | def data_solution(x, FB, FBC, F2B, FBFy, alpha, sf):
63 | FR = FBFy + torch.fft.fftn(alpha*x, dim=(-2,-1))
64 | x1 = FB.mul(FR)
65 | FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
66 | invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
67 | invWBR = FBR.div(invW + alpha)
68 | FCBinvWBR = FBC*invWBR.repeat(1, 1, sf, sf)
69 | FX = (FR-FCBinvWBR)/alpha
70 | Xest = torch.real(torch.fft.ifftn(FX, dim=(-2, -1)))
71 |
72 | return Xest
73 |
74 |
75 | def pre_calculate(x, k, sf):
76 | '''
77 | Args:
78 | x: NxCxHxW, LR input
79 | k: NxCxhxw
80 | sf: integer
81 |
82 | Returns:
83 | FB, FBC, F2B, FBFy
84 | will be reused during iterations
85 | '''
86 | w, h = x.shape[-2:]
87 | FB = p2o(k, (w*sf, h*sf))
88 | FBC = torch.conj(FB)
89 | F2B = torch.pow(torch.abs(FB), 2)
90 | STy = upsample(x, sf=sf)
91 | FBFy = FBC*torch.fft.fftn(STy, dim=(-2, -1))
92 | return FB, FBC, F2B, FBFy
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
--------------------------------------------------------------------------------
/utils/utils_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import random
4 | import numpy as np
5 | import torch
6 | import cv2
7 | import utils_image as util
8 |
9 | '''
10 | # =======================================
11 | # image processing process on numpy image
12 | # augment(img_list, hflip=True, rot=True):
13 | # =======================================
14 | '''
15 | # ----------------------------------------
16 | # get uint8 image of size HxWxn_channles (RGB)
17 | # ----------------------------------------
18 | def imread_uint(path, n_channels=3):
19 | # input: path
20 | # output: HxWx3(RGB or GGG), or HxWx1 (G)
21 | if n_channels == 1:
22 | img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
23 | img = np.expand_dims(img, axis=2) # HxWx1
24 | elif n_channels == 3:
25 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
26 | if img.ndim == 2:
27 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
28 | else:
29 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
30 | return img
31 |
32 |
33 | def augment_img(img, mode=0):
34 | if mode == 0:
35 | return img
36 | elif mode == 1:
37 | return np.flipud(np.rot90(img))
38 | elif mode == 2:
39 | return np.flipud(img)
40 | elif mode == 3:
41 | return np.rot90(img, k=3)
42 | elif mode == 4:
43 | return np.flipud(np.rot90(img, k=2))
44 | elif mode == 5:
45 | return np.rot90(img)
46 | elif mode == 6:
47 | return np.rot90(img, k=2)
48 | elif mode == 7:
49 | return np.flipud(np.rot90(img, k=3))
50 |
51 |
52 | def augment_img_tensor4(img, mode=0):
53 | if mode == 0:
54 | return img
55 | elif mode == 1:
56 | return np.flipud(np.rot90(img))
57 | elif mode == 2:
58 | return np.flipud(img)
59 | elif mode == 3:
60 | return np.rot90(img, k=3)
61 | elif mode == 4:
62 | return np.flipud(np.rot90(img, k=2))
63 | elif mode == 5:
64 | return np.rot90(img)
65 | elif mode == 6:
66 | return np.rot90(img, k=2)
67 | elif mode == 7:
68 | return np.flipud(np.rot90(img, k=3))
69 |
70 |
71 | def augment_img_np3(img, mode=0):
72 | if mode == 0:
73 | return img
74 | elif mode == 1:
75 | return img.transpose(1, 0, 2)
76 | elif mode == 2:
77 | return img[::-1, :, :]
78 | elif mode == 3:
79 | img = img[::-1, :, :]
80 | img = img.transpose(1, 0, 2)
81 | return img
82 | elif mode == 4:
83 | return img[:, ::-1, :]
84 | elif mode == 5:
85 | img = img[:, ::-1, :]
86 | img = img.transpose(1, 0, 2)
87 | return img
88 | elif mode == 6:
89 | img = img[:, ::-1, :]
90 | img = img[::-1, :, :]
91 | return img
92 | elif mode == 7:
93 | img = img[:, ::-1, :]
94 | img = img[::-1, :, :]
95 | img = img.transpose(1, 0, 2)
96 | return img
97 |
98 |
99 | def augment_img_tensor(img, mode=0):
100 | img_size = img.size()
101 | img_np = img.data.cpu().numpy()
102 | if len(img_size) == 3:
103 | img_np = np.transpose(img_np, (1, 2, 0))
104 | elif len(img_size) == 4:
105 | img_np = np.transpose(img_np, (2, 3, 1, 0))
106 | img_np = augment_img(img_np, mode=mode)
107 | img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
108 | if len(img_size) == 3:
109 | img_tensor = img_tensor.permute(2, 0, 1)
110 | elif len(img_size) == 4:
111 | img_tensor = img_tensor.permute(3, 2, 0, 1)
112 |
113 | return img_tensor.type_as(img)
114 |
115 |
116 | def augment_imgs(img_list, hflip=True, rot=True):
117 | # horizontal flip OR rotate
118 | hflip = hflip and random.random() < 0.5
119 | vflip = rot and random.random() < 0.5
120 | rot90 = rot and random.random() < 0.5
121 |
122 | def _augment(img):
123 | if hflip:
124 | img = img[:, ::-1, :]
125 | if vflip:
126 | img = img[::-1, :, :]
127 | if rot90:
128 | img = img.transpose(1, 0, 2)
129 | return img
130 |
131 | return [_augment(img) for img in img_list]
132 |
133 |
134 |
135 | def p2o(psf, shape):
136 | otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
137 | otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
138 | for axis, axis_size in enumerate(psf.shape[2:]):
139 | otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
140 | otf = torch.rfft(otf, 2, onesided=False)
141 | n_ops = torch.sum(psf.size * torch.log2(psf.shape))
142 | otf[...,1][torch.asb(otf[...,1])