├── .gitignore
├── README.md
├── data
├── monarch.bmp
├── monarch_jpeg_q40.png
├── monarch_jpeg_q40_DnCNN-3.png
├── monarch_noise_l25.png
├── monarch_noise_l25_DnCNN-3.png
├── monarch_sr_s3.png
└── monarch_sr_s3_DnCNN-3.png
├── dataset.py
├── example.py
├── main.py
├── model.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DnCNN
2 |
3 | This repository is implementation of the "Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising".
4 |
5 | ## Requirements
6 | - PyTorch
7 | - Tensorflow
8 | - tqdm
9 | - Numpy
10 | - Pillow
11 |
12 | **Tensorflow** is required for quickly fetching image in training phase.
13 |
14 | ## Results
15 |
16 | The DnCNN-3 is only a single model for three general image denoising tasks, i.e., blind Gaussian denoising, SISR with multiple upscaling factors, and JPEG deblocking with different quality factors.
17 |
18 |
19 |
20 | JPEG Artifacts (Quality 40) |
21 | DnCNN-3 |
22 |
23 |
24 |
25 |
26 | |
27 |
28 |
29 | |
30 |
31 |
32 | Gaussian Noise (Level 25) |
33 | DnCNN-3 |
34 |
35 |
36 |
37 |
38 | |
39 |
40 |
41 | |
42 |
43 |
44 | Super-Resolution (Scale x3) |
45 | DnCNN-3 |
46 |
47 |
48 |
49 |
50 | |
51 |
52 |
53 | |
54 |
55 |
56 |
57 | ## Usages
58 |
59 | ### Train
60 |
61 | When training begins, the model weights will be saved every epoch.
62 | If you want to train quickly, you should use **--use_fast_loader** option.
63 |
64 | #### DnCNN-S
65 |
66 | ```bash
67 | python main.py --arch "DnCNN-S" \
68 | --images_dir "" \
69 | --outputs_dir "" \
70 | --gaussian_noise_level 25 \
71 | --patch_size 50 \
72 | --batch_size 16 \
73 | --num_epochs 20 \
74 | --lr 1e-3 \
75 | --threads 8 \
76 | --seed 123 \
77 | --use_fast_loader
78 | ```
79 |
80 | #### DnCNN-B
81 |
82 | ```bash
83 | python main.py --arch "DnCNN-B" \
84 | --images_dir "" \
85 | --outputs_dir "" \
86 | --gaussian_noise_level 0,55 \
87 | --patch_size 50 \
88 | --batch_size 16 \
89 | --num_epochs 20 \
90 | --lr 1e-3 \
91 | --threads 8 \
92 | --seed 123 \
93 | --use_fast_loader
94 | ```
95 |
96 | #### DnCNN-3
97 |
98 | ```bash
99 | python main.py --arch "DnCNN-3" \
100 | --images_dir "" \
101 | --outputs_dir "" \
102 | --gaussian_noise_level 0,55 \
103 | --downsampling_factor 1,4 \
104 | --jpeg_quality 5,99 \
105 | --patch_size 50 \
106 | --batch_size 16 \
107 | --num_epochs 20 \
108 | --lr 1e-3 \
109 | --threads 8 \
110 | --seed 123 \
111 | --use_fast_loader
112 | ```
113 |
114 | ### Test
115 |
116 | Output results consist of noisy image and denoised image.
117 |
118 | ```bash
119 | python example --arch "DnCNN-S" \
120 | --weights_path "" \
121 | --image_path "" \
122 | --outputs_dir "" \
123 | --jpeg_quality 25
124 | ```
125 |
--------------------------------------------------------------------------------
/data/monarch.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch.bmp
--------------------------------------------------------------------------------
/data/monarch_jpeg_q40.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_jpeg_q40.png
--------------------------------------------------------------------------------
/data/monarch_jpeg_q40_DnCNN-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_jpeg_q40_DnCNN-3.png
--------------------------------------------------------------------------------
/data/monarch_noise_l25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_noise_l25.png
--------------------------------------------------------------------------------
/data/monarch_noise_l25_DnCNN-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_noise_l25_DnCNN-3.png
--------------------------------------------------------------------------------
/data/monarch_sr_s3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_sr_s3.png
--------------------------------------------------------------------------------
/data/monarch_sr_s3_DnCNN-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/DnCNN-pytorch/91ffa5b6028dde2eb9f5ff2e85ede6c18b32f118/data/monarch_sr_s3_DnCNN-3.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
3 |
4 | import random
5 | import glob
6 | import io
7 | import numpy as np
8 | import PIL.Image as pil_image
9 |
10 | import tensorflow as tf
11 | config = tf.ConfigProto()
12 | config.gpu_options.allow_growth = True
13 | tf.enable_eager_execution(config=config)
14 |
15 |
16 | class Dataset(object):
17 | def __init__(self, images_dir, patch_size,
18 | gaussian_noise_level, downsampling_factor, jpeg_quality,
19 | use_fast_loader=False):
20 | self.image_files = sorted(glob.glob(images_dir + '/*'))
21 | self.patch_size = patch_size
22 | self.gaussian_noise_level = gaussian_noise_level
23 | self.downsampling_factor = downsampling_factor
24 | self.jpeg_quality = jpeg_quality
25 | self.use_fast_loader = use_fast_loader
26 |
27 | def __getitem__(self, idx):
28 | if self.use_fast_loader:
29 | clean_image = tf.read_file(self.image_files[idx])
30 | clean_image = tf.image.decode_jpeg(clean_image, channels=3)
31 | clean_image = pil_image.fromarray(clean_image.numpy())
32 | else:
33 | clean_image = pil_image.open(self.image_files[idx]).convert('RGB')
34 |
35 | # randomly crop patch from training set
36 | crop_x = random.randint(0, clean_image.width - self.patch_size)
37 | crop_y = random.randint(0, clean_image.height - self.patch_size)
38 | clean_image = clean_image.crop((crop_x, crop_y, crop_x + self.patch_size, crop_y + self.patch_size))
39 |
40 | noisy_image = clean_image.copy()
41 | gaussian_noise = np.zeros((clean_image.height, clean_image.width, 3), dtype=np.float32)
42 |
43 | # additive gaussian noise
44 | if self.gaussian_noise_level is not None:
45 | if len(self.gaussian_noise_level) == 1:
46 | sigma = self.gaussian_noise_level[0]
47 | else:
48 | sigma = random.randint(self.gaussian_noise_level[0], self.gaussian_noise_level[1])
49 | gaussian_noise += np.random.normal(0.0, sigma, (clean_image.height, clean_image.width, 3)).astype(np.float32)
50 |
51 | # downsampling
52 | if self.downsampling_factor is not None:
53 | if len(self.downsampling_factor) == 1:
54 | downsampling_factor = self.downsampling_factor[0]
55 | else:
56 | downsampling_factor = random.randint(self.downsampling_factor[0], self.downsampling_factor[1])
57 |
58 | noisy_image = noisy_image.resize((self.patch_size // downsampling_factor,
59 | self.patch_size // downsampling_factor),
60 | resample=pil_image.BICUBIC)
61 | noisy_image = noisy_image.resize((self.patch_size, self.patch_size), resample=pil_image.BICUBIC)
62 |
63 | # additive jpeg noise
64 | if self.jpeg_quality is not None:
65 | if len(self.jpeg_quality) == 1:
66 | quality = self.jpeg_quality[0]
67 | else:
68 | quality = random.randint(self.jpeg_quality[0], self.jpeg_quality[1])
69 | buffer = io.BytesIO()
70 | noisy_image.save(buffer, format='jpeg', quality=quality)
71 | noisy_image = pil_image.open(buffer)
72 |
73 | clean_image = np.array(clean_image).astype(np.float32)
74 | noisy_image = np.array(noisy_image).astype(np.float32)
75 | noisy_image += gaussian_noise
76 |
77 | input = np.transpose(noisy_image, axes=[2, 0, 1])
78 | label = np.transpose(clean_image, axes=[2, 0, 1])
79 |
80 | # normalization
81 | input /= 255.0
82 | label /= 255.0
83 |
84 | return input, label
85 |
86 | def __len__(self):
87 | return len(self.image_files)
88 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import io
4 | import numpy as np
5 | import PIL.Image as pil_image
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | from torchvision import transforms
9 | from model import DnCNN
10 |
11 | cudnn.benchmark = True
12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13 |
14 |
15 | if __name__ == '__main__':
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--arch', type=str, default='DnCNN-S', help='DnCNN-S, DnCNN-B, DnCNN-3')
18 | parser.add_argument('--weights_path', type=str, required=True)
19 | parser.add_argument('--image_path', type=str, required=True)
20 | parser.add_argument('--outputs_dir', type=str, required=True)
21 | parser.add_argument('--gaussian_noise_level', type=int)
22 | parser.add_argument('--jpeg_quality', type=int)
23 | parser.add_argument('--downsampling_factor', type=int)
24 | opt = parser.parse_args()
25 |
26 | if not os.path.exists(opt.outputs_dir):
27 | os.makedirs(opt.outputs_dir)
28 |
29 | if opt.arch == 'DnCNN-S':
30 | model = DnCNN(num_layers=17)
31 | elif opt.arch == 'DnCNN-B':
32 | model = DnCNN(num_layers=20)
33 | elif opt.arch == 'DnCNN-3':
34 | model = DnCNN(num_layers=20)
35 |
36 | state_dict = model.state_dict()
37 | for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items():
38 | if n in state_dict.keys():
39 | state_dict[n].copy_(p)
40 | else:
41 | raise KeyError(n)
42 |
43 | model = model.to(device)
44 | model.eval()
45 |
46 | filename = os.path.basename(opt.image_path).split('.')[0]
47 | descriptions = ''
48 |
49 | input = pil_image.open(opt.image_path).convert('RGB')
50 |
51 | if opt.gaussian_noise_level is not None:
52 | noise = np.random.normal(0.0, opt.gaussian_noise_level, (input.height, input.width, 3)).astype(np.float32)
53 | input = np.array(input).astype(np.float32) + noise
54 | descriptions += '_noise_l{}'.format(opt.gaussian_noise_level)
55 | pil_image.fromarray(input.clip(0.0, 255.0).astype(np.uint8)).save(os.path.join(opt.outputs_dir, '{}{}.png'.format(filename, descriptions)))
56 | input /= 255.0
57 |
58 | if opt.jpeg_quality is not None:
59 | buffer = io.BytesIO()
60 | input.save(buffer, format='jpeg', quality=opt.jpeg_quality)
61 | input = pil_image.open(buffer)
62 | descriptions += '_jpeg_q{}'.format(opt.jpeg_quality)
63 | input.save(os.path.join(opt.outputs_dir, '{}{}.png'.format(filename, descriptions)))
64 | input = np.array(input).astype(np.float32)
65 | input /= 255.0
66 |
67 | if opt.downsampling_factor is not None:
68 | original_width = input.width
69 | original_height = input.height
70 | input = input.resize((input.width // opt.downsampling_factor,
71 | input.height // opt.downsampling_factor),
72 | resample=pil_image.BICUBIC)
73 | input = input.resize((original_width, original_height), resample=pil_image.BICUBIC)
74 | descriptions += '_sr_s{}'.format(opt.downsampling_factor)
75 | input.save(os.path.join(opt.outputs_dir, '{}{}.png'.format(filename, descriptions)))
76 | input = np.array(input).astype(np.float32)
77 | input /= 255.0
78 |
79 | input = transforms.ToTensor()(input).unsqueeze(0).to(device)
80 |
81 | with torch.no_grad():
82 | pred = model(input)
83 |
84 | output = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
85 | output = pil_image.fromarray(output, mode='RGB')
86 | output.save(os.path.join(opt.outputs_dir, '{}{}_{}.png'.format(filename, descriptions, opt.arch)))
87 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from torch import nn
5 | import torch.optim as optim
6 | import torch.backends.cudnn as cudnn
7 | from torch.utils.data.dataloader import DataLoader
8 | from tqdm import tqdm
9 | from model import DnCNN
10 | from dataset import Dataset
11 | from utils import AverageMeter
12 |
13 | cudnn.benchmark = True
14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--arch', type=str, default='DnCNN-S', help='DnCNN-S, DnCNN-B, DnCNN-3')
20 | parser.add_argument('--images_dir', type=str, required=True)
21 | parser.add_argument('--outputs_dir', type=str, required=True)
22 | parser.add_argument('--gaussian_noise_level', type=str)
23 | parser.add_argument('--downsampling_factor', type=str)
24 | parser.add_argument('--jpeg_quality', type=str)
25 | parser.add_argument('--patch_size', type=int, default=50)
26 | parser.add_argument('--batch_size', type=int, default=16)
27 | parser.add_argument('--num_epochs', type=int, default=20)
28 | parser.add_argument('--lr', type=float, default=1e-3)
29 | parser.add_argument('--threads', type=int, default=8)
30 | parser.add_argument('--seed', type=int, default=123)
31 | parser.add_argument('--use_fast_loader', action='store_true')
32 | opt = parser.parse_args()
33 |
34 | if opt.gaussian_noise_level is not None:
35 | opt.gaussian_noise_level = list(map(lambda x: int(x), opt.gaussian_noise_level.split(',')))
36 |
37 | if opt.downsampling_factor is not None:
38 | opt.downsampling_factor = list(map(lambda x: int(x), opt.downsampling_factor.split(',')))
39 |
40 | if opt.jpeg_quality is not None:
41 | opt.jpeg_quality = list(map(lambda x: int(x), opt.jpeg_quality.split(',')))
42 |
43 | if not os.path.exists(opt.outputs_dir):
44 | os.makedirs(opt.outputs_dir)
45 |
46 | torch.manual_seed(opt.seed)
47 |
48 | if opt.arch == 'DnCNN-S':
49 | model = DnCNN(num_layers=17)
50 | elif opt.arch == 'DnCNN-B':
51 | model = DnCNN(num_layers=20)
52 | elif opt.arch == 'DnCNN-3':
53 | model = DnCNN(num_layers=20)
54 |
55 | model = model.to(device)
56 | criterion = nn.MSELoss(reduction='sum')
57 |
58 | optimizer = optim.Adam(model.parameters(), lr=opt.lr)
59 |
60 | dataset = Dataset(opt.images_dir, opt.patch_size,
61 | opt.gaussian_noise_level, opt.downsampling_factor, opt.jpeg_quality,
62 | opt.use_fast_loader)
63 | dataloader = DataLoader(dataset=dataset,
64 | batch_size=opt.batch_size,
65 | shuffle=True,
66 | num_workers=opt.threads,
67 | pin_memory=True,
68 | drop_last=True)
69 |
70 | for epoch in range(opt.num_epochs):
71 | epoch_losses = AverageMeter()
72 |
73 | with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm:
74 | _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs))
75 | for data in dataloader:
76 | inputs, labels = data
77 | inputs = inputs.to(device)
78 | labels = labels.to(device)
79 |
80 | preds = model(inputs)
81 |
82 | loss = criterion(preds, labels) / (2 * len(inputs))
83 |
84 | epoch_losses.update(loss.item(), len(inputs))
85 |
86 | optimizer.zero_grad()
87 | loss.backward()
88 | optimizer.step()
89 |
90 | _tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
91 | _tqdm.update(len(inputs))
92 |
93 | torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format(opt.arch, epoch)))
94 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class DnCNN(nn.Module):
5 | def __init__(self, num_layers=17, num_features=64):
6 | super(DnCNN, self).__init__()
7 | layers = [nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=1, padding=1),
8 | nn.ReLU(inplace=True))]
9 | for i in range(num_layers - 2):
10 | layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
11 | nn.BatchNorm2d(num_features),
12 | nn.ReLU(inplace=True)))
13 | layers.append(nn.Conv2d(num_features, 3, kernel_size=3, padding=1))
14 | self.layers = nn.Sequential(*layers)
15 |
16 | self._initialize_weights()
17 |
18 | def _initialize_weights(self):
19 | for m in self.modules():
20 | if isinstance(m, nn.Conv2d):
21 | nn.init.kaiming_normal_(m.weight)
22 | elif isinstance(m, nn.BatchNorm2d):
23 | nn.init.ones_(m.weight)
24 | nn.init.zeros_(m.bias)
25 |
26 | def forward(self, inputs):
27 | y = inputs
28 | residual = self.layers(y)
29 | return y - residual
30 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | def __init__(self):
3 | self.reset()
4 |
5 | def reset(self):
6 | self.val = 0
7 | self.avg = 0
8 | self.sum = 0
9 | self.count = 0
10 |
11 | def update(self, val, n=1):
12 | self.val = val
13 | self.sum += val * n
14 | self.count += n
15 | self.avg = self.sum / self.count
16 |
--------------------------------------------------------------------------------