├── .gitignore
├── Fig
├── .DS_Store
├── fig1.png
└── fig2.png
├── LICENSE
├── README.md
├── dataset.py
├── model.py
├── prepare_data.py
├── test.py
├── train_mge.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/Fig/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MegEngine/NBNet/73112b185e022d0920f2f45c34c5bcf7c581d983/Fig/.DS_Store
--------------------------------------------------------------------------------
/Fig/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MegEngine/NBNet/73112b185e022d0920f2f45c34c5bcf7c581d983/Fig/fig1.png
--------------------------------------------------------------------------------
/Fig/fig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MegEngine/NBNet/73112b185e022d0920f2f45c34c5bcf7c581d983/Fig/fig2.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NBNet: Noise Basis Learning for Image Denoising with Subspace Projection
2 |
3 | Code for CVPR21 paper [NBNet](https://arxiv.org/abs/2012.15028).
4 |
5 | *The illustration of our key insight:*
6 |
7 |
8 |
9 | ## Dependencies
10 |
11 | - MegEngine >= 1.3.1 (For DistributedDataParallel)
12 |
13 |
14 |
15 | ## Training
16 |
17 | ### Preparation
18 |
19 | ```
20 | python prepare_data.py --data_dir yours_sidd_data_path
21 | ```
22 |
23 |
24 |
25 | ### Begin training:
26 |
27 | For SIDD benchmark, use:
28 |
29 | ```
30 | python train_mge.py -d prepared_data_path -n num_gpus
31 | ```
32 |
33 |
34 |
35 | For DnD benchmark, we use MixUp additionally:
36 |
37 | ```
38 | python train_mge.py -d prepared_data_path -n num_gpus --dnd
39 | ```
40 |
41 | ### Begin testing:
42 | Download the pretrained checkpoint and use:
43 |
44 | ```
45 | python test.py -d prepared_data_path -c checkpoint_path
46 | ```
47 | The result is **PSNR 39.765**.
48 |
49 |
50 |
51 | ## Pretrained model
52 |
53 | MegEngine checkpoint for SIDD benchmark can be downloaded via
54 | [Google Drive](https://drive.google.com/file/d/1RPAf9ZJqqq9ePPVTtJRlixX4-h3HJTCc/view?usp=sharing)
55 | or
56 | [GitHub Release](https://github.com/megvii-research/NBNet/releases).
57 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import megengine as mge
4 | from megengine.data.dataset import Dataset
5 |
6 | import os
7 | import json
8 | from pathlib import Path
9 | from typing import Iterator, Sequence
10 | from tqdm import tqdm
11 |
12 | import cv2
13 | import numpy as np
14 | import pickle as pkl
15 | from skimage import img_as_float32 as img_as_float
16 | import random
17 | from scipy.io import loadmat
18 | def is_png_file(filename):
19 | return any(filename.endswith(extension) for extension in [".png"])
20 |
21 | def data_augmentation(image, mode):
22 | """
23 | Performs data augmentation of the input image
24 | Input:
25 | image: a cv2 (OpenCV) image
26 | mode: int. Choice of transformation to apply to the image
27 | 0 - no transformation
28 | 1 - flip up and down
29 | 2 - rotate counterwise 90 degree
30 | 3 - rotate 90 degree and flip up and down
31 | 4 - rotate 180 degree
32 | 5 - rotate 180 degree and flip
33 | 6 - rotate 270 degree
34 | 7 - rotate 270 degree and flip
35 | """
36 | if mode == 0:
37 | # original
38 | out = image
39 | elif mode == 1:
40 | # flip up and down
41 | out = np.flipud(image)
42 | elif mode == 2:
43 | # rotate counterwise 90 degree
44 | out = np.rot90(image)
45 | elif mode == 3:
46 | # rotate 90 degree and flip up and down
47 | out = np.rot90(image)
48 | out = np.flipud(out)
49 | elif mode == 4:
50 | # rotate 180 degree
51 | out = np.rot90(image, k=2)
52 | elif mode == 5:
53 | # rotate 180 degree and flip
54 | out = np.rot90(image, k=2)
55 | out = np.flipud(out)
56 | elif mode == 6:
57 | # rotate 270 degree
58 | out = np.rot90(image, k=3)
59 | elif mode == 7:
60 | # rotate 270 degree and flip
61 | out = np.rot90(image, k=3)
62 | out = np.flipud(out)
63 | else:
64 | raise Exception('Invalid choice of image transformation')
65 |
66 | return out
67 |
68 | def random_augmentation(*args):
69 | out = []
70 | if random.randint(0,1) == 1:
71 | flag_aug = random.randint(1,7)
72 | for data in args:
73 | out.append(data_augmentation(data, flag_aug).copy())
74 | else:
75 | for data in args:
76 | out.append(data)
77 | return out
78 |
79 | def load_img(filepath):
80 | img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
81 | return img
82 |
83 | class SIDDData(Dataset):
84 | def __init__(self, path, length=None):
85 | '''
86 | Args:
87 | h5_path (str): path of the hdf5 file
88 | length (int): length of Datasets
89 | '''
90 | super(SIDDData, self).__init__()
91 | self.length = length
92 | clean_files = sorted(os.listdir(os.path.join(path, 'train', 'groundtruth')))
93 | noisy_files = sorted(os.listdir(os.path.join(path, 'train', 'input')))
94 |
95 | self.clean_filenames = [os.path.join(path, 'train', 'groundtruth', x) for x in clean_files if is_png_file(x)]
96 | self.noisy_filenames = [os.path.join(path, 'train', 'input', x) for x in noisy_files if is_png_file(x)]
97 | self.pch_size = 128
98 | def __len__(self):
99 | if self.length == None:
100 | return self.num_images
101 | else:
102 | return self.length
103 |
104 | def crop_patch(self, n_img, gt_img):
105 | H, W, C = n_img.shape
106 | # minus the bayer patter channel
107 | ind_H = random.randint(0, H-self.pch_size)
108 | ind_W = random.randint(0, W-self.pch_size)
109 | im_noisy = n_img[ind_H:ind_H+self.pch_size, ind_W:ind_W+self.pch_size, :]
110 | im_gt = gt_img[ind_H:ind_H+self.pch_size, ind_W:ind_W+self.pch_size, :]
111 | return im_noisy, im_gt
112 |
113 | def __getitem__(self, index):
114 | index = index % len(self.clean_filenames)
115 | # cv2.setNumThreads(0)
116 | noisy_img = load_img(self.noisy_filenames[index])
117 | gt_img = load_img(self.clean_filenames[index])
118 |
119 | # noisy_img = np.ascontiguousarray(noisy_img, dtype=np.float32)
120 | # gt_img = np.ascontiguousarray(gt_img, dtype=np.float32)
121 |
122 | noisy_img, gt_img = self.crop_patch(noisy_img, gt_img)
123 | gt_img = img_as_float(gt_img)
124 | noisy_img = img_as_float(noisy_img)
125 | noisy_img, gt_img = random_augmentation(noisy_img, gt_img)
126 |
127 | gt_img = gt_img.transpose((2, 0, 1))
128 | noisy_img = noisy_img.transpose((2, 0, 1))
129 | return noisy_img, gt_img
130 |
131 |
132 | class SIDDValData(Dataset):
133 | def __init__(self, path):
134 |
135 | val_data_dict = loadmat(os.path.join(path, 'ValidationNoisyBlocksSrgb.mat'))
136 | val_data_noisy = val_data_dict['ValidationNoisyBlocksSrgb']
137 | val_data_dict = loadmat(os.path.join(path,'ValidationGtBlocksSrgb.mat'))
138 | val_data_gt = val_data_dict['ValidationGtBlocksSrgb']
139 | self.num_img, self.num_block, h_, w_, c_ = val_data_gt.shape
140 | self.val_data_noisy = np.reshape(val_data_noisy, (-1, h_, w_, c_))
141 | self.val_data_gt = np.reshape(val_data_gt, (-1, h_, w_, c_))
142 |
143 |
144 | def __len__(self):
145 | return self.num_img*self.num_block
146 |
147 | def __getitem__(self, index):
148 |
149 | noisy_img, gt_img = self.val_data_noisy[index], self.val_data_gt[index]
150 | gt_img = img_as_float(gt_img)
151 | noisy_img = img_as_float(noisy_img)
152 | gt_img = gt_img.transpose((2, 0, 1))
153 | noisy_img = noisy_img.transpose((2, 0, 1))
154 | return noisy_img, gt_img
155 |
156 | # vim: ts=4 sw=4 sts=4 expandtab
157 |
158 |
159 | # vim: ts=4 sw=4 sts=4 expandtab
160 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import megengine as mge
3 | import megengine.module as nn
4 | import megengine.functional as F
5 |
6 | def conv3x3(in_chn, out_chn, bias=True):
7 | layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias)
8 | return layer
9 |
10 |
11 | def conv_down(in_chn, out_chn, bias=False):
12 | layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
13 | return layer
14 |
15 |
16 | class UNetD(nn.Module):
17 |
18 | def __init__(self, in_chn, wf=32, depth=5, relu_slope=0.2, subspace_dim=16):
19 | super(UNetD, self).__init__()
20 | self.depth = depth
21 | self.down_path = []
22 | prev_channels = self.get_input_chn(in_chn)
23 | for i in range(depth):
24 | downsample = True if (i+1) < depth else False
25 | self.down_path.append(UNetConvBlock(prev_channels, (2**i)*wf, downsample, relu_slope))
26 | prev_channels = (2**i) * wf
27 |
28 | # self.ema = EMAU(prev_channels, prev_channels//8)
29 | self.up_path = []
30 | subnet_repeat_num = 1
31 | for i in reversed(range(depth - 1)):
32 | self.up_path.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope, subnet_repeat_num, subspace_dim))
33 | prev_channels = (2**i)*wf
34 | subnet_repeat_num += 1
35 |
36 | self.last = conv3x3(prev_channels, in_chn, bias=True)
37 | #self._initialize()
38 |
39 | def forward(self, x1):
40 | blocks = []
41 | for i, down in enumerate(self.down_path):
42 | # print(x1.shape)
43 | if (i+1) < self.depth:
44 | x1, x1_up = down(x1)
45 | blocks.append(x1_up)
46 | else:
47 | x1 = down(x1)
48 | # print(x1.shape)
49 | # x1 = self.ema(x1)
50 | for i, up in enumerate(self.up_path):
51 | # print(x1.shape, blocks[-i-1].shape)
52 | x1 = up(x1, blocks[-i-1])
53 |
54 | pred = self.last(x1)
55 | return pred
56 |
57 | def get_input_chn(self, in_chn):
58 | return in_chn
59 |
60 | def _initialize(self):
61 | gain = nn.init.calculate_gain('leaky_relu', 0.20)
62 | for m in self.modules():
63 | if isinstance(m, nn.Conv2d):
64 | print("weight")
65 | nn.init.xavier_uniform_(m.weight)
66 | if m.bias is not None:
67 | print("bias")
68 | nn.init.zeros_(m.bias)
69 |
70 |
71 | class UNetConvBlock(nn.Module):
72 |
73 | def __init__(self, in_size, out_size, downsample, relu_slope):
74 | super(UNetConvBlock, self).__init__()
75 | self.block = nn.Sequential(
76 | nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True),
77 | nn.LeakyReLU(relu_slope),
78 | nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True),
79 | nn.LeakyReLU(relu_slope))
80 |
81 | self.downsample = downsample
82 | if downsample:
83 | self.downsample = conv_down(out_size, out_size, bias=False)
84 |
85 | self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)
86 |
87 | def forward(self, x):
88 | out = self.block(x)
89 | sc = self.shortcut(x)
90 | out = out + sc
91 | if self.downsample:
92 | out_down = self.downsample(out)
93 | return out_down, out
94 | else:
95 | return out
96 |
97 |
98 | class UNetUpBlock(nn.Module):
99 |
100 | def __init__(self, in_size, out_size, relu_slope, subnet_repeat_num, subspace_dim=16):
101 | super(UNetUpBlock, self).__init__()
102 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True)
103 | self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope)
104 | self.num_subspace = subspace_dim
105 | print(self.num_subspace, subnet_repeat_num)
106 | self.subnet = Subspace(in_size, self.num_subspace)
107 | self.skip_m = skip_blocks(out_size, out_size, subnet_repeat_num)
108 |
109 | def forward(self, x, bridge):
110 | up = self.up(x)
111 | bridge = self.skip_m(bridge)
112 | out = F.concat([up, bridge], 1)
113 | if self.subnet:
114 | b_, c_, h_, w_ = bridge.shape
115 | sub = self.subnet(out)
116 | V_t = sub.reshape(b_, self.num_subspace, h_*w_)
117 | V_t = V_t / (1e-6 + F.abs(V_t).sum(axis=2, keepdims=True))
118 | V = V_t.transpose(0, 2, 1)
119 | mat = F.matmul(V_t, V)
120 | mat_inv = F.matinv(mat)
121 | project_mat = F.matmul(mat_inv, V_t)
122 | bridge_ = bridge.reshape(b_, c_, h_*w_)
123 | project_feature = F.matmul(project_mat, bridge_.transpose(0, 2, 1))
124 | bridge = F.matmul(V, project_feature).transpose(0, 2, 1).reshape(b_, c_, h_, w_)
125 | out = F.concat([up, bridge], 1)
126 | out = self.conv_block(out)
127 | return out
128 |
129 |
130 | class Subspace(nn.Module):
131 |
132 | def __init__(self, in_size, out_size):
133 | super(Subspace, self).__init__()
134 | self.blocks = []
135 | self.blocks.append(UNetConvBlock(in_size, out_size, False, 0.2))
136 | self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)
137 |
138 | def forward(self, x):
139 | sc = self.shortcut(x)
140 | for i in range(len(self.blocks)):
141 | x = self.blocks[i](x)
142 | return x + sc
143 |
144 |
145 | class skip_blocks(nn.Module):
146 |
147 | def __init__(self, in_size, out_size, repeat_num=1):
148 | super(skip_blocks, self).__init__()
149 | self.blocks = []
150 | self.re_num = repeat_num
151 | mid_c = 128
152 | self.blocks.append(UNetConvBlock(in_size, mid_c, False, 0.2))
153 | for i in range(self.re_num - 2):
154 | self.blocks.append(UNetConvBlock(mid_c, mid_c, False, 0.2))
155 | self.blocks.append(UNetConvBlock(mid_c, out_size, False, 0.2))
156 | self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)
157 |
158 | def forward(self, x):
159 | sc = self.shortcut(x)
160 | for m in self.blocks:
161 | x = m(x)
162 | return x + sc
163 |
164 |
165 | if __name__ == "__main__":
166 | import numpy as np
167 | a = UNetD(3)
168 |
169 | #print(a)
170 | im = mge.tensor(np.random.randn(1, 3, 128, 128).astype(np.float32))
171 | print(a(im))
172 |
173 |
--------------------------------------------------------------------------------
/prepare_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import cv2
4 | import numpy as np
5 | import h5py as h5
6 | import argparse
7 | from joblib import Parallel, delayed
8 | import multiprocessing
9 | from tqdm import tqdm
10 | parser = argparse.ArgumentParser(prog='SIDD Train dataset Generation')
11 | # The orignal SIDD images: /ssd1t/SIDD/
12 | parser.add_argument('--data_dir', default='/data/sidd/SIDD_Medium_Srgb/Data', type=str, metavar='PATH',
13 | help="path to save the training set of SIDD, (default: None)")
14 | parser.add_argument('--tar_dir', default='/data/sidd/train',type=str, help='Directory for image patches')
15 | args = parser.parse_args()
16 | tar = args.tar_dir
17 |
18 | noisy_patchDir = os.path.join(tar, 'input')
19 | clean_patchDir = os.path.join(tar, 'groundtruth')
20 |
21 | if os.path.exists(tar):
22 | os.system("rm -r {}".format(tar))
23 |
24 | os.makedirs(noisy_patchDir)
25 | os.makedirs(clean_patchDir)
26 | path_all_noisy = glob(os.path.join(args.data_dir, '**/*NOISY*.PNG'), recursive=True)
27 | path_all_noisy = sorted(path_all_noisy)
28 | path_all_gt = [x.replace('NOISY', 'GT') for x in path_all_noisy]
29 | print('Number of big images: {:d}'.format(len(path_all_gt)))
30 |
31 | print('Training: Split the original images to small ones!')
32 | path_h5 = os.path.join(args.data_dir, 'small_imgs_train.hdf5')
33 | if os.path.exists(path_h5):
34 | os.remove(path_h5)
35 | pch_size = 512
36 | stride = 512-128
37 | num_patch = 0
38 | C = 3
39 |
40 | def save_files(ii):
41 | im_noisy_int8 = cv2.imread(path_all_noisy[ii])
42 | H, W, _ = im_noisy_int8.shape
43 | im_gt_int8 = cv2.imread(path_all_gt[ii])
44 | ind_H = list(range(0, H-pch_size+1, stride))
45 | if ind_H[-1] < H-pch_size:
46 | ind_H.append(H-pch_size)
47 | ind_W = list(range(0, W-pch_size+1, stride))
48 | if ind_W[-1] < W-pch_size:
49 | ind_W.append(W-pch_size)
50 | count = 1
51 | for start_H in ind_H:
52 | for start_W in ind_W:
53 | pch_noisy = im_noisy_int8[start_H:start_H+pch_size, start_W:start_W+pch_size, ]
54 | pch_gt = im_gt_int8[start_H:start_H+pch_size, start_W:start_W+pch_size, ]
55 | cv2.imwrite(os.path.join(noisy_patchDir, '{}_{}.png'.format(ii+1,count+1)), pch_noisy)
56 | cv2.imwrite(os.path.join(clean_patchDir, '{}_{}.png'.format(ii+1,count+1)), pch_gt)
57 | count += 1
58 | Parallel(n_jobs=10)(delayed(save_files)(i) for i in tqdm(range(len(path_all_gt))))
59 | print('Finish!\n')
60 |
61 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from dataset import SIDDValData
3 | from model import UNetD
4 | import megengine.data as data
5 | from utils import batch_PSNR
6 | from tqdm import tqdm
7 | import argparse
8 | import pickle
9 | import megengine
10 |
11 |
12 | def test(args):
13 | valid_dataset = SIDDValData(args.data)
14 | valid_sampler = data.SequentialSampler(
15 | valid_dataset, batch_size=1, drop_last=False
16 | )
17 | valid_dataloader = data.DataLoader(
18 | valid_dataset,
19 | sampler=valid_sampler,
20 | num_workers=8,
21 | )
22 | model = UNetD(3)
23 | with open(args.checkpoint, "rb") as f:
24 | state = pickle.load(f)
25 | model.load_state_dict(state["state_dict"])
26 | model.eval()
27 |
28 | def valid_step(image, label):
29 | pred = model(image)
30 | pred = image - pred
31 | psnr_it = batch_PSNR(pred, label)
32 | return psnr_it
33 |
34 | def valid(func, data_queue):
35 | psnr_v = 0.
36 | for step, (image, label) in tqdm(enumerate(data_queue)):
37 | image = megengine.tensor(image)
38 | label = megengine.tensor(label)
39 | psnr_it = func(image, label)
40 | psnr_v += psnr_it
41 | psnr_v /= step + 1
42 | return psnr_v
43 |
44 | psnr_v = valid(valid_step, valid_dataloader)
45 | print("PSNR: {:.3f}".format(psnr_v.item()) )
46 |
47 | if __name__ == "__main__":
48 | parser = argparse.ArgumentParser(description="MegEngine NBNet")
49 | parser.add_argument("-d", "--data", default="/data/sidd", metavar="DIR", help="path to imagenet dataset")
50 | parser.add_argument("-c", "--checkpoint", help="path to checkpoint")
51 | args = parser.parse_args()
52 | test(args)
53 |
54 |
55 |
56 | # vim: ts=4 sw=4 sts=4 expandtab
57 |
--------------------------------------------------------------------------------
/train_mge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import bisect
4 | import multiprocessing
5 | import os
6 | import time
7 | import numpy as np
8 | # pylint: disable=import-error
9 | from model import UNetD
10 |
11 | import megengine
12 | import megengine.autodiff as autodiff
13 | import megengine.data as data
14 | import megengine.data.transform as T
15 | import megengine.distributed as dist
16 | import megengine.functional as F
17 | import megengine.optimizer as optim
18 |
19 | from dataset import SIDDData, SIDDValData
20 | from utils import batch_PSNR, MixUp_AUG
21 | logging = megengine.logger.get_logger()
22 |
23 |
24 | def main():
25 | parser = argparse.ArgumentParser(description="MegEngine NBNet")
26 | parser.add_argument("-d", "--data", default="/data/sidd", metavar="DIR", help="path to sidd dataset")
27 | parser.add_argument("--dnd", action='store_true', help="training for dnd benchmark")
28 | parser.add_argument(
29 | "-a",
30 | "--arch",
31 | default="NBNet",
32 | )
33 | parser.add_argument(
34 | "-n",
35 | "--ngpus",
36 | default=None,
37 | type=int,
38 | help="number of GPUs per node (default: None, use all available GPUs)",
39 | )
40 | parser.add_argument(
41 | "--save",
42 | metavar="DIR",
43 | default="output",
44 | help="path to save checkpoint and log",
45 | )
46 | parser.add_argument(
47 | "--epochs",
48 | default=70,
49 | type=int,
50 | help="number of total epochs to run (default: 70)",
51 | )
52 |
53 | parser.add_argument(
54 | "--steps_per_epoch",
55 | default=10000,
56 | type=int,
57 | help="number of steps for one epoch (default: 10000)",
58 | )
59 |
60 | parser.add_argument(
61 | "-b",
62 | "--batch-size",
63 | metavar="SIZE",
64 | default=32,
65 | type=int,
66 | help="total batch size (default: 32)",
67 | )
68 | parser.add_argument(
69 | "--lr",
70 | "--learning-rate",
71 | metavar="LR",
72 | default=2e-4,
73 | type=float,
74 | help="learning rate for single GPU (default: 0.0002)",
75 | )
76 |
77 | parser.add_argument(
78 | "--weight-decay", default=1e-8, type=float, help="weight decay"
79 | )
80 |
81 | parser.add_argument("-j", "--workers", default=8, type=int)
82 |
83 | parser.add_argument(
84 | "-p",
85 | "--print-freq",
86 | default=10,
87 | type=int,
88 | metavar="N",
89 | help="print frequency (default: 10)",
90 | )
91 |
92 |
93 |
94 | args = parser.parse_args()
95 | # pylint: disable=unused-variable # noqa: F841
96 |
97 | # get device count
98 | if args.ngpus:
99 | ngpus_per_node = args.ngpus
100 |
101 | # launch processes
102 | train_proc = dist.launcher(worker) if ngpus_per_node > 1 else worker
103 | train_proc(args)
104 |
105 | def worker(args):
106 | # pylint: disable=too-many-statements
107 | rank = dist.get_rank()
108 | world_size = dist.get_world_size()
109 | if rank == 0:
110 | os.makedirs(os.path.join(args.save, args.arch), exist_ok=True)
111 | megengine.logger.set_log_file(os.path.join(args.save, args.arch, "log.txt"))
112 | # init process group
113 |
114 | # build dataset
115 | train_dataloader, valid_dataloader = build_dataset(args)
116 | train_queue = iter(train_dataloader) # infinite
117 | steps_per_epoch = args.steps_per_epoch
118 |
119 | # build model
120 | model = UNetD(3)
121 | # Sync parameters
122 | if world_size > 1:
123 | dist.bcast_list_(model.parameters(), dist.WORLD)
124 |
125 | # Autodiff gradient manager
126 | gm = autodiff.GradManager().attach(
127 | model.parameters(),
128 | callbacks=dist.make_allreduce_cb("SUM") if world_size > 1 else None,
129 | )
130 |
131 | # Optimizer
132 | opt = optim.Adam(
133 | model.parameters(),
134 | lr=args.lr,
135 | weight_decay=args.weight_decay * world_size, # scale weight decay in "SUM" mode
136 | )
137 |
138 | # mixup
139 | def preprocess(image, label):
140 | if args.dnd:
141 | image, label = MixUp_AUG(image, label)
142 | return image, label
143 |
144 | # train and valid func
145 | def train_step(image, label):
146 | with gm:
147 | logits = model(image)
148 | logits = image - logits
149 | loss = F.nn.l1_loss(logits, label)
150 | gm.backward(loss)
151 | opt.step().clear_grad()
152 | return loss
153 |
154 | def valid_step(image, label):
155 | pred = model(image)
156 | pred = image - pred
157 | mae_iter = F.nn.l1_loss(pred, label)
158 | psnr_it = batch_PSNR(pred, label)
159 | #print(psnr_it.item())
160 | if world_size > 1:
161 | mae_iter = F.distributed.all_reduce_sum(mae_iter) / world_size
162 | psnr_it = F.distributed.all_reduce_sum(psnr_it) / world_size
163 |
164 | return mae_iter, psnr_it
165 |
166 | # multi-step learning rate scheduler with warmup
167 | def adjust_learning_rate(step):
168 | #lr = 1e-6 + 0.5 * (args.lr - 1e-6)*(1 + np.cos(step/(args.epochs*steps_per_epoch) * np.pi))
169 | lr = args.lr * (np.cos(step / (steps_per_epoch * args.epochs) * np.pi) + 1) / 2
170 | for param_group in opt.param_groups:
171 | param_group["lr"] = lr
172 | return lr
173 |
174 | # start training
175 | for step in range(0, int(args.epochs * steps_per_epoch)):
176 | #print(step)
177 | lr = adjust_learning_rate(step)
178 |
179 | t_step = time.time()
180 |
181 | image, label = next(train_queue)
182 | if step > steps_per_epoch:
183 | image, label = preprocess(image, label)
184 | image = megengine.tensor(image)
185 | label = megengine.tensor(label)
186 | t_data = time.time() - t_step
187 | loss = train_step(image, label)
188 | t_train = time.time() - t_step
189 | speed = 1. / t_train
190 | if step % args.print_freq == 0 and dist.get_rank() == 0:
191 | logging.info(
192 | "Epoch {} Step {}, Speed={:.2g} mb/s, dp_cost={:.2g}, Loss={:5.2e}, lr={:.2e}".format(
193 | step // int(steps_per_epoch),
194 | step,
195 | speed,
196 | t_data/t_train,
197 | loss.item(),
198 | lr
199 | ))
200 | #print(steps_per_epoch)
201 | if (step + 1) % steps_per_epoch == 0:
202 | model.eval()
203 | loss, psnr_v = valid(valid_step, valid_dataloader)
204 | model.train()
205 | logging.info(
206 | "Epoch {} Test mae {:.3f}, psnr {:.3f}".format(
207 | (step + 1) // steps_per_epoch,
208 | loss.item(),
209 | psnr_v.item(),
210 | ))
211 | megengine.save(
212 | {
213 | "epoch": (step + 1) // steps_per_epoch,
214 | "state_dict": model.state_dict(),
215 | },
216 | os.path.join(args.save, args.arch, "checkpoint.pkl"),
217 | ) if rank == 0 else None
218 |
219 | def valid(func, data_queue):
220 | loss = 0.
221 | psnr_v = 0.
222 | for step, (image, label) in enumerate(data_queue):
223 | image = megengine.tensor(image)
224 | label = megengine.tensor(label)
225 | mae_iter, psnr_it = func(image, label)
226 | loss += mae_iter
227 | psnr_v += psnr_it
228 | loss /= step + 1
229 | psnr_v /= step + 1
230 | return loss, psnr_v
231 |
232 |
233 | def build_dataset(args):
234 | assert not args.batch_size//args.ngpus == 0 and not 4 // args.ngpus == 0
235 | train_dataset = SIDDData(args.data, length=args.batch_size*args.steps_per_epoch)
236 | train_sampler = data.Infinite(
237 | data.RandomSampler(train_dataset, batch_size=args.batch_size//args.ngpus, drop_last=True)
238 | )
239 | train_dataloader = data.DataLoader(
240 | train_dataset,
241 | sampler=train_sampler,
242 | num_workers=args.workers,
243 | )
244 | valid_dataset = SIDDValData(args.data)
245 | valid_sampler = data.SequentialSampler(
246 | valid_dataset, batch_size=4//args.ngpus, drop_last=False
247 | )
248 | valid_dataloader = data.DataLoader(
249 | valid_dataset,
250 | sampler=valid_sampler,
251 | num_workers=args.workers,
252 | )
253 | return train_dataloader, valid_dataloader
254 |
255 |
256 |
257 | if __name__ == "__main__":
258 | main()
259 |
260 | # vim: ts=4 sw=4 sts=4 expandtab
261 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Power by Zongsheng Yue 2019-01-22 22:07:08
4 |
5 | import math
6 | import torch
7 | import megengine.functional as F
8 | from skimage import img_as_ubyte
9 | import numpy as np
10 | import cv2
11 |
12 |
13 | def get_gausskernel(p, chn=3):
14 | '''
15 | Build a 2-dimensional Gaussian filter with size p
16 | '''
17 | x = cv2.getGaussianKernel(p, sigma=-1) # p x 1
18 | y = np.matmul(x, x.T)[np.newaxis, np.newaxis,] # 1x 1 x p x p
19 | out = np.tile(y, (chn, 1, 1, 1)) # chn x 1 x p x p
20 |
21 | return torch.from_numpy(out).type(torch.float32)
22 |
23 | def gaussblur(x, kernel, p=5, chn=3):
24 | x_pad = F.pad(x, pad=[int((p-1)/2),]*4, mode='reflect')
25 | y = F.conv2d(x_pad, kernel, padding=0, stride=1, groups=chn)
26 |
27 | return y
28 |
29 |
30 | def calculate_psnr(im1, im2, border=0):
31 | if not im1.shape == im2.shape:
32 | raise ValueError('Input images must have the same dimensions.')
33 | h, w = im1.shape[:2]
34 | im1 = im1[border:h-border, border:w-border]
35 | im2 = im2[border:h-border, border:w-border]
36 |
37 | mse = F.mean((im1 - im2)**2)
38 | if mse == 0:
39 | return float('inf')
40 | return 10 * F.log(1.0 / mse) / F.log(10.)
41 |
42 | def batch_PSNR(img, imclean, border=0):
43 | Img = img
44 | Iclean = imclean
45 | PSNR = 0
46 | for i in range(Img.shape[0]):
47 | PSNR += calculate_psnr(Iclean[i,:,], Img[i,:,], border)
48 | return (PSNR/Img.shape[0])
49 |
50 |
51 |
52 |
53 | def MixUp_AUG(rgb_gt, rgb_noisy):
54 | bs = rgb_gt.shape[0]
55 | indices = np.arange(bs)
56 | np.random.shuffle(indices)
57 | rgb_gt2 = rgb_gt[indices]
58 | rgb_noisy2 = rgb_noisy[indices]
59 |
60 | lam = np.random.beta(1.2, 1.2, (bs, 1, 1, 1))
61 |
62 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
63 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
64 |
65 | return rgb_gt, rgb_noisy
66 |
67 |
--------------------------------------------------------------------------------