├── .gitignore ├── LICENSE ├── README.md ├── benchmark.py ├── models ├── __init__.py ├── mge_pretrained.ckp ├── net_mge.py ├── net_torch.py └── torch_pretrained.ckp ├── requirements.txt ├── run_benchmark.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Practical Mobile Raw Image Denoising (PMRID) 2 | 3 | Code and dataset for ECCV20 paper [Practical Deep Raw Image Denoising on Mobile Devices](https://arxiv.org/abs/2010.06935). 4 | 5 | ## Dataset 6 | 7 | ### Downloads 8 | - [OneDrive](https://megvii-my.sharepoint.cn/:f:/g/personal/wangyuzhi_megvii_com/Et4v2Z7CkRxHnbcFUq6RXZMBfXUrlm_Se5OVDvcdujVsMA?e=vcfJWs) 9 | - [Kaggle](https://www.kaggle.com/dataset/1bdc5cd707cfbb3ee842eb3cbfe93495dbba88017d29f295f8edbcb8f8790556) 10 | 11 | ### Usage 12 | 13 | The dataset includes two 7zip files: 14 | - `reno10x_noise.7z` contains DNG raw images shot by an _OPPO Reno 10x_ phone for noise parameter estimation (refer Sec 3.1 and 5.1 in the paper) 15 | - `PMRID.7z` is the benchmark dataset described in Sec 5.2 in the paper 16 | 17 | The structure of `PMRID.7z` is 18 | ``` 19 | - benchmark.json # meta info 20 | - Scene1/ 21 | \- Bright/ 22 | \- exposure-case1/ 23 | \- input.raw # RAW data for noisy image in uint16 24 | - gt.raw # RAW data for clean image in uint16 25 | + case2/ 26 | + Dark/ 27 | + Secne2/ 28 | ``` 29 | 30 | All metadata for images are listed in `benchmark.json`: 31 | ```python 32 | { 33 | "input": "path/to/noisy_input.raw", 34 | "gt": "path/to/clean_gt.raw", 35 | "meta": { 36 | "name": "case_name", 37 | "scene_id": "scene_name", 38 | "light": "light condition", 39 | "ISO": "ISO", 40 | "exp_time": "exposure time", 41 | "bayer_pattern": "BGGR", 42 | "shape": [3000, 4000], 43 | "wb_gain": [r_gain, g_gain, b_gain], 44 | "CCM": [ # 3x3 color correction matrix 45 | [c11, c12, c13], 46 | [c21, c22, c23], 47 | [c31, c32, c33] 48 | ], 49 | "ROIs": [ # patch ROIs to calculate PSNR and SSIM, x0 is topleft 50 | [topleft_w, topleft_h, bottomright_w, bottomright_h] 51 | ] 52 | } 53 | } 54 | ``` 55 | 56 | ## Pre-trained Models and Benchmark Script 57 | 58 | Both [PyTorch](https://pytorch.org/) and [MegEngine](https://megengine.org.cn/) pre-trained models are provided in the `models` directory. 59 | The benchmark script is written for models trained with MegEngine. `Python >= 3.6` is required to run the benchmark script. 60 | 61 | ``` 62 | pip install -r requirements.txt 63 | python3 run_benchmark.py --benchmark /path/to/PMRID/benchmark.json models/mge_pretrained.ckp 64 | ``` 65 | 66 | 67 | ## Citation 68 | ``` 69 | @inproceedings{wang2020, 70 | title={Practical Deep Raw Image Denoising on Mobile Devices}, 71 | author={Wang, Yuzhi and Huang, Haibin and Xu, Qin and Liu, Jiaming and Liu, Yiqun and Wang, Jue}, 72 | booktitle={European Conference on Computer Vision (ECCV)}, 73 | year={2020}, 74 | pages={1--16} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import json 3 | from pathlib import Path 4 | from dataclasses import dataclass 5 | from typing import Tuple, List 6 | 7 | import numpy as np 8 | 9 | 10 | def read_array(path: Path) -> np.ndarray: 11 | return np.fromfile(str(path), dtype=np.uint16) 12 | 13 | 14 | @dataclass 15 | class RawMeta: 16 | name: str 17 | scene_id: str 18 | light: str 19 | ISO: int 20 | exp_time: float 21 | bayer_pattern: str 22 | shape: Tuple[int, int] 23 | wb_gain: Tuple[float, float, float] 24 | CCM: Tuple[ 25 | Tuple[float, float, float], 26 | Tuple[float, float, float], 27 | Tuple[float, float, float], 28 | ] 29 | ROIs: List[Tuple[int, int, int, int]] 30 | 31 | 32 | class BenchmarkLoader: 33 | 34 | def __init__(self, dataset_info_json: Path, base_path=None): 35 | with dataset_info_json.open() as f: 36 | self._dataset = [ 37 | { 38 | 'input': x['input'], 39 | 'gt': x['gt'], 40 | 'meta': RawMeta(**x['meta']) 41 | } 42 | for x in json.load(f) 43 | ] 44 | if base_path is None: 45 | self.base_path = dataset_info_json.parent 46 | else: 47 | self.base_path = Path(base_path) 48 | 49 | def __len__(self): 50 | return len(self._dataset) 51 | 52 | def __iter__(self): 53 | self._idx = 0 54 | return self 55 | 56 | def __next__(self) -> Tuple[np.ndarray, np.ndarray, RawMeta]: 57 | if self._idx >= len(self): 58 | raise StopIteration 59 | 60 | input_bayer, gt_bayer, meta = self._load_idx(self._idx) 61 | 62 | self._idx += 1 63 | return input_bayer, gt_bayer, meta 64 | 65 | def _load_idx(self, idx: int): 66 | item = self._dataset[idx] 67 | 68 | img_files = item['input'], item['gt'] 69 | meta = item['meta'] 70 | bayers = [] 71 | for img_file in img_files: 72 | if not Path(img_file).is_absolute(): 73 | img_file = self.base_path / img_file 74 | bayer = read_array(img_file) 75 | bayer = bayer.reshape(*meta.shape) 76 | # Reno 10x outputs BGGR order 77 | assert meta.bayer_pattern == 'BGGR' 78 | bayer = bayer.astype(np.float32) / 65535 79 | bayers.append(bayer) 80 | 81 | input_bayer, gt_bayer = bayers 82 | return input_bayer, gt_bayer, meta 83 | 84 | # vim: ts=4 sw=4 sts=4 expandtab 85 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/PMRID/8ebb9e8e96559881dee957f34243933c5beb77dd/models/__init__.py -------------------------------------------------------------------------------- /models/mge_pretrained.ckp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/PMRID/8ebb9e8e96559881dee957f34243933c5beb77dd/models/mge_pretrained.ckp -------------------------------------------------------------------------------- /models/net_mge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from collections import OrderedDict 3 | 4 | import megengine as mge 5 | import megengine.module as M 6 | import megengine.functional as F 7 | 8 | 9 | def Conv2D( 10 | in_channels: int, out_channels: int, 11 | kernel_size: int, stride: int, padding: int, 12 | is_seperable: bool = False, has_relu: bool = False, 13 | ): 14 | modules = OrderedDict() 15 | 16 | if is_seperable: 17 | modules['depthwise'] = M.Conv2d( 18 | in_channels, in_channels, kernel_size, stride, padding, 19 | groups=in_channels, bias=False, 20 | ) 21 | modules['pointwise'] = M.Conv2d( 22 | in_channels, out_channels, 23 | kernel_size=1, stride=1, padding=0, bias=True, 24 | ) 25 | else: 26 | modules['conv'] = M.Conv2d( 27 | in_channels, out_channels, kernel_size, stride, padding, 28 | bias=True, 29 | ) 30 | if has_relu: 31 | modules['relu'] = M.ReLU() 32 | 33 | return M.Sequential(modules) 34 | 35 | 36 | class EncoderBlock(M.Module): 37 | 38 | def __init__(self, in_channels: int, mid_channels: int, out_channels: int, stride: int = 1): 39 | super().__init__() 40 | 41 | self.conv1 = Conv2D(in_channels, mid_channels, kernel_size=5, stride=stride, padding=2, is_seperable=True, has_relu=True) 42 | self.conv2 = Conv2D(mid_channels, out_channels, kernel_size=5, stride=1, padding=2, is_seperable=True, has_relu=False) 43 | 44 | self.proj = ( 45 | M.Identity() 46 | if stride == 1 and in_channels == out_channels else 47 | Conv2D(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, is_seperable=True, has_relu=False) 48 | ) 49 | self.relu = M.ReLU() 50 | 51 | def forward(self, x): 52 | proj = self.proj(x) 53 | 54 | x = self.conv1(x) 55 | x = self.conv2(x) 56 | 57 | x = x + proj 58 | return self.relu(x) 59 | 60 | 61 | def EncoderStage(in_channels: int, out_channels: int, num_blocks: int): 62 | 63 | blocks = [ 64 | EncoderBlock( 65 | in_channels=in_channels, 66 | mid_channels=out_channels//4, 67 | out_channels=out_channels, 68 | stride=2, 69 | ) 70 | ] 71 | for _ in range(num_blocks-1): 72 | blocks.append( 73 | EncoderBlock( 74 | in_channels=out_channels, 75 | mid_channels=out_channels//4, 76 | out_channels=out_channels, 77 | stride=1, 78 | ) 79 | ) 80 | 81 | return M.Sequential(*blocks) 82 | 83 | 84 | class DecoderBlock(M.Module): 85 | 86 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3): 87 | super().__init__() 88 | 89 | padding = kernel_size // 2 90 | self.conv0 = Conv2D( 91 | in_channels, out_channels, kernel_size=kernel_size, padding=padding, 92 | stride=1, is_seperable=True, has_relu=True, 93 | ) 94 | self.conv1 = Conv2D( 95 | out_channels, out_channels, kernel_size=kernel_size, padding=padding, 96 | stride=1, is_seperable=True, has_relu=False, 97 | ) 98 | 99 | def forward(self, x): 100 | inp = x 101 | x = self.conv0(x) 102 | x = self.conv1(x) 103 | x = x + inp 104 | return x 105 | 106 | 107 | class DecoderStage(M.Module): 108 | 109 | def __init__(self, in_channels: int, skip_in_channels: int, out_channels: int): 110 | super().__init__() 111 | 112 | self.decode_conv = DecoderBlock(in_channels, in_channels, kernel_size=3) 113 | self.upsample = M.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0) 114 | self.proj_conv = Conv2D(skip_in_channels, out_channels, kernel_size=3, stride=1, padding=1, is_seperable=True, has_relu=True) 115 | 116 | def forward(self, inputs): 117 | inp, skip = inputs 118 | 119 | x = self.decode_conv(inp) 120 | x = self.upsample(x) 121 | y = self.proj_conv(skip) 122 | return x + y 123 | 124 | 125 | class Network(M.Module): 126 | 127 | def __init__(self): 128 | super().__init__() 129 | 130 | self.conv0 = Conv2D(in_channels=4, out_channels=16, kernel_size=3, padding=1, stride=1, is_seperable=False, has_relu=True) 131 | self.enc1 = EncoderStage(in_channels=16, out_channels=64, num_blocks=2) 132 | self.enc2 = EncoderStage(in_channels=64, out_channels=128, num_blocks=2) 133 | self.enc3 = EncoderStage(in_channels=128, out_channels=256, num_blocks=4) 134 | self.enc4 = EncoderStage(in_channels=256, out_channels=512, num_blocks=4) 135 | 136 | self.encdec = Conv2D(in_channels=512, out_channels=64, kernel_size=3, padding=1, stride=1, is_seperable=True, has_relu=True) 137 | self.dec1 = DecoderStage(in_channels=64, skip_in_channels=256, out_channels=64) 138 | self.dec2 = DecoderStage(in_channels=64, skip_in_channels=128, out_channels=32) 139 | self.dec3 = DecoderStage(in_channels=32, skip_in_channels=64, out_channels=32) 140 | self.dec4 = DecoderStage(in_channels=32, skip_in_channels=16, out_channels=16) 141 | 142 | self.out0 = DecoderBlock(in_channels=16, out_channels=16, kernel_size=3) 143 | self.out1 = Conv2D(in_channels=16, out_channels=4, kernel_size=3, stride=1, padding=1, is_seperable=False, has_relu=False) 144 | 145 | def forward(self, inp): 146 | 147 | conv0 = self.conv0(inp) 148 | conv1 = self.enc1(conv0) 149 | conv2 = self.enc2(conv1) 150 | conv3 = self.enc3(conv2) 151 | conv4 = self.enc4(conv3) 152 | 153 | conv5 = self.encdec(conv4) 154 | 155 | up3 = self.dec1((conv5, conv3)) 156 | up2 = self.dec2((up3, conv2)) 157 | up1 = self.dec3((up2, conv1)) 158 | x = self.dec4((up1, conv0)) 159 | 160 | x = self.out0(x) 161 | x = self.out1(x) 162 | 163 | pred = inp + x 164 | return pred 165 | 166 | 167 | def get_loss_l1(pred: mge.Tensor, label: mge.Tensor, norm_k: mge.Tensor): 168 | B = pred.shape[0] 169 | L1 = F.abs(pred - label).reshape(B, -1).mean(axis=1) 170 | L1 = L1 / norm_k.flatten() 171 | return L1.mean() 172 | 173 | 174 | if __name__ == "__main__": 175 | import numpy as np 176 | 177 | net = Network() 178 | img = mge.tensor(np.random.randn(1, 4, 64, 64).astype(np.float32)) 179 | out = net(img) 180 | 181 | # vim: ts=4 sw=4 sts=4 expandtab 182 | -------------------------------------------------------------------------------- /models/net_torch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | import torch.nn as nn 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | 8 | 9 | def Conv2D( 10 | in_channels: int, out_channels: int, 11 | kernel_size: int, stride: int, padding: int, 12 | is_seperable: bool = False, has_relu: bool = False, 13 | ): 14 | modules = OrderedDict() 15 | 16 | if is_seperable: 17 | modules['depthwise'] = nn.Conv2d( 18 | in_channels, in_channels, kernel_size, stride, padding, 19 | groups=in_channels, bias=False, 20 | ) 21 | modules['pointwise'] = nn.Conv2d( 22 | in_channels, out_channels, 23 | kernel_size=1, stride=1, padding=0, bias=True, 24 | ) 25 | else: 26 | modules['conv'] = nn.Conv2d( 27 | in_channels, out_channels, kernel_size, stride, padding, 28 | bias=True, 29 | ) 30 | if has_relu: 31 | modules['relu'] = nn.ReLU() 32 | 33 | return nn.Sequential(modules) 34 | 35 | 36 | class EncoderBlock(nn.Module): 37 | 38 | def __init__(self, in_channels: int, mid_channels: int, out_channels: int, stride: int = 1): 39 | super().__init__() 40 | 41 | self.conv1 = Conv2D(in_channels, mid_channels, kernel_size=5, stride=stride, padding=2, is_seperable=True, has_relu=True) 42 | self.conv2 = Conv2D(mid_channels, out_channels, kernel_size=5, stride=1, padding=2, is_seperable=True, has_relu=False) 43 | 44 | self.proj = ( 45 | nn.Identity() 46 | if stride == 1 and in_channels == out_channels else 47 | Conv2D(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, is_seperable=True, has_relu=False) 48 | ) 49 | self.relu = nn.ReLU() 50 | 51 | def forward(self, x): 52 | proj = self.proj(x) 53 | 54 | x = self.conv1(x) 55 | x = self.conv2(x) 56 | 57 | x = x + proj 58 | return self.relu(x) 59 | 60 | 61 | def EncoderStage(in_channels: int, out_channels: int, num_blocks: int): 62 | 63 | blocks = [ 64 | EncoderBlock( 65 | in_channels=in_channels, 66 | mid_channels=out_channels//4, 67 | out_channels=out_channels, 68 | stride=2, 69 | ) 70 | ] 71 | for _ in range(num_blocks-1): 72 | blocks.append( 73 | EncoderBlock( 74 | in_channels=out_channels, 75 | mid_channels=out_channels//4, 76 | out_channels=out_channels, 77 | stride=1, 78 | ) 79 | ) 80 | 81 | return nn.Sequential(*blocks) 82 | 83 | 84 | class DecoderBlock(nn.Module): 85 | 86 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3): 87 | super().__init__() 88 | 89 | padding = kernel_size // 2 90 | self.conv0 = Conv2D( 91 | in_channels, out_channels, kernel_size=kernel_size, padding=padding, 92 | stride=1, is_seperable=True, has_relu=True, 93 | ) 94 | self.conv1 = Conv2D( 95 | out_channels, out_channels, kernel_size=kernel_size, padding=padding, 96 | stride=1, is_seperable=True, has_relu=False, 97 | ) 98 | 99 | def forward(self, x): 100 | inp = x 101 | x = self.conv0(x) 102 | x = self.conv1(x) 103 | x = x + inp 104 | return x 105 | 106 | 107 | class DecoderStage(nn.Module): 108 | 109 | def __init__(self, in_channels: int, skip_in_channels: int, out_channels: int): 110 | super().__init__() 111 | 112 | self.decode_conv = DecoderBlock(in_channels, in_channels, kernel_size=3) 113 | self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0) 114 | self.proj_conv = Conv2D(skip_in_channels, out_channels, kernel_size=3, stride=1, padding=1, is_seperable=True, has_relu=True) 115 | # M.init.msra_normal_(self.upsample.weight, mode='fan_in', nonlinearity='linear') 116 | 117 | def forward(self, inputs): 118 | inp, skip = inputs 119 | 120 | x = self.decode_conv(inp) 121 | x = self.upsample(x) 122 | y = self.proj_conv(skip) 123 | return x + y 124 | 125 | 126 | class Network(nn.Module): 127 | 128 | def __init__(self): 129 | super().__init__() 130 | 131 | self.conv0 = Conv2D(in_channels=4, out_channels=16, kernel_size=3, padding=1, stride=1, is_seperable=False, has_relu=True) 132 | self.enc1 = EncoderStage(in_channels=16, out_channels=64, num_blocks=2) 133 | self.enc2 = EncoderStage(in_channels=64, out_channels=128, num_blocks=2) 134 | self.enc3 = EncoderStage(in_channels=128, out_channels=256, num_blocks=4) 135 | self.enc4 = EncoderStage(in_channels=256, out_channels=512, num_blocks=4) 136 | 137 | self.encdec = Conv2D(in_channels=512, out_channels=64, kernel_size=3, padding=1, stride=1, is_seperable=True, has_relu=True) 138 | self.dec1 = DecoderStage(in_channels=64, skip_in_channels=256, out_channels=64) 139 | self.dec2 = DecoderStage(in_channels=64, skip_in_channels=128, out_channels=32) 140 | self.dec3 = DecoderStage(in_channels=32, skip_in_channels=64, out_channels=32) 141 | self.dec4 = DecoderStage(in_channels=32, skip_in_channels=16, out_channels=16) 142 | 143 | self.out0 = DecoderBlock(in_channels=16, out_channels=16, kernel_size=3) 144 | self.out1 = Conv2D(in_channels=16, out_channels=4, kernel_size=3, stride=1, padding=1, is_seperable=False, has_relu=False) 145 | 146 | def forward(self, inp): 147 | 148 | conv0 = self.conv0(inp) 149 | conv1 = self.enc1(conv0) 150 | conv2 = self.enc2(conv1) 151 | conv3 = self.enc3(conv2) 152 | conv4 = self.enc4(conv3) 153 | 154 | conv5 = self.encdec(conv4) 155 | 156 | up3 = self.dec1((conv5, conv3)) 157 | up2 = self.dec2((up3, conv2)) 158 | up1 = self.dec3((up2, conv1)) 159 | x = self.dec4((up1, conv0)) 160 | 161 | x = self.out0(x) 162 | x = self.out1(x) 163 | 164 | pred = inp + x 165 | return pred 166 | 167 | 168 | if __name__ == "__main__": 169 | net = Network() 170 | # img = mge.tensor(np.random.randn(1, 4, 64, 64).astype(np.float32)) 171 | img = torch.randn(1, 4, 64, 64, device=torch.device('cpu'), dtype=torch.float32) 172 | out = net(img) 173 | 174 | import IPython; IPython.embed() 175 | 176 | # vim: ts=4 sw=4 sts=4 expandtab 177 | -------------------------------------------------------------------------------- /models/torch_pretrained.ckp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/PMRID/8ebb9e8e96559881dee957f34243933c5beb77dd/models/torch_pretrained.ckp -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | MegEngine>=1.1.0 2 | numpy 3 | opencv-python 4 | scikit-image 5 | dataclasses 6 | -------------------------------------------------------------------------------- /run_benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import pickle 4 | from pathlib import Path 5 | from typing import Tuple 6 | 7 | import cv2 8 | import megengine as mge 9 | import numpy as np 10 | import skimage.metrics 11 | from tqdm import tqdm 12 | 13 | from models.net_mge import Network 14 | from utils import RawUtils 15 | from benchmark import BenchmarkLoader, RawMeta 16 | 17 | 18 | class KSigma: 19 | 20 | def __init__(self, K_coeff: Tuple[float, float], B_coeff: Tuple[float, float, float], anchor: float, V: float = 959.0): 21 | self.K = np.poly1d(K_coeff) 22 | self.Sigma = np.poly1d(B_coeff) 23 | self.anchor = anchor 24 | self.V = V 25 | 26 | def __call__(self, img_01, iso: float, inverse=False): 27 | k, sigma = self.K(iso), self.Sigma(iso) 28 | k_a, sigma_a = self.K(self.anchor), self.Sigma(self.anchor) 29 | 30 | cvt_k = k_a / k 31 | cvt_b = (sigma / (k ** 2) - sigma_a / (k_a ** 2)) * k_a 32 | 33 | img = img_01 * self.V 34 | 35 | if not inverse: 36 | img = img * cvt_k + cvt_b 37 | else: 38 | img = (img - cvt_b) / cvt_k 39 | 40 | return img / self.V 41 | 42 | 43 | class Denoiser: 44 | 45 | def __init__(self, model_path: Path, ksigma: KSigma, inp_scale=256.0): 46 | net = Network() 47 | with model_path.open('rb') as f: 48 | states = pickle.load(f) 49 | net.load_state_dict(states) 50 | net.eval() 51 | 52 | self.net = net 53 | self.ksigma = ksigma 54 | self.inp_scale = inp_scale 55 | 56 | def pre_process(self, bayer_01: np.ndarray): 57 | rggb = RawUtils.bayer2rggb(bayer_01) 58 | rggb = rggb.clip(0, 1) 59 | 60 | H, W = rggb.shape[:2] 61 | ph, pw = (32-(H % 32))//2, (32-(W % 32))//2 62 | rggb = np.pad(rggb, [(ph, ph), (pw, pw), (0, 0)], 'constant') 63 | inp_rggb = rggb.transpose(2, 0, 1)[np.newaxis] 64 | self.ph, self.pw = ph, pw 65 | return inp_rggb 66 | 67 | def run(self, bayer_01: np.ndarray, iso: float): 68 | inp_rggb_01 = self.pre_process(bayer_01) 69 | inp_rggb = self.ksigma(inp_rggb_01, iso) * self.inp_scale 70 | 71 | inp = np.ascontiguousarray(inp_rggb) 72 | pred = self.net(inp)[0] / self.inp_scale 73 | 74 | # import ipdb; ipdb.set_trace() 75 | pred = pred.numpy().transpose(1, 2, 0) 76 | pred = self.ksigma(pred, iso, inverse=True) 77 | 78 | ph, pw = self.ph, self.pw 79 | pred = pred[ph:-ph, pw:-pw] 80 | return RawUtils.rggb2bayer(pred) 81 | 82 | 83 | def run_benchmark(model_path, bm_loader: BenchmarkLoader): 84 | 85 | ksigma = KSigma( 86 | K_coeff=[0.0005995267, 0.00868861], 87 | B_coeff=[7.11772e-7, 6.514934e-4, 0.11492713], 88 | anchor=1600, 89 | ) 90 | denoiser = Denoiser(model_path, ksigma) 91 | 92 | PSNRs, SSIMs = [], [] 93 | 94 | bar = tqdm(bm_loader) 95 | for input_bayer, gt_bayer, meta in bar: 96 | bar.set_description(meta.name) 97 | assert meta.bayer_pattern == 'BGGR' 98 | input_bayer, gt_bayer = RawUtils.bggr2rggb(input_bayer, gt_bayer) 99 | 100 | pred_bayer = denoiser.run(input_bayer, iso=meta.ISO) 101 | 102 | inp_rgb, pred_rgb, gt_rgb = RawUtils.bayer2rgb( 103 | input_bayer, pred_bayer, gt_bayer, 104 | wb_gain=meta.wb_gain, CCM=meta.CCM, 105 | ) 106 | inp_rgb, pred_rgb, gt_rgb = RawUtils.bggr2rggb(inp_rgb, pred_rgb, gt_rgb) 107 | bar.set_description(meta.name+' ✓') 108 | 109 | psnrs = [] 110 | ssims = [] 111 | 112 | for x0, y0, x1, y1 in meta.ROIs: 113 | pred_patch = pred_rgb[y0:y1, x0:x1] 114 | gt_patch = gt_rgb[y0:y1, x0:x1] 115 | 116 | psnr = skimage.metrics.peak_signal_noise_ratio(gt_patch, pred_patch) 117 | ssim = skimage.metrics.structural_similarity(gt_patch, pred_patch, multichannel=True) 118 | psnrs.append(float(psnr)) 119 | ssims.append(float(ssim)) 120 | 121 | bar.set_description(meta.name+' ✓✓') 122 | 123 | PSNRs = PSNRs + psnrs # list append 124 | SSIMs = SSIMs + ssims 125 | 126 | mean_psnr = np.mean(PSNRs) 127 | mean_ssim = np.mean(SSIMs) 128 | print("mean PSNR:", mean_psnr) 129 | print("mean SSIM:", mean_ssim) 130 | 131 | 132 | def main(): 133 | 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument('model', type=Path) 136 | parser.add_argument('--benchmark', type=Path) 137 | 138 | args = parser.parse_args() 139 | 140 | bm_loader = BenchmarkLoader(args.benchmark.resolve()) 141 | run_benchmark(args.model, bm_loader) 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | 147 | # vim: ts=4 sw=4 sts=4 expandtab 148 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | class RawUtils: 7 | 8 | @classmethod 9 | def bggr2rggb(cls, *bayers): 10 | res = [] 11 | for bayer in bayers: 12 | res.append(bayer[::-1, ::-1]) 13 | if len(res) == 1: 14 | return res[0] 15 | return res 16 | 17 | @classmethod 18 | def rggb2bggr(cls, *bayers): 19 | return cls.bggr2rggb(*bayers) 20 | 21 | @classmethod 22 | def bayer2rggb(cls, *bayers): 23 | res = [] 24 | for bayer in bayers: 25 | H, W = bayer.shape 26 | res.append( 27 | bayer.reshape(H//2, 2, W//2, 2) 28 | .transpose(0, 2, 1, 3) 29 | .reshape(H//2, W//2, 4) 30 | ) 31 | if len(res) == 1: 32 | return res[0] 33 | return res 34 | 35 | @classmethod 36 | def rggb2bayer(cls, *rggbs): 37 | res = [] 38 | for rggb in rggbs: 39 | H, W, _ = rggb.shape 40 | res.append( 41 | rggb.reshape(H, W, 2, 2) 42 | .transpose(0, 2, 1, 3) 43 | .reshape(H*2, W*2) 44 | ) 45 | 46 | if len(res) == 1: 47 | return res[0] 48 | return res 49 | 50 | @classmethod 51 | def bayer2rgb(cls, *bayer_01s, wb_gain, CCM, gamma=2.2): 52 | 53 | wb_gain = np.array(wb_gain)[[0, 1, 1, 2]] 54 | res = [] 55 | for bayer_01 in bayer_01s: 56 | bayer = cls.rggb2bayer( 57 | (cls.bayer2rggb(bayer_01) * wb_gain).clip(0, 1) 58 | ).astype(np.float32) 59 | bayer = np.round(np.ascontiguousarray(bayer) * 65535).clip(0, 65535).astype(np.uint16) 60 | rgb = cv2.cvtColor(bayer, cv2.COLOR_BAYER_BG2RGB_EA).astype(np.float32) / 65535 61 | rgb = rgb.dot(np.array(CCM).T).clip(0, 1) 62 | rgb = rgb ** (1/gamma) 63 | res.append(rgb.astype(np.float32)) 64 | 65 | if len(res) == 1: 66 | return res[0] 67 | return res 68 | 69 | 70 | # vim: ts=4 sw=4 sts=4 expandtab 71 | --------------------------------------------------------------------------------