├── .gitignore ├── LICENSE ├── README.md ├── ffdnet.py ├── model.py ├── models └── .gitkeep ├── test_data ├── color.png └── gray.jpg ├── test_run.sh ├── train_data ├── gray │ ├── train │ │ └── .gitkeep │ └── val │ │ └── .gitkeep └── rgb │ ├── train │ └── .gitkeep │ └── val │ └── .gitkeep └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | models/*.* 4 | train_data/gray/train/*.* 5 | train_data/gray/val/*.* 6 | train_data/rgb/train/*.* 7 | train_data/rgb/val/*.* 8 | !.gitkeep 9 | 10 | *.zip 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 青いほしぞら 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FFDNet_pytorch 2 | + A PyTorch implementation of a denoising network called [FFDNet](https://github.com/cszn/FFDNet) 3 | + Paper: FFDNet: Toward a Fast and Flexible Solution for CNN based Image Denoising - [arxiv](https://arxiv.org/abs/1710.04026) / [IEEE](https://ieeexplore.ieee.org/abstract/document/8365806/) 4 | 5 | ### Dataset 6 | 7 | + [Waterloo Exploration Database](https://ece.uwaterloo.ca/~k29ma/exploration/) 8 | 9 | ### Usage 10 | 11 | + Train 12 | 13 | ```bash 14 | python3 ffdnet.py \ 15 | --use_gpu \ 16 | --is_train \ 17 | --train_path './train_data/' \ 18 | --model_path './models/' \ 19 | --batch_size 768 \ 20 | --epoches 80 \ 21 | --val_epoch 5 22 | --patch_size 32 \ 23 | --save_checkpoints 20 \ 24 | --train_noise_interval 15 75 15 \ 25 | --val_noise_interval 30 60 30 \ 26 | ``` 27 | 28 | + Test 29 | 30 | ```bash 31 | python3 ffdnet.py \ 32 | --use_gpu \ 33 | --is_test \ 34 | --test_path './test_data/color.png' \ 35 | --model_path './models/' \ 36 | --add_noise 37 | --noise_sigma 30 38 | ``` 39 | 40 | ### References 41 | 42 | + Some codes are copied from [An Analysis and Implementation of the FFDNet Image Denoising Method](http://www.ipol.im/pub/pre/231/) -------------------------------------------------------------------------------- /ffdnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import cv2 4 | import os 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torch.autograd import Variable 12 | from torch.utils.data import DataLoader 13 | 14 | from model import FFDNet 15 | import utils 16 | 17 | def read_image(image_path, is_gray): 18 | """ 19 | :return: Normalized Image (C * W * H) 20 | """ 21 | if is_gray: 22 | image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) 23 | image = np.expand_dims(image.T, 0) # 1 * W * H 24 | else: 25 | image = cv2.imread(image_path) 26 | image = (cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).transpose(2, 1, 0) # 3 * W * H 27 | 28 | return utils.normalize(image) 29 | 30 | def load_images(is_train, is_gray, base_path): 31 | """ 32 | :param base_path: ./train_data/ 33 | :return: List[Patches] (C * W * H) 34 | """ 35 | if is_gray: 36 | train_dir = 'gray/train/' 37 | val_dir = 'gray/val/' 38 | else: 39 | train_dir = 'rgb/train/' 40 | val_dir = 'rgb/val/' 41 | 42 | image_dir = base_path.replace('\'', '').replace('"', '') + (train_dir if is_train else val_dir) 43 | print('> Loading images in ' + image_dir) 44 | images = [] 45 | for fn in next(os.walk(image_dir))[2]: 46 | image = read_image(image_dir + fn, is_gray) 47 | images.append(image) 48 | return images 49 | 50 | def images_to_patches(images, patch_size): 51 | """ 52 | :param images: List[Image (C * W * H)] 53 | :param patch_size: int 54 | :return: (n * C * W * H) 55 | """ 56 | patches_list = [] 57 | for image in images: 58 | patches = utils.image_to_patches(image, patch_size=patch_size) 59 | if len(patches) != 0: 60 | patches_list.append(patches) 61 | del images 62 | return np.vstack(patches_list) 63 | 64 | def train(args): 65 | print('> Loading dataset...') 66 | # Images 67 | train_dataset = load_images(is_train=True, is_gray=args.is_gray, base_path=args.train_path) 68 | val_dataset = load_images(is_train=False, is_gray=args.is_gray, base_path=args.train_path) 69 | print(f'\tTrain image datasets: {len(train_dataset)}') 70 | print(f'\tVal image datasets: {len(val_dataset)}') 71 | 72 | # Patches 73 | train_dataset = images_to_patches(train_dataset, patch_size=args.patch_size) 74 | val_dataset = images_to_patches(val_dataset, patch_size=args.patch_size) 75 | print(f'\tTrain patch datasets: {train_dataset.shape}') 76 | print(f'\tVal patch datasets: {val_dataset.shape}') 77 | 78 | # DataLoader 79 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=6) 80 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=6) 81 | print(f'\tTrain batch number: {len(train_dataloader)}') 82 | print(f'\tVal batch number: {len(val_dataloader)}') 83 | 84 | # Noise list 85 | train_noises = args.train_noise_interval # [0, 75, 15] 86 | val_noises = args.val_noise_interval # [0, 60, 30] 87 | train_noises = list(range(train_noises[0], train_noises[1], train_noises[2])) 88 | val_noises = list(range(val_noises[0], val_noises[1], val_noises[2])) 89 | print(f'\tTrain noise internal: {train_noises}') 90 | print(f'\tVal noise internal: {val_noises}') 91 | print('\n') 92 | 93 | # Model & Optim 94 | model = FFDNet(is_gray=args.is_gray) 95 | model.apply(utils.weights_init_kaiming) 96 | if args.cuda: 97 | model = model.cuda() 98 | loss_fn = nn.MSELoss(reduction='sum') 99 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) 100 | 101 | print('> Start training...') 102 | for epoch_idx in range(args.epoches): 103 | # Train 104 | loss_idx = 0 105 | train_losses = 0 106 | model.train() 107 | 108 | start_time = time.time() 109 | for batch_idx, batch_data in enumerate(train_dataloader): 110 | # According to internal, add noise 111 | for int_noise_sigma in train_noises: 112 | noise_sigma = int_noise_sigma / 255 113 | new_images = utils.add_batch_noise(batch_data, noise_sigma) 114 | noise_sigma = torch.FloatTensor(np.array([noise_sigma for idx in range(new_images.shape[0])])) 115 | new_images = Variable(new_images) 116 | noise_sigma = Variable(noise_sigma) 117 | if args.cuda: 118 | new_images = new_images.cuda() 119 | noise_sigma = noise_sigma.cuda() 120 | 121 | # Predict 122 | images_pred = model(new_images, noise_sigma) 123 | train_loss = loss_fn(images_pred, batch_data.to(images_pred.device)) 124 | train_losses += train_loss 125 | loss_idx += 1 126 | 127 | optimizer.zero_grad() 128 | train_loss.backward() 129 | optimizer.step() 130 | 131 | # Log Progress 132 | stop_time = time.time() 133 | all_num = len(train_dataloader) * len(train_noises) 134 | done_num = batch_idx * len(train_noises) + train_noises.index(int_noise_sigma) + 1 135 | rest_time = int((stop_time - start_time) / done_num * (all_num - done_num)) 136 | percent = int(done_num / all_num * 100) 137 | print(f'\rEpoch: {epoch_idx + 1} / {args.epoches}, ' + 138 | f'Batch: {batch_idx + 1} / {len(train_dataloader)}, ' + 139 | f'Noise_Sigma: {int_noise_sigma} / {train_noises[-1]}, ' + 140 | f'Train_Loss: {train_loss}, ' + 141 | f'=> {rest_time}s, {percent}%', end='') 142 | 143 | train_losses /= loss_idx 144 | print(f', Avg_Train_Loss: {train_losses}, All: {int(stop_time - start_time)}s') 145 | 146 | # Evaluate 147 | loss_idx = 0 148 | val_losses = 0 149 | if (epoch_idx + 1) % args.val_epoch != 0: 150 | continue 151 | model.eval() 152 | 153 | start_time = time.time() 154 | for batch_idx, batch_data in enumerate(val_dataloader): 155 | # According to internal, add noise 156 | for int_noise_sigma in val_noises: 157 | noise_sigma = int_noise_sigma / 255 158 | new_images = utils.add_batch_noise(batch_data, noise_sigma) 159 | noise_sigma = torch.FloatTensor(np.array([noise_sigma for idx in range(new_images.shape[0])])) 160 | new_images = Variable(new_images) 161 | noise_sigma = Variable(noise_sigma) 162 | if args.cuda: 163 | new_images = new_images.cuda() 164 | noise_sigma = noise_sigma.cuda() 165 | 166 | # Predict 167 | images_pred = model(new_images, noise_sigma) 168 | val_loss = loss_fn(images_pred, batch_data.to(images_pred.device)) 169 | val_losses += val_loss 170 | loss_idx += 1 171 | 172 | # Log Progress 173 | stop_time = time.time() 174 | all_num = len(val_dataloader) * len(val_noises) 175 | done_num = batch_idx * len(val_noises) + val_noises.index(int_noise_sigma) + 1 176 | rest_time = int((stop_time - start_time) / done_num * (all_num - done_num)) 177 | percent = int(done_num / all_num * 100) 178 | print(f'\rEpoch: {epoch_idx + 1} / {args.epoches}, ' + 179 | f'Batch: {batch_idx + 1} / {len(val_dataloader)}, ' + 180 | f'Noise_Sigma: {int_noise_sigma} / {val_noises[-1]}, ' + 181 | f'Val_Loss: {val_loss}, ' + 182 | f'=> {rest_time}s, {percent}%', end='') 183 | 184 | val_losses /= loss_idx 185 | print(f', Avg_Val_Loss: {val_losses}, All: {int(stop_time - start_time)}s') 186 | 187 | # Save Checkpoint 188 | if (epoch_idx + 1) % args.save_checkpoints == 0: 189 | model_path = args.model_path + ('net_gray_checkpoint.pth' if args.is_gray else 'net_rgb_checkpoint.pth') 190 | torch.save(model.state_dict(), model_path) 191 | print(f'| Saved Checkpoint at Epoch {epoch_idx + 1} to {model_path}') 192 | 193 | # Final Save Model Dict 194 | model.eval() 195 | model_path = args.model_path + ('net_gray.pth' if args.is_gray else 'net_rgb.pth') 196 | torch.save(model.state_dict(), model_path) 197 | print(f'Saved State Dict in {model_path}') 198 | print('\n') 199 | 200 | def test(args): 201 | # Image 202 | image = cv2.imread(args.test_path) 203 | if image is None: 204 | raise Exception(f'File {args.test_path} not found or error') 205 | is_gray = utils.is_image_gray(image) 206 | image = read_image(args.test_path, is_gray) 207 | print("{} image shape: {}".format("Gray" if is_gray else "RGB", image.shape)) 208 | 209 | # Expand odd shape to even 210 | expend_W = False 211 | expend_H = False 212 | if image.shape[1] % 2 != 0: 213 | expend_W = True 214 | image = np.concatenate((image, image[:, -1, :][:, np.newaxis, :]), axis=1) 215 | if image.shape[2] % 2 != 0: 216 | expend_H = True 217 | image = np.concatenate((image, image[:, :, -1][:, :, np.newaxis]), axis=2) 218 | 219 | # Noise 220 | image = torch.FloatTensor([image]) # 1 * C(1 / 3) * W * H 221 | if args.add_noise: 222 | image = utils.add_batch_noise(image, args.noise_sigma) 223 | noise_sigma = torch.FloatTensor([args.noise_sigma]) 224 | 225 | # Model & GPU 226 | model = FFDNet(is_gray=is_gray) 227 | if args.cuda: 228 | image = image.cuda() 229 | noise_sigma = noise_sigma.cuda() 230 | model = model.cuda() 231 | 232 | # Dict 233 | model_path = args.model_path + ('net_gray.pth' if is_gray else 'net_rgb.pth') 234 | print(f"> Loading model param in {model_path}...") 235 | state_dict = torch.load(model_path) 236 | model.load_state_dict(state_dict) 237 | model.eval() 238 | print('\n') 239 | 240 | # Test 241 | with torch.no_grad(): 242 | start_time = time.time() 243 | image_pred = model(image, noise_sigma) 244 | stop_time = time.time() 245 | print("Test time: {0:.4f}s".format(stop_time - start_time)) 246 | 247 | # PSNR 248 | psnr = utils.batch_psnr(img=image_pred, imclean=image, data_range=1) 249 | print("PSNR denoised {0:.2f}dB".format(psnr)) 250 | 251 | # UnExpand odd 252 | if expend_W: 253 | image_pred = image_pred[:, :, :-1, :] 254 | if expend_H: 255 | image_pred = image_pred[:, :, :, :-1] 256 | 257 | # Save 258 | cv2.imwrite("ffdnet.png", utils.variable_to_cv2_image(image_pred)) 259 | if args.add_noise: 260 | cv2.imwrite("noisy.png", utils.variable_to_cv2_image(image)) 261 | 262 | def main(): 263 | parser = argparse.ArgumentParser() 264 | 265 | # Train 266 | parser.add_argument("--train_path", type=str, default='./train_data/', help='Train dataset dir.') 267 | parser.add_argument("--is_gray", action='store_true', help='Train gray/rgb model.') 268 | parser.add_argument("--patch_size", type=int, default=32, help='Uniform size of training images patches.') 269 | parser.add_argument("--train_noise_interval", nargs=3, type=int, default=[0, 75, 15], help='Train dataset noise sigma set interval.') 270 | parser.add_argument("--val_noise_interval", nargs=3, type=int, default=[0, 60, 30], help='Validation dataset noise sigma set interval.') 271 | parser.add_argument("--batch_size", type=int, default=256, help='Batch size for training.') 272 | parser.add_argument("--epoches", type=int, default=80, help='Total number of training epoches.') 273 | parser.add_argument("--val_epoch", type=int, default=5, help='Total number of validation epoches.') 274 | parser.add_argument("--learning_rate", type=float, default=1e-3, help='The initial learning rate for Adam.') 275 | parser.add_argument("--save_checkpoints", type=int, default=5, help='Save checkpoint every epoch.') 276 | 277 | # Test 278 | parser.add_argument("--test_path", type=str, default='./test_data/color.png', help='Test image path.') 279 | parser.add_argument("--noise_sigma", type=float, default=25, help='Input uniform noise sigma for test.') 280 | parser.add_argument('--add_noise', action='store_true', help='Add noise_sigma to input or not.') 281 | 282 | # Global 283 | parser.add_argument("--model_path", type=str, default='./models/', help='Model loading and saving path.') 284 | parser.add_argument("--use_gpu", action='store_true', help='Train and test using GPU.') 285 | parser.add_argument("--is_train", action='store_true', help='Do train.') 286 | parser.add_argument("--is_test", action='store_true', help='Do test.') 287 | 288 | args = parser.parse_args() 289 | assert (args.is_train or args.is_test), 'is_train 和 is_test 至少有一个为 True' 290 | 291 | args.cuda = args.use_gpu and torch.cuda.is_available() 292 | print("> Parameters: ") 293 | for k, v in zip(args.__dict__.keys(), args.__dict__.values()): 294 | print(f'\t{k}: {v}') 295 | print('\n') 296 | 297 | # Normalize noise level 298 | args.noise_sigma /= 255 299 | args.train_noise_interval[1] += 1 300 | args.val_noise_interval[1] += 1 301 | 302 | if args.is_train: 303 | train(args) 304 | 305 | if args.is_test: 306 | test(args) 307 | 308 | if __name__ == "__main__": 309 | main() 310 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | 7 | import utils 8 | 9 | class FFDNet(nn.Module): 10 | 11 | def __init__(self, is_gray): 12 | super(FFDNet, self).__init__() 13 | 14 | if is_gray: 15 | self.num_conv_layers = 15 # all layers number 16 | self.downsampled_channels = 5 # Conv_Relu in 17 | self.num_feature_maps = 64 # Conv_Bn_Relu in 18 | self.output_features = 4 # Conv out 19 | else: 20 | self.num_conv_layers = 12 21 | self.downsampled_channels = 15 22 | self.num_feature_maps = 96 23 | self.output_features = 12 24 | 25 | self.kernel_size = 3 26 | self.padding = 1 27 | 28 | layers = [] 29 | # Conv + Relu 30 | layers.append(nn.Conv2d(in_channels=self.downsampled_channels, out_channels=self.num_feature_maps, \ 31 | kernel_size=self.kernel_size, padding=self.padding, bias=False)) 32 | layers.append(nn.ReLU(inplace=True)) 33 | 34 | # Conv + BN + Relu 35 | for _ in range(self.num_conv_layers - 2): 36 | layers.append(nn.Conv2d(in_channels=self.num_feature_maps, out_channels=self.num_feature_maps, \ 37 | kernel_size=self.kernel_size, padding=self.padding, bias=False)) 38 | layers.append(nn.BatchNorm2d(self.num_feature_maps)) 39 | layers.append(nn.ReLU(inplace=True)) 40 | 41 | # Conv 42 | layers.append(nn.Conv2d(in_channels=self.num_feature_maps, out_channels=self.output_features, \ 43 | kernel_size=self.kernel_size, padding=self.padding, bias=False)) 44 | 45 | self.intermediate_dncnn = nn.Sequential(*layers) 46 | 47 | def forward(self, x, noise_sigma): 48 | noise_map = noise_sigma.view(x.shape[0], 1, 1, 1).repeat(1, x.shape[1], x.shape[2] // 2, x.shape[3] // 2) 49 | 50 | x_up = utils.downsample(x.data) # 4 * C * H/2 * W/2 51 | x_cat = torch.cat((noise_map.data, x_up), 1) # 4 * (C + 1) * H/2 * W/2 52 | x_cat = Variable(x_cat) 53 | 54 | h_dncnn = self.intermediate_dncnn(x_cat) 55 | y_pred = utils.upsample(h_dncnn) 56 | return y_pred 57 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Aoi-hosizora/FFDNet_pytorch/80b23129a2316139a9550c1b85f6546528b531f6/models/.gitkeep -------------------------------------------------------------------------------- /test_data/color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Aoi-hosizora/FFDNet_pytorch/80b23129a2316139a9550c1b85f6546528b531f6/test_data/color.png -------------------------------------------------------------------------------- /test_data/gray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Aoi-hosizora/FFDNet_pytorch/80b23129a2316139a9550c1b85f6546528b531f6/test_data/gray.jpg -------------------------------------------------------------------------------- /test_run.sh: -------------------------------------------------------------------------------- 1 | python3 ffdnet.py \ 2 | --use_gpu \ 3 | --is_train \ 4 | --train_path './train_data/' \ 5 | --model_path './models/' \ 6 | --batch_size 1024 \ 7 | --epoches 150 \ 8 | --val_epoch 5 \ 9 | --patch_size 32 \ 10 | --save_checkpoints 20 \ 11 | --train_noise_interval 15 75 15 \ 12 | --val_noise_interval 30 60 30 \ 13 | --is_test \ 14 | --test_path './test_data/color.png' \ 15 | --add_noise \ 16 | --noise_sigma 30 -------------------------------------------------------------------------------- /train_data/gray/train/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Aoi-hosizora/FFDNet_pytorch/80b23129a2316139a9550c1b85f6546528b531f6/train_data/gray/train/.gitkeep -------------------------------------------------------------------------------- /train_data/gray/val/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Aoi-hosizora/FFDNet_pytorch/80b23129a2316139a9550c1b85f6546528b531f6/train_data/gray/val/.gitkeep -------------------------------------------------------------------------------- /train_data/rgb/train/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Aoi-hosizora/FFDNet_pytorch/80b23129a2316139a9550c1b85f6546528b531f6/train_data/rgb/train/.gitkeep -------------------------------------------------------------------------------- /train_data/rgb/val/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Aoi-hosizora/FFDNet_pytorch/80b23129a2316139a9550c1b85f6546528b531f6/train_data/rgb/val/.gitkeep -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | from skimage.measure.simple_metrics import compare_psnr 11 | from skimage.util import random_noise 12 | 13 | def is_image_gray(image): 14 | """ 15 | :param image: cv2 16 | """ 17 | # a[..., 0] == a.T[0].T 18 | return not(len(image.shape) == 3 and not(np.allclose(image[...,0], image[...,1]) and np.allclose(image[...,2], image[...,1]))) 19 | 20 | def downsample(x): 21 | """ 22 | :param x: (C, H, W) 23 | :param noise_sigma: (C, H/2, W/2) 24 | :return: (4, C, H/2, W/2) 25 | """ 26 | # x = x[:, :, :x.shape[2] // 2 * 2, :x.shape[3] // 2 * 2] 27 | N, C, W, H = x.size() 28 | idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] 29 | 30 | Cout = 4 * C 31 | Wout = W // 2 32 | Hout = H // 2 33 | 34 | if 'cuda' in x.type(): 35 | down_features = torch.cuda.FloatTensor(N, Cout, Wout, Hout).fill_(0) 36 | else: 37 | down_features = torch.FloatTensor(N, Cout, Wout, Hout).fill_(0) 38 | 39 | for idx in range(4): 40 | down_features[:, idx:Cout:4, :, :] = x[:, :, idxL[idx][0]::2, idxL[idx][1]::2] 41 | 42 | return down_features 43 | 44 | def upsample(x): 45 | """ 46 | :param x: (n, C, W, H) 47 | :return: (n, C/4, W*2, H*2) 48 | """ 49 | N, Cin, Win, Hin = x.size() 50 | idxL = [[0, 0], [0, 1], [1, 0], [1, 1]] 51 | 52 | Cout = Cin // 4 53 | Wout = Win * 2 54 | Hout = Hin * 2 55 | 56 | up_feature = torch.zeros((N, Cout, Wout, Hout)).type(x.type()) 57 | for idx in range(4): 58 | up_feature[:, :, idxL[idx][0]::2, idxL[idx][1]::2] = x[:, idx:Cin:4, :, :] 59 | 60 | return up_feature 61 | 62 | def normalize(data): 63 | """ 64 | // variable_to_cv2_image will reshape to *255 65 | """ 66 | return np.float32(data / 255) 67 | 68 | def image_to_patches(image, patch_size): 69 | """ 70 | :param image: Image (C * W * H) Numpy 71 | :param patch_size: int 72 | :return: (patch_num, C, win, win) 73 | """ 74 | W = image.shape[1] 75 | H = image.shape[2] 76 | if W < patch_size or H < patch_size: 77 | return [] 78 | 79 | ret = [] 80 | for ws in range(0, W // patch_size): 81 | for hs in range(0, H // patch_size): 82 | patch = image[:, ws * patch_size : (ws + 1) * patch_size, hs * patch_size : (hs + 1) * patch_size] 83 | ret.append(patch) 84 | return np.array(ret, dtype=np.float32) 85 | 86 | def add_batch_noise(images, noise_sigma): 87 | """ 88 | :param images: Image (n, C, W, H) Tensor 89 | :return: Image (n, C, W, H) 90 | """ 91 | images = random_noise(images.numpy(), mode='gaussian', var=noise_sigma) 92 | return torch.FloatTensor(images) 93 | 94 | def batch_psnr(img, imclean, data_range): 95 | """ 96 | add the whole batch's PSNR 97 | """ 98 | img_cpu = img.data.cpu().numpy().astype(np.float32) 99 | imgclean = imclean.data.cpu().numpy().astype(np.float32) 100 | psnr = 0 101 | for i in range(img_cpu.shape[0]): 102 | psnr += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :], data_range=data_range) 103 | return psnr / img_cpu.shape[0] 104 | 105 | def variable_to_cv2_image(varim): 106 | """ 107 | Norm Variable -> Cv2 108 | """ 109 | nchannels = varim.size()[1] 110 | if nchannels == 1: 111 | res = (varim.data.cpu().numpy()[0, 0, :] * 255.).clip(0, 255).astype(np.uint8) 112 | elif nchannels == 3: 113 | res = varim.data.cpu().numpy()[0] 114 | res = cv2.cvtColor(res.transpose(2, 1, 0), cv2.COLOR_RGB2BGR) 115 | res = (res*255.).clip(0, 255).astype(np.uint8) 116 | else: 117 | raise Exception('Number of color channels not supported') 118 | return res 119 | 120 | def weights_init_kaiming(lyr): 121 | """ 122 | Initializes weights of the model according to the "He" initialization 123 | method described in "Delving deep into rectifiers: Surpassing human-level 124 | performance on ImageNet classification" - He, K. et al. (2015), using a 125 | normal distribution. 126 | This function is to be called by the torch.nn.Module.apply() method, 127 | which applies weights_init_kaiming() to every layer of the model. 128 | """ 129 | classname = lyr.__class__.__name__ 130 | if classname.find('Conv') != -1: 131 | nn.init.kaiming_normal_(lyr.weight.data, a=0, mode='fan_in') 132 | elif classname.find('Linear') != -1: 133 | nn.init.kaiming_normal_(lyr.weight.data, a=0, mode='fan_in') 134 | elif classname.find('BatchNorm') != -1: 135 | lyr.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).\ 136 | clamp_(-0.025, 0.025) 137 | nn.init.constant_(lyr.bias.data, 0.0) --------------------------------------------------------------------------------