├── .gitignore ├── Fig ├── .DS_Store ├── fig1.png └── fig2.png ├── LICENSE ├── README.md ├── dataset.py ├── model.py ├── prepare_data.py ├── test.py ├── train_mge.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Fig/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/NBNet/73112b185e022d0920f2f45c34c5bcf7c581d983/Fig/.DS_Store -------------------------------------------------------------------------------- /Fig/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/NBNet/73112b185e022d0920f2f45c34c5bcf7c581d983/Fig/fig1.png -------------------------------------------------------------------------------- /Fig/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/NBNet/73112b185e022d0920f2f45c34c5bcf7c581d983/Fig/fig2.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NBNet: Noise Basis Learning for Image Denoising with Subspace Projection 2 | 3 | Code for CVPR21 paper [NBNet](https://arxiv.org/abs/2012.15028). 4 | 5 | *The illustration of our key insight:* 6 | 7 | projection_concept 8 | 9 | ## Dependencies 10 | 11 | - MegEngine >= 1.3.1 (For DistributedDataParallel) 12 | 13 | 14 | 15 | ## Training 16 | 17 | ### Preparation 18 | 19 | ``` 20 | python prepare_data.py --data_dir yours_sidd_data_path 21 | ``` 22 | 23 | 24 | 25 | ### Begin training: 26 | 27 | For SIDD benchmark, use: 28 | 29 | ``` 30 | python train_mge.py -d prepared_data_path -n num_gpus 31 | ``` 32 | 33 | 34 | 35 | For DnD benchmark, we use MixUp additionally: 36 | 37 | ``` 38 | python train_mge.py -d prepared_data_path -n num_gpus --dnd 39 | ``` 40 | 41 | ### Begin testing: 42 | Download the pretrained checkpoint and use: 43 | 44 | ``` 45 | python test.py -d prepared_data_path -c checkpoint_path 46 | ``` 47 | The result is **PSNR 39.765**. 48 | 49 | 50 | 51 | ## Pretrained model 52 | 53 | MegEngine checkpoint for SIDD benchmark can be downloaded via 54 | [Google Drive](https://drive.google.com/file/d/1RPAf9ZJqqq9ePPVTtJRlixX4-h3HJTCc/view?usp=sharing) 55 | or 56 | [GitHub Release](https://github.com/megvii-research/NBNet/releases). 57 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import megengine as mge 4 | from megengine.data.dataset import Dataset 5 | 6 | import os 7 | import json 8 | from pathlib import Path 9 | from typing import Iterator, Sequence 10 | from tqdm import tqdm 11 | 12 | import cv2 13 | import numpy as np 14 | import pickle as pkl 15 | from skimage import img_as_float32 as img_as_float 16 | import random 17 | from scipy.io import loadmat 18 | def is_png_file(filename): 19 | return any(filename.endswith(extension) for extension in [".png"]) 20 | 21 | def data_augmentation(image, mode): 22 | """ 23 | Performs data augmentation of the input image 24 | Input: 25 | image: a cv2 (OpenCV) image 26 | mode: int. Choice of transformation to apply to the image 27 | 0 - no transformation 28 | 1 - flip up and down 29 | 2 - rotate counterwise 90 degree 30 | 3 - rotate 90 degree and flip up and down 31 | 4 - rotate 180 degree 32 | 5 - rotate 180 degree and flip 33 | 6 - rotate 270 degree 34 | 7 - rotate 270 degree and flip 35 | """ 36 | if mode == 0: 37 | # original 38 | out = image 39 | elif mode == 1: 40 | # flip up and down 41 | out = np.flipud(image) 42 | elif mode == 2: 43 | # rotate counterwise 90 degree 44 | out = np.rot90(image) 45 | elif mode == 3: 46 | # rotate 90 degree and flip up and down 47 | out = np.rot90(image) 48 | out = np.flipud(out) 49 | elif mode == 4: 50 | # rotate 180 degree 51 | out = np.rot90(image, k=2) 52 | elif mode == 5: 53 | # rotate 180 degree and flip 54 | out = np.rot90(image, k=2) 55 | out = np.flipud(out) 56 | elif mode == 6: 57 | # rotate 270 degree 58 | out = np.rot90(image, k=3) 59 | elif mode == 7: 60 | # rotate 270 degree and flip 61 | out = np.rot90(image, k=3) 62 | out = np.flipud(out) 63 | else: 64 | raise Exception('Invalid choice of image transformation') 65 | 66 | return out 67 | 68 | def random_augmentation(*args): 69 | out = [] 70 | if random.randint(0,1) == 1: 71 | flag_aug = random.randint(1,7) 72 | for data in args: 73 | out.append(data_augmentation(data, flag_aug).copy()) 74 | else: 75 | for data in args: 76 | out.append(data) 77 | return out 78 | 79 | def load_img(filepath): 80 | img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 81 | return img 82 | 83 | class SIDDData(Dataset): 84 | def __init__(self, path, length=None): 85 | ''' 86 | Args: 87 | h5_path (str): path of the hdf5 file 88 | length (int): length of Datasets 89 | ''' 90 | super(SIDDData, self).__init__() 91 | self.length = length 92 | clean_files = sorted(os.listdir(os.path.join(path, 'train', 'groundtruth'))) 93 | noisy_files = sorted(os.listdir(os.path.join(path, 'train', 'input'))) 94 | 95 | self.clean_filenames = [os.path.join(path, 'train', 'groundtruth', x) for x in clean_files if is_png_file(x)] 96 | self.noisy_filenames = [os.path.join(path, 'train', 'input', x) for x in noisy_files if is_png_file(x)] 97 | self.pch_size = 128 98 | def __len__(self): 99 | if self.length == None: 100 | return self.num_images 101 | else: 102 | return self.length 103 | 104 | def crop_patch(self, n_img, gt_img): 105 | H, W, C = n_img.shape 106 | # minus the bayer patter channel 107 | ind_H = random.randint(0, H-self.pch_size) 108 | ind_W = random.randint(0, W-self.pch_size) 109 | im_noisy = n_img[ind_H:ind_H+self.pch_size, ind_W:ind_W+self.pch_size, :] 110 | im_gt = gt_img[ind_H:ind_H+self.pch_size, ind_W:ind_W+self.pch_size, :] 111 | return im_noisy, im_gt 112 | 113 | def __getitem__(self, index): 114 | index = index % len(self.clean_filenames) 115 | # cv2.setNumThreads(0) 116 | noisy_img = load_img(self.noisy_filenames[index]) 117 | gt_img = load_img(self.clean_filenames[index]) 118 | 119 | # noisy_img = np.ascontiguousarray(noisy_img, dtype=np.float32) 120 | # gt_img = np.ascontiguousarray(gt_img, dtype=np.float32) 121 | 122 | noisy_img, gt_img = self.crop_patch(noisy_img, gt_img) 123 | gt_img = img_as_float(gt_img) 124 | noisy_img = img_as_float(noisy_img) 125 | noisy_img, gt_img = random_augmentation(noisy_img, gt_img) 126 | 127 | gt_img = gt_img.transpose((2, 0, 1)) 128 | noisy_img = noisy_img.transpose((2, 0, 1)) 129 | return noisy_img, gt_img 130 | 131 | 132 | class SIDDValData(Dataset): 133 | def __init__(self, path): 134 | 135 | val_data_dict = loadmat(os.path.join(path, 'ValidationNoisyBlocksSrgb.mat')) 136 | val_data_noisy = val_data_dict['ValidationNoisyBlocksSrgb'] 137 | val_data_dict = loadmat(os.path.join(path,'ValidationGtBlocksSrgb.mat')) 138 | val_data_gt = val_data_dict['ValidationGtBlocksSrgb'] 139 | self.num_img, self.num_block, h_, w_, c_ = val_data_gt.shape 140 | self.val_data_noisy = np.reshape(val_data_noisy, (-1, h_, w_, c_)) 141 | self.val_data_gt = np.reshape(val_data_gt, (-1, h_, w_, c_)) 142 | 143 | 144 | def __len__(self): 145 | return self.num_img*self.num_block 146 | 147 | def __getitem__(self, index): 148 | 149 | noisy_img, gt_img = self.val_data_noisy[index], self.val_data_gt[index] 150 | gt_img = img_as_float(gt_img) 151 | noisy_img = img_as_float(noisy_img) 152 | gt_img = gt_img.transpose((2, 0, 1)) 153 | noisy_img = noisy_img.transpose((2, 0, 1)) 154 | return noisy_img, gt_img 155 | 156 | # vim: ts=4 sw=4 sts=4 expandtab 157 | 158 | 159 | # vim: ts=4 sw=4 sts=4 expandtab 160 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import megengine as mge 3 | import megengine.module as nn 4 | import megengine.functional as F 5 | 6 | def conv3x3(in_chn, out_chn, bias=True): 7 | layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias) 8 | return layer 9 | 10 | 11 | def conv_down(in_chn, out_chn, bias=False): 12 | layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias) 13 | return layer 14 | 15 | 16 | class UNetD(nn.Module): 17 | 18 | def __init__(self, in_chn, wf=32, depth=5, relu_slope=0.2, subspace_dim=16): 19 | super(UNetD, self).__init__() 20 | self.depth = depth 21 | self.down_path = [] 22 | prev_channels = self.get_input_chn(in_chn) 23 | for i in range(depth): 24 | downsample = True if (i+1) < depth else False 25 | self.down_path.append(UNetConvBlock(prev_channels, (2**i)*wf, downsample, relu_slope)) 26 | prev_channels = (2**i) * wf 27 | 28 | # self.ema = EMAU(prev_channels, prev_channels//8) 29 | self.up_path = [] 30 | subnet_repeat_num = 1 31 | for i in reversed(range(depth - 1)): 32 | self.up_path.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope, subnet_repeat_num, subspace_dim)) 33 | prev_channels = (2**i)*wf 34 | subnet_repeat_num += 1 35 | 36 | self.last = conv3x3(prev_channels, in_chn, bias=True) 37 | #self._initialize() 38 | 39 | def forward(self, x1): 40 | blocks = [] 41 | for i, down in enumerate(self.down_path): 42 | # print(x1.shape) 43 | if (i+1) < self.depth: 44 | x1, x1_up = down(x1) 45 | blocks.append(x1_up) 46 | else: 47 | x1 = down(x1) 48 | # print(x1.shape) 49 | # x1 = self.ema(x1) 50 | for i, up in enumerate(self.up_path): 51 | # print(x1.shape, blocks[-i-1].shape) 52 | x1 = up(x1, blocks[-i-1]) 53 | 54 | pred = self.last(x1) 55 | return pred 56 | 57 | def get_input_chn(self, in_chn): 58 | return in_chn 59 | 60 | def _initialize(self): 61 | gain = nn.init.calculate_gain('leaky_relu', 0.20) 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv2d): 64 | print("weight") 65 | nn.init.xavier_uniform_(m.weight) 66 | if m.bias is not None: 67 | print("bias") 68 | nn.init.zeros_(m.bias) 69 | 70 | 71 | class UNetConvBlock(nn.Module): 72 | 73 | def __init__(self, in_size, out_size, downsample, relu_slope): 74 | super(UNetConvBlock, self).__init__() 75 | self.block = nn.Sequential( 76 | nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True), 77 | nn.LeakyReLU(relu_slope), 78 | nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True), 79 | nn.LeakyReLU(relu_slope)) 80 | 81 | self.downsample = downsample 82 | if downsample: 83 | self.downsample = conv_down(out_size, out_size, bias=False) 84 | 85 | self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True) 86 | 87 | def forward(self, x): 88 | out = self.block(x) 89 | sc = self.shortcut(x) 90 | out = out + sc 91 | if self.downsample: 92 | out_down = self.downsample(out) 93 | return out_down, out 94 | else: 95 | return out 96 | 97 | 98 | class UNetUpBlock(nn.Module): 99 | 100 | def __init__(self, in_size, out_size, relu_slope, subnet_repeat_num, subspace_dim=16): 101 | super(UNetUpBlock, self).__init__() 102 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True) 103 | self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope) 104 | self.num_subspace = subspace_dim 105 | print(self.num_subspace, subnet_repeat_num) 106 | self.subnet = Subspace(in_size, self.num_subspace) 107 | self.skip_m = skip_blocks(out_size, out_size, subnet_repeat_num) 108 | 109 | def forward(self, x, bridge): 110 | up = self.up(x) 111 | bridge = self.skip_m(bridge) 112 | out = F.concat([up, bridge], 1) 113 | if self.subnet: 114 | b_, c_, h_, w_ = bridge.shape 115 | sub = self.subnet(out) 116 | V_t = sub.reshape(b_, self.num_subspace, h_*w_) 117 | V_t = V_t / (1e-6 + F.abs(V_t).sum(axis=2, keepdims=True)) 118 | V = V_t.transpose(0, 2, 1) 119 | mat = F.matmul(V_t, V) 120 | mat_inv = F.matinv(mat) 121 | project_mat = F.matmul(mat_inv, V_t) 122 | bridge_ = bridge.reshape(b_, c_, h_*w_) 123 | project_feature = F.matmul(project_mat, bridge_.transpose(0, 2, 1)) 124 | bridge = F.matmul(V, project_feature).transpose(0, 2, 1).reshape(b_, c_, h_, w_) 125 | out = F.concat([up, bridge], 1) 126 | out = self.conv_block(out) 127 | return out 128 | 129 | 130 | class Subspace(nn.Module): 131 | 132 | def __init__(self, in_size, out_size): 133 | super(Subspace, self).__init__() 134 | self.blocks = [] 135 | self.blocks.append(UNetConvBlock(in_size, out_size, False, 0.2)) 136 | self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True) 137 | 138 | def forward(self, x): 139 | sc = self.shortcut(x) 140 | for i in range(len(self.blocks)): 141 | x = self.blocks[i](x) 142 | return x + sc 143 | 144 | 145 | class skip_blocks(nn.Module): 146 | 147 | def __init__(self, in_size, out_size, repeat_num=1): 148 | super(skip_blocks, self).__init__() 149 | self.blocks = [] 150 | self.re_num = repeat_num 151 | mid_c = 128 152 | self.blocks.append(UNetConvBlock(in_size, mid_c, False, 0.2)) 153 | for i in range(self.re_num - 2): 154 | self.blocks.append(UNetConvBlock(mid_c, mid_c, False, 0.2)) 155 | self.blocks.append(UNetConvBlock(mid_c, out_size, False, 0.2)) 156 | self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True) 157 | 158 | def forward(self, x): 159 | sc = self.shortcut(x) 160 | for m in self.blocks: 161 | x = m(x) 162 | return x + sc 163 | 164 | 165 | if __name__ == "__main__": 166 | import numpy as np 167 | a = UNetD(3) 168 | 169 | #print(a) 170 | im = mge.tensor(np.random.randn(1, 3, 128, 128).astype(np.float32)) 171 | print(a(im)) 172 | 173 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import cv2 4 | import numpy as np 5 | import h5py as h5 6 | import argparse 7 | from joblib import Parallel, delayed 8 | import multiprocessing 9 | from tqdm import tqdm 10 | parser = argparse.ArgumentParser(prog='SIDD Train dataset Generation') 11 | # The orignal SIDD images: /ssd1t/SIDD/ 12 | parser.add_argument('--data_dir', default='/data/sidd/SIDD_Medium_Srgb/Data', type=str, metavar='PATH', 13 | help="path to save the training set of SIDD, (default: None)") 14 | parser.add_argument('--tar_dir', default='/data/sidd/train',type=str, help='Directory for image patches') 15 | args = parser.parse_args() 16 | tar = args.tar_dir 17 | 18 | noisy_patchDir = os.path.join(tar, 'input') 19 | clean_patchDir = os.path.join(tar, 'groundtruth') 20 | 21 | if os.path.exists(tar): 22 | os.system("rm -r {}".format(tar)) 23 | 24 | os.makedirs(noisy_patchDir) 25 | os.makedirs(clean_patchDir) 26 | path_all_noisy = glob(os.path.join(args.data_dir, '**/*NOISY*.PNG'), recursive=True) 27 | path_all_noisy = sorted(path_all_noisy) 28 | path_all_gt = [x.replace('NOISY', 'GT') for x in path_all_noisy] 29 | print('Number of big images: {:d}'.format(len(path_all_gt))) 30 | 31 | print('Training: Split the original images to small ones!') 32 | path_h5 = os.path.join(args.data_dir, 'small_imgs_train.hdf5') 33 | if os.path.exists(path_h5): 34 | os.remove(path_h5) 35 | pch_size = 512 36 | stride = 512-128 37 | num_patch = 0 38 | C = 3 39 | 40 | def save_files(ii): 41 | im_noisy_int8 = cv2.imread(path_all_noisy[ii]) 42 | H, W, _ = im_noisy_int8.shape 43 | im_gt_int8 = cv2.imread(path_all_gt[ii]) 44 | ind_H = list(range(0, H-pch_size+1, stride)) 45 | if ind_H[-1] < H-pch_size: 46 | ind_H.append(H-pch_size) 47 | ind_W = list(range(0, W-pch_size+1, stride)) 48 | if ind_W[-1] < W-pch_size: 49 | ind_W.append(W-pch_size) 50 | count = 1 51 | for start_H in ind_H: 52 | for start_W in ind_W: 53 | pch_noisy = im_noisy_int8[start_H:start_H+pch_size, start_W:start_W+pch_size, ] 54 | pch_gt = im_gt_int8[start_H:start_H+pch_size, start_W:start_W+pch_size, ] 55 | cv2.imwrite(os.path.join(noisy_patchDir, '{}_{}.png'.format(ii+1,count+1)), pch_noisy) 56 | cv2.imwrite(os.path.join(clean_patchDir, '{}_{}.png'.format(ii+1,count+1)), pch_gt) 57 | count += 1 58 | Parallel(n_jobs=10)(delayed(save_files)(i) for i in tqdm(range(len(path_all_gt)))) 59 | print('Finish!\n') 60 | 61 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from dataset import SIDDValData 3 | from model import UNetD 4 | import megengine.data as data 5 | from utils import batch_PSNR 6 | from tqdm import tqdm 7 | import argparse 8 | import pickle 9 | import megengine 10 | 11 | 12 | def test(args): 13 | valid_dataset = SIDDValData(args.data) 14 | valid_sampler = data.SequentialSampler( 15 | valid_dataset, batch_size=1, drop_last=False 16 | ) 17 | valid_dataloader = data.DataLoader( 18 | valid_dataset, 19 | sampler=valid_sampler, 20 | num_workers=8, 21 | ) 22 | model = UNetD(3) 23 | with open(args.checkpoint, "rb") as f: 24 | state = pickle.load(f) 25 | model.load_state_dict(state["state_dict"]) 26 | model.eval() 27 | 28 | def valid_step(image, label): 29 | pred = model(image) 30 | pred = image - pred 31 | psnr_it = batch_PSNR(pred, label) 32 | return psnr_it 33 | 34 | def valid(func, data_queue): 35 | psnr_v = 0. 36 | for step, (image, label) in tqdm(enumerate(data_queue)): 37 | image = megengine.tensor(image) 38 | label = megengine.tensor(label) 39 | psnr_it = func(image, label) 40 | psnr_v += psnr_it 41 | psnr_v /= step + 1 42 | return psnr_v 43 | 44 | psnr_v = valid(valid_step, valid_dataloader) 45 | print("PSNR: {:.3f}".format(psnr_v.item()) ) 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser(description="MegEngine NBNet") 49 | parser.add_argument("-d", "--data", default="/data/sidd", metavar="DIR", help="path to imagenet dataset") 50 | parser.add_argument("-c", "--checkpoint", help="path to checkpoint") 51 | args = parser.parse_args() 52 | test(args) 53 | 54 | 55 | 56 | # vim: ts=4 sw=4 sts=4 expandtab 57 | -------------------------------------------------------------------------------- /train_mge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import bisect 4 | import multiprocessing 5 | import os 6 | import time 7 | import numpy as np 8 | # pylint: disable=import-error 9 | from model import UNetD 10 | 11 | import megengine 12 | import megengine.autodiff as autodiff 13 | import megengine.data as data 14 | import megengine.data.transform as T 15 | import megengine.distributed as dist 16 | import megengine.functional as F 17 | import megengine.optimizer as optim 18 | 19 | from dataset import SIDDData, SIDDValData 20 | from utils import batch_PSNR, MixUp_AUG 21 | logging = megengine.logger.get_logger() 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser(description="MegEngine NBNet") 26 | parser.add_argument("-d", "--data", default="/data/sidd", metavar="DIR", help="path to sidd dataset") 27 | parser.add_argument("--dnd", action='store_true', help="training for dnd benchmark") 28 | parser.add_argument( 29 | "-a", 30 | "--arch", 31 | default="NBNet", 32 | ) 33 | parser.add_argument( 34 | "-n", 35 | "--ngpus", 36 | default=None, 37 | type=int, 38 | help="number of GPUs per node (default: None, use all available GPUs)", 39 | ) 40 | parser.add_argument( 41 | "--save", 42 | metavar="DIR", 43 | default="output", 44 | help="path to save checkpoint and log", 45 | ) 46 | parser.add_argument( 47 | "--epochs", 48 | default=70, 49 | type=int, 50 | help="number of total epochs to run (default: 70)", 51 | ) 52 | 53 | parser.add_argument( 54 | "--steps_per_epoch", 55 | default=10000, 56 | type=int, 57 | help="number of steps for one epoch (default: 10000)", 58 | ) 59 | 60 | parser.add_argument( 61 | "-b", 62 | "--batch-size", 63 | metavar="SIZE", 64 | default=32, 65 | type=int, 66 | help="total batch size (default: 32)", 67 | ) 68 | parser.add_argument( 69 | "--lr", 70 | "--learning-rate", 71 | metavar="LR", 72 | default=2e-4, 73 | type=float, 74 | help="learning rate for single GPU (default: 0.0002)", 75 | ) 76 | 77 | parser.add_argument( 78 | "--weight-decay", default=1e-8, type=float, help="weight decay" 79 | ) 80 | 81 | parser.add_argument("-j", "--workers", default=8, type=int) 82 | 83 | parser.add_argument( 84 | "-p", 85 | "--print-freq", 86 | default=10, 87 | type=int, 88 | metavar="N", 89 | help="print frequency (default: 10)", 90 | ) 91 | 92 | 93 | 94 | args = parser.parse_args() 95 | # pylint: disable=unused-variable # noqa: F841 96 | 97 | # get device count 98 | if args.ngpus: 99 | ngpus_per_node = args.ngpus 100 | 101 | # launch processes 102 | train_proc = dist.launcher(worker) if ngpus_per_node > 1 else worker 103 | train_proc(args) 104 | 105 | def worker(args): 106 | # pylint: disable=too-many-statements 107 | rank = dist.get_rank() 108 | world_size = dist.get_world_size() 109 | if rank == 0: 110 | os.makedirs(os.path.join(args.save, args.arch), exist_ok=True) 111 | megengine.logger.set_log_file(os.path.join(args.save, args.arch, "log.txt")) 112 | # init process group 113 | 114 | # build dataset 115 | train_dataloader, valid_dataloader = build_dataset(args) 116 | train_queue = iter(train_dataloader) # infinite 117 | steps_per_epoch = args.steps_per_epoch 118 | 119 | # build model 120 | model = UNetD(3) 121 | # Sync parameters 122 | if world_size > 1: 123 | dist.bcast_list_(model.parameters(), dist.WORLD) 124 | 125 | # Autodiff gradient manager 126 | gm = autodiff.GradManager().attach( 127 | model.parameters(), 128 | callbacks=dist.make_allreduce_cb("SUM") if world_size > 1 else None, 129 | ) 130 | 131 | # Optimizer 132 | opt = optim.Adam( 133 | model.parameters(), 134 | lr=args.lr, 135 | weight_decay=args.weight_decay * world_size, # scale weight decay in "SUM" mode 136 | ) 137 | 138 | # mixup 139 | def preprocess(image, label): 140 | if args.dnd: 141 | image, label = MixUp_AUG(image, label) 142 | return image, label 143 | 144 | # train and valid func 145 | def train_step(image, label): 146 | with gm: 147 | logits = model(image) 148 | logits = image - logits 149 | loss = F.nn.l1_loss(logits, label) 150 | gm.backward(loss) 151 | opt.step().clear_grad() 152 | return loss 153 | 154 | def valid_step(image, label): 155 | pred = model(image) 156 | pred = image - pred 157 | mae_iter = F.nn.l1_loss(pred, label) 158 | psnr_it = batch_PSNR(pred, label) 159 | #print(psnr_it.item()) 160 | if world_size > 1: 161 | mae_iter = F.distributed.all_reduce_sum(mae_iter) / world_size 162 | psnr_it = F.distributed.all_reduce_sum(psnr_it) / world_size 163 | 164 | return mae_iter, psnr_it 165 | 166 | # multi-step learning rate scheduler with warmup 167 | def adjust_learning_rate(step): 168 | #lr = 1e-6 + 0.5 * (args.lr - 1e-6)*(1 + np.cos(step/(args.epochs*steps_per_epoch) * np.pi)) 169 | lr = args.lr * (np.cos(step / (steps_per_epoch * args.epochs) * np.pi) + 1) / 2 170 | for param_group in opt.param_groups: 171 | param_group["lr"] = lr 172 | return lr 173 | 174 | # start training 175 | for step in range(0, int(args.epochs * steps_per_epoch)): 176 | #print(step) 177 | lr = adjust_learning_rate(step) 178 | 179 | t_step = time.time() 180 | 181 | image, label = next(train_queue) 182 | if step > steps_per_epoch: 183 | image, label = preprocess(image, label) 184 | image = megengine.tensor(image) 185 | label = megengine.tensor(label) 186 | t_data = time.time() - t_step 187 | loss = train_step(image, label) 188 | t_train = time.time() - t_step 189 | speed = 1. / t_train 190 | if step % args.print_freq == 0 and dist.get_rank() == 0: 191 | logging.info( 192 | "Epoch {} Step {}, Speed={:.2g} mb/s, dp_cost={:.2g}, Loss={:5.2e}, lr={:.2e}".format( 193 | step // int(steps_per_epoch), 194 | step, 195 | speed, 196 | t_data/t_train, 197 | loss.item(), 198 | lr 199 | )) 200 | #print(steps_per_epoch) 201 | if (step + 1) % steps_per_epoch == 0: 202 | model.eval() 203 | loss, psnr_v = valid(valid_step, valid_dataloader) 204 | model.train() 205 | logging.info( 206 | "Epoch {} Test mae {:.3f}, psnr {:.3f}".format( 207 | (step + 1) // steps_per_epoch, 208 | loss.item(), 209 | psnr_v.item(), 210 | )) 211 | megengine.save( 212 | { 213 | "epoch": (step + 1) // steps_per_epoch, 214 | "state_dict": model.state_dict(), 215 | }, 216 | os.path.join(args.save, args.arch, "checkpoint.pkl"), 217 | ) if rank == 0 else None 218 | 219 | def valid(func, data_queue): 220 | loss = 0. 221 | psnr_v = 0. 222 | for step, (image, label) in enumerate(data_queue): 223 | image = megengine.tensor(image) 224 | label = megengine.tensor(label) 225 | mae_iter, psnr_it = func(image, label) 226 | loss += mae_iter 227 | psnr_v += psnr_it 228 | loss /= step + 1 229 | psnr_v /= step + 1 230 | return loss, psnr_v 231 | 232 | 233 | def build_dataset(args): 234 | assert not args.batch_size//args.ngpus == 0 and not 4 // args.ngpus == 0 235 | train_dataset = SIDDData(args.data, length=args.batch_size*args.steps_per_epoch) 236 | train_sampler = data.Infinite( 237 | data.RandomSampler(train_dataset, batch_size=args.batch_size//args.ngpus, drop_last=True) 238 | ) 239 | train_dataloader = data.DataLoader( 240 | train_dataset, 241 | sampler=train_sampler, 242 | num_workers=args.workers, 243 | ) 244 | valid_dataset = SIDDValData(args.data) 245 | valid_sampler = data.SequentialSampler( 246 | valid_dataset, batch_size=4//args.ngpus, drop_last=False 247 | ) 248 | valid_dataloader = data.DataLoader( 249 | valid_dataset, 250 | sampler=valid_sampler, 251 | num_workers=args.workers, 252 | ) 253 | return train_dataloader, valid_dataloader 254 | 255 | 256 | 257 | if __name__ == "__main__": 258 | main() 259 | 260 | # vim: ts=4 sw=4 sts=4 expandtab 261 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-01-22 22:07:08 4 | 5 | import math 6 | import torch 7 | import megengine.functional as F 8 | from skimage import img_as_ubyte 9 | import numpy as np 10 | import cv2 11 | 12 | 13 | def get_gausskernel(p, chn=3): 14 | ''' 15 | Build a 2-dimensional Gaussian filter with size p 16 | ''' 17 | x = cv2.getGaussianKernel(p, sigma=-1) # p x 1 18 | y = np.matmul(x, x.T)[np.newaxis, np.newaxis,] # 1x 1 x p x p 19 | out = np.tile(y, (chn, 1, 1, 1)) # chn x 1 x p x p 20 | 21 | return torch.from_numpy(out).type(torch.float32) 22 | 23 | def gaussblur(x, kernel, p=5, chn=3): 24 | x_pad = F.pad(x, pad=[int((p-1)/2),]*4, mode='reflect') 25 | y = F.conv2d(x_pad, kernel, padding=0, stride=1, groups=chn) 26 | 27 | return y 28 | 29 | 30 | def calculate_psnr(im1, im2, border=0): 31 | if not im1.shape == im2.shape: 32 | raise ValueError('Input images must have the same dimensions.') 33 | h, w = im1.shape[:2] 34 | im1 = im1[border:h-border, border:w-border] 35 | im2 = im2[border:h-border, border:w-border] 36 | 37 | mse = F.mean((im1 - im2)**2) 38 | if mse == 0: 39 | return float('inf') 40 | return 10 * F.log(1.0 / mse) / F.log(10.) 41 | 42 | def batch_PSNR(img, imclean, border=0): 43 | Img = img 44 | Iclean = imclean 45 | PSNR = 0 46 | for i in range(Img.shape[0]): 47 | PSNR += calculate_psnr(Iclean[i,:,], Img[i,:,], border) 48 | return (PSNR/Img.shape[0]) 49 | 50 | 51 | 52 | 53 | def MixUp_AUG(rgb_gt, rgb_noisy): 54 | bs = rgb_gt.shape[0] 55 | indices = np.arange(bs) 56 | np.random.shuffle(indices) 57 | rgb_gt2 = rgb_gt[indices] 58 | rgb_noisy2 = rgb_noisy[indices] 59 | 60 | lam = np.random.beta(1.2, 1.2, (bs, 1, 1, 1)) 61 | 62 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 63 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 64 | 65 | return rgb_gt, rgb_noisy 66 | 67 | --------------------------------------------------------------------------------