├── .gitignore ├── README.md ├── data ├── monarch.bmp ├── monarch_jpeg_q40.png ├── monarch_jpeg_q40_DnCNN-3.png ├── monarch_noise_l25.png ├── monarch_noise_l25_DnCNN-3.png ├── monarch_sr_s3.png └── monarch_sr_s3_DnCNN-3.png ├── dataset.py ├── example.py ├── main.py ├── model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DnCNN 2 | 3 | This repository is implementation of the "Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising". 4 | 5 | ## Requirements 6 | - PyTorch 7 | - Tensorflow 8 | - tqdm 9 | - Numpy 10 | - Pillow 11 | 12 | **Tensorflow** is required for quickly fetching image in training phase. 13 | 14 | ## Results 15 | 16 | The DnCNN-3 is only a single model for three general image denoising tasks, i.e., blind Gaussian denoising, SISR with multiple upscaling factors, and JPEG deblocking with different quality factors. 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 27 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 39 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 51 | 54 | 55 |
JPEG Artifacts (Quality 40)
DnCNN-3
25 |
26 |
28 |
29 |
Gaussian Noise (Level 25)
DnCNN-3
37 |
38 |
40 |
41 |
Super-Resolution (Scale x3)
DnCNN-3
49 |
50 |
52 |
53 |
56 | 57 | ## Usages 58 | 59 | ### Train 60 | 61 | When training begins, the model weights will be saved every epoch.
62 | If you want to train quickly, you should use **--use_fast_loader** option. 63 | 64 | #### DnCNN-S 65 | 66 | ```bash 67 | python main.py --arch "DnCNN-S" \ 68 | --images_dir "" \ 69 | --outputs_dir "" \ 70 | --gaussian_noise_level 25 \ 71 | --patch_size 50 \ 72 | --batch_size 16 \ 73 | --num_epochs 20 \ 74 | --lr 1e-3 \ 75 | --threads 8 \ 76 | --seed 123 \ 77 | --use_fast_loader 78 | ``` 79 | 80 | #### DnCNN-B 81 | 82 | ```bash 83 | python main.py --arch "DnCNN-B" \ 84 | --images_dir "" \ 85 | --outputs_dir "" \ 86 | --gaussian_noise_level 0,55 \ 87 | --patch_size 50 \ 88 | --batch_size 16 \ 89 | --num_epochs 20 \ 90 | --lr 1e-3 \ 91 | --threads 8 \ 92 | --seed 123 \ 93 | --use_fast_loader 94 | ``` 95 | 96 | #### DnCNN-3 97 | 98 | ```bash 99 | python main.py --arch "DnCNN-3" \ 100 | --images_dir "" \ 101 | --outputs_dir "" \ 102 | --gaussian_noise_level 0,55 \ 103 | --downsampling_factor 1,4 \ 104 | --jpeg_quality 5,99 \ 105 | --patch_size 50 \ 106 | --batch_size 16 \ 107 | --num_epochs 20 \ 108 | --lr 1e-3 \ 109 | --threads 8 \ 110 | --seed 123 \ 111 | --use_fast_loader 112 | ``` 113 | 114 | ### Test 115 | 116 | Output results consist of noisy image and denoised image. 117 | 118 | ```bash 119 | python example --arch "DnCNN-S" \ 120 | --weights_path "" \ 121 | --image_path "" \ 122 | --outputs_dir "" \ 123 | --jpeg_quality 25 124 | ``` 125 | -------------------------------------------------------------------------------- /data/monarch.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch.bmp -------------------------------------------------------------------------------- /data/monarch_jpeg_q40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_jpeg_q40.png -------------------------------------------------------------------------------- /data/monarch_jpeg_q40_DnCNN-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_jpeg_q40_DnCNN-3.png -------------------------------------------------------------------------------- /data/monarch_noise_l25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_noise_l25.png -------------------------------------------------------------------------------- /data/monarch_noise_l25_DnCNN-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_noise_l25_DnCNN-3.png -------------------------------------------------------------------------------- /data/monarch_sr_s3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_sr_s3.png -------------------------------------------------------------------------------- /data/monarch_sr_s3_DnCNN-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_sr_s3_DnCNN-3.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 3 | 4 | import random 5 | import glob 6 | import io 7 | import numpy as np 8 | import PIL.Image as pil_image 9 | 10 | import tensorflow as tf 11 | config = tf.ConfigProto() 12 | config.gpu_options.allow_growth = True 13 | tf.enable_eager_execution(config=config) 14 | 15 | 16 | class Dataset(object): 17 | def __init__(self, images_dir, patch_size, 18 | gaussian_noise_level, downsampling_factor, jpeg_quality, 19 | use_fast_loader=False): 20 | self.image_files = sorted(glob.glob(images_dir + '/*')) 21 | self.patch_size = patch_size 22 | self.gaussian_noise_level = gaussian_noise_level 23 | self.downsampling_factor = downsampling_factor 24 | self.jpeg_quality = jpeg_quality 25 | self.use_fast_loader = use_fast_loader 26 | 27 | def __getitem__(self, idx): 28 | if self.use_fast_loader: 29 | clean_image = tf.read_file(self.image_files[idx]) 30 | clean_image = tf.image.decode_jpeg(clean_image, channels=3) 31 | clean_image = pil_image.fromarray(clean_image.numpy()) 32 | else: 33 | clean_image = pil_image.open(self.image_files[idx]).convert('RGB') 34 | 35 | # randomly crop patch from training set 36 | crop_x = random.randint(0, clean_image.width - self.patch_size) 37 | crop_y = random.randint(0, clean_image.height - self.patch_size) 38 | clean_image = clean_image.crop((crop_x, crop_y, crop_x + self.patch_size, crop_y + self.patch_size)) 39 | 40 | noisy_image = clean_image.copy() 41 | gaussian_noise = np.zeros((clean_image.height, clean_image.width, 3), dtype=np.float32) 42 | 43 | # additive gaussian noise 44 | if self.gaussian_noise_level is not None: 45 | if len(self.gaussian_noise_level) == 1: 46 | sigma = self.gaussian_noise_level[0] 47 | else: 48 | sigma = random.randint(self.gaussian_noise_level[0], self.gaussian_noise_level[1]) 49 | gaussian_noise += np.random.normal(0.0, sigma, (clean_image.height, clean_image.width, 3)).astype(np.float32) 50 | 51 | # downsampling 52 | if self.downsampling_factor is not None: 53 | if len(self.downsampling_factor) == 1: 54 | downsampling_factor = self.downsampling_factor[0] 55 | else: 56 | downsampling_factor = random.randint(self.downsampling_factor[0], self.downsampling_factor[1]) 57 | 58 | noisy_image = noisy_image.resize((self.patch_size // downsampling_factor, 59 | self.patch_size // downsampling_factor), 60 | resample=pil_image.BICUBIC) 61 | noisy_image = noisy_image.resize((self.patch_size, self.patch_size), resample=pil_image.BICUBIC) 62 | 63 | # additive jpeg noise 64 | if self.jpeg_quality is not None: 65 | if len(self.jpeg_quality) == 1: 66 | quality = self.jpeg_quality[0] 67 | else: 68 | quality = random.randint(self.jpeg_quality[0], self.jpeg_quality[1]) 69 | buffer = io.BytesIO() 70 | noisy_image.save(buffer, format='jpeg', quality=quality) 71 | noisy_image = pil_image.open(buffer) 72 | 73 | clean_image = np.array(clean_image).astype(np.float32) 74 | noisy_image = np.array(noisy_image).astype(np.float32) 75 | noisy_image += gaussian_noise 76 | 77 | input = np.transpose(noisy_image, axes=[2, 0, 1]) 78 | label = np.transpose(clean_image, axes=[2, 0, 1]) 79 | 80 | # normalization 81 | input /= 255.0 82 | label /= 255.0 83 | 84 | return input, label 85 | 86 | def __len__(self): 87 | return len(self.image_files) 88 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import io 4 | import numpy as np 5 | import PIL.Image as pil_image 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from torchvision import transforms 9 | from model import DnCNN 10 | 11 | cudnn.benchmark = True 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--arch', type=str, default='DnCNN-S', help='DnCNN-S, DnCNN-B, DnCNN-3') 18 | parser.add_argument('--weights_path', type=str, required=True) 19 | parser.add_argument('--image_path', type=str, required=True) 20 | parser.add_argument('--outputs_dir', type=str, required=True) 21 | parser.add_argument('--gaussian_noise_level', type=int) 22 | parser.add_argument('--jpeg_quality', type=int) 23 | parser.add_argument('--downsampling_factor', type=int) 24 | opt = parser.parse_args() 25 | 26 | if not os.path.exists(opt.outputs_dir): 27 | os.makedirs(opt.outputs_dir) 28 | 29 | if opt.arch == 'DnCNN-S': 30 | model = DnCNN(num_layers=17) 31 | elif opt.arch == 'DnCNN-B': 32 | model = DnCNN(num_layers=20) 33 | elif opt.arch == 'DnCNN-3': 34 | model = DnCNN(num_layers=20) 35 | 36 | state_dict = model.state_dict() 37 | for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items(): 38 | if n in state_dict.keys(): 39 | state_dict[n].copy_(p) 40 | else: 41 | raise KeyError(n) 42 | 43 | model = model.to(device) 44 | model.eval() 45 | 46 | filename = os.path.basename(opt.image_path).split('.')[0] 47 | descriptions = '' 48 | 49 | input = pil_image.open(opt.image_path).convert('RGB') 50 | 51 | if opt.gaussian_noise_level is not None: 52 | noise = np.random.normal(0.0, opt.gaussian_noise_level, (input.height, input.width, 3)).astype(np.float32) 53 | input = np.array(input).astype(np.float32) + noise 54 | descriptions += '_noise_l{}'.format(opt.gaussian_noise_level) 55 | pil_image.fromarray(input.clip(0.0, 255.0).astype(np.uint8)).save(os.path.join(opt.outputs_dir, '{}{}.png'.format(filename, descriptions))) 56 | input /= 255.0 57 | 58 | if opt.jpeg_quality is not None: 59 | buffer = io.BytesIO() 60 | input.save(buffer, format='jpeg', quality=opt.jpeg_quality) 61 | input = pil_image.open(buffer) 62 | descriptions += '_jpeg_q{}'.format(opt.jpeg_quality) 63 | input.save(os.path.join(opt.outputs_dir, '{}{}.png'.format(filename, descriptions))) 64 | input = np.array(input).astype(np.float32) 65 | input /= 255.0 66 | 67 | if opt.downsampling_factor is not None: 68 | original_width = input.width 69 | original_height = input.height 70 | input = input.resize((input.width // opt.downsampling_factor, 71 | input.height // opt.downsampling_factor), 72 | resample=pil_image.BICUBIC) 73 | input = input.resize((original_width, original_height), resample=pil_image.BICUBIC) 74 | descriptions += '_sr_s{}'.format(opt.downsampling_factor) 75 | input.save(os.path.join(opt.outputs_dir, '{}{}.png'.format(filename, descriptions))) 76 | input = np.array(input).astype(np.float32) 77 | input /= 255.0 78 | 79 | input = transforms.ToTensor()(input).unsqueeze(0).to(device) 80 | 81 | with torch.no_grad(): 82 | pred = model(input) 83 | 84 | output = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy() 85 | output = pil_image.fromarray(output, mode='RGB') 86 | output.save(os.path.join(opt.outputs_dir, '{}{}_{}.png'.format(filename, descriptions, opt.arch))) 87 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from torch import nn 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | from torch.utils.data.dataloader import DataLoader 8 | from tqdm import tqdm 9 | from model import DnCNN 10 | from dataset import Dataset 11 | from utils import AverageMeter 12 | 13 | cudnn.benchmark = True 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--arch', type=str, default='DnCNN-S', help='DnCNN-S, DnCNN-B, DnCNN-3') 20 | parser.add_argument('--images_dir', type=str, required=True) 21 | parser.add_argument('--outputs_dir', type=str, required=True) 22 | parser.add_argument('--gaussian_noise_level', type=str) 23 | parser.add_argument('--downsampling_factor', type=str) 24 | parser.add_argument('--jpeg_quality', type=str) 25 | parser.add_argument('--patch_size', type=int, default=50) 26 | parser.add_argument('--batch_size', type=int, default=16) 27 | parser.add_argument('--num_epochs', type=int, default=20) 28 | parser.add_argument('--lr', type=float, default=1e-3) 29 | parser.add_argument('--threads', type=int, default=8) 30 | parser.add_argument('--seed', type=int, default=123) 31 | parser.add_argument('--use_fast_loader', action='store_true') 32 | opt = parser.parse_args() 33 | 34 | if opt.gaussian_noise_level is not None: 35 | opt.gaussian_noise_level = list(map(lambda x: int(x), opt.gaussian_noise_level.split(','))) 36 | 37 | if opt.downsampling_factor is not None: 38 | opt.downsampling_factor = list(map(lambda x: int(x), opt.downsampling_factor.split(','))) 39 | 40 | if opt.jpeg_quality is not None: 41 | opt.jpeg_quality = list(map(lambda x: int(x), opt.jpeg_quality.split(','))) 42 | 43 | if not os.path.exists(opt.outputs_dir): 44 | os.makedirs(opt.outputs_dir) 45 | 46 | torch.manual_seed(opt.seed) 47 | 48 | if opt.arch == 'DnCNN-S': 49 | model = DnCNN(num_layers=17) 50 | elif opt.arch == 'DnCNN-B': 51 | model = DnCNN(num_layers=20) 52 | elif opt.arch == 'DnCNN-3': 53 | model = DnCNN(num_layers=20) 54 | 55 | model = model.to(device) 56 | criterion = nn.MSELoss(reduction='sum') 57 | 58 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 59 | 60 | dataset = Dataset(opt.images_dir, opt.patch_size, 61 | opt.gaussian_noise_level, opt.downsampling_factor, opt.jpeg_quality, 62 | opt.use_fast_loader) 63 | dataloader = DataLoader(dataset=dataset, 64 | batch_size=opt.batch_size, 65 | shuffle=True, 66 | num_workers=opt.threads, 67 | pin_memory=True, 68 | drop_last=True) 69 | 70 | for epoch in range(opt.num_epochs): 71 | epoch_losses = AverageMeter() 72 | 73 | with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm: 74 | _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs)) 75 | for data in dataloader: 76 | inputs, labels = data 77 | inputs = inputs.to(device) 78 | labels = labels.to(device) 79 | 80 | preds = model(inputs) 81 | 82 | loss = criterion(preds, labels) / (2 * len(inputs)) 83 | 84 | epoch_losses.update(loss.item(), len(inputs)) 85 | 86 | optimizer.zero_grad() 87 | loss.backward() 88 | optimizer.step() 89 | 90 | _tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg)) 91 | _tqdm.update(len(inputs)) 92 | 93 | torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format(opt.arch, epoch))) 94 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class DnCNN(nn.Module): 5 | def __init__(self, num_layers=17, num_features=64): 6 | super(DnCNN, self).__init__() 7 | layers = [nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=1, padding=1), 8 | nn.ReLU(inplace=True))] 9 | for i in range(num_layers - 2): 10 | layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), 11 | nn.BatchNorm2d(num_features), 12 | nn.ReLU(inplace=True))) 13 | layers.append(nn.Conv2d(num_features, 3, kernel_size=3, padding=1)) 14 | self.layers = nn.Sequential(*layers) 15 | 16 | self._initialize_weights() 17 | 18 | def _initialize_weights(self): 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | nn.init.kaiming_normal_(m.weight) 22 | elif isinstance(m, nn.BatchNorm2d): 23 | nn.init.ones_(m.weight) 24 | nn.init.zeros_(m.bias) 25 | 26 | def forward(self, inputs): 27 | y = inputs 28 | residual = self.layers(y) 29 | return y - residual 30 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | def __init__(self): 3 | self.reset() 4 | 5 | def reset(self): 6 | self.val = 0 7 | self.avg = 0 8 | self.sum = 0 9 | self.count = 0 10 | 11 | def update(self, val, n=1): 12 | self.val = val 13 | self.sum += val * n 14 | self.count += n 15 | self.avg = self.sum / self.count 16 | --------------------------------------------------------------------------------