├── .gitignore
├── README.md
├── data
├── monarch.bmp
├── monarch_ARCNN.png
├── monarch_REDNet10.png
├── monarch_REDNet20.png
├── monarch_REDNet30.png
└── monarch_jpeg_q10.png
├── dataset.py
├── example.py
├── main.py
├── model.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RED-Net
2 |
3 | This repository is implementation of the "Image Restoration Using Very Deep Convolutional Encoder-Decoder Networks with Symmetric Skip Connections".
4 | To reduce computational cost, it adopts stride 2 for the first convolution layer and the last transposed convolution layer.
5 |
6 | ## Requirements
7 | - PyTorch
8 | - Tensorflow
9 | - tqdm
10 | - Numpy
11 | - Pillow
12 |
13 | **Tensorflow** is required for quickly fetching image in training phase.
14 |
15 | ## Results
16 |
17 |
18 |
19 | Input |
20 | JPEG (Quality 10) |
21 |
22 |
23 |
24 |
25 | |
26 |
27 |
28 | |
29 |
30 |
31 | AR-CNN |
32 | RED-Net 10 |
33 |
34 |
35 |
36 |
37 | |
38 |
39 |
40 | |
41 |
42 |
43 | RED-Net 20 |
44 | RED-Net 30 |
45 |
46 |
47 |
48 |
49 | |
50 |
51 |
52 | |
53 |
54 |
55 |
56 | ## Usages
57 |
58 | ### Train
59 |
60 | When training begins, the model weights will be saved every epoch.
61 | If you want to train quickly, you should use **--use_fast_loader** option.
62 |
63 | ```bash
64 | python main.py --arch "REDNet30" \ # REDNet10, REDNet20, REDNet30
65 | --images_dir "" \
66 | --outputs_dir "" \
67 | --jpeg_quality 10 \
68 | --patch_size 50 \
69 | --batch_size 16 \
70 | --num_epochs 20 \
71 | --lr 1e-4 \
72 | --threads 8 \
73 | --seed 123 \
74 | --use_fast_loader
75 | ```
76 |
77 | ### Test
78 |
79 | Output results consist of image compressed with JPEG and image with artifacts reduced.
80 |
81 | ```bash
82 | python example --arch "REDNet30" \ # REDNet10, REDNet20, REDNet30
83 | --weights_path "" \
84 | --image_path "" \
85 | --outputs_dir "" \
86 | --jpeg_quality 10
87 | ```
88 |
--------------------------------------------------------------------------------
/data/monarch.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch.bmp
--------------------------------------------------------------------------------
/data/monarch_ARCNN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch_ARCNN.png
--------------------------------------------------------------------------------
/data/monarch_REDNet10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch_REDNet10.png
--------------------------------------------------------------------------------
/data/monarch_REDNet20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch_REDNet20.png
--------------------------------------------------------------------------------
/data/monarch_REDNet30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch_REDNet30.png
--------------------------------------------------------------------------------
/data/monarch_jpeg_q10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch_jpeg_q10.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
3 |
4 | import random
5 | import glob
6 | import io
7 | import numpy as np
8 | import PIL.Image as pil_image
9 |
10 | import tensorflow as tf
11 | config = tf.ConfigProto()
12 | config.gpu_options.allow_growth = True
13 | tf.enable_eager_execution(config=config)
14 |
15 |
16 | class Dataset(object):
17 | def __init__(self, images_dir, patch_size, jpeg_quality, use_fast_loader=False):
18 | self.image_files = sorted(glob.glob(images_dir + '/*'))
19 | self.patch_size = patch_size
20 | self.jpeg_quality = jpeg_quality
21 | self.use_fast_loader = use_fast_loader
22 |
23 | def __getitem__(self, idx):
24 | if self.use_fast_loader:
25 | label = tf.read_file(self.image_files[idx])
26 | label = tf.image.decode_jpeg(label, channels=3)
27 | label = pil_image.fromarray(label.numpy())
28 | else:
29 | label = pil_image.open(self.image_files[idx]).convert('RGB')
30 |
31 | # randomly crop patch from training set
32 | crop_x = random.randint(0, label.width - self.patch_size)
33 | crop_y = random.randint(0, label.height - self.patch_size)
34 | label = label.crop((crop_x, crop_y, crop_x + self.patch_size, crop_y + self.patch_size))
35 |
36 | # additive jpeg noise
37 | buffer = io.BytesIO()
38 | label.save(buffer, format='jpeg', quality=self.jpeg_quality)
39 | input = pil_image.open(buffer)
40 |
41 | input = np.array(input).astype(np.float32)
42 | label = np.array(label).astype(np.float32)
43 | input = np.transpose(input, axes=[2, 0, 1])
44 | label = np.transpose(label, axes=[2, 0, 1])
45 |
46 | # normalization
47 | input /= 255.0
48 | label /= 255.0
49 |
50 | return input, label
51 |
52 | def __len__(self):
53 | return len(self.image_files)
54 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import io
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | from torchvision import transforms
7 | import PIL.Image as pil_image
8 | from model import REDNet10, REDNet20, REDNet30
9 |
10 | cudnn.benchmark = True
11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12 |
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--arch', type=str, default='REDNet10', help='REDNet10, REDNet20, REDNet30')
17 | parser.add_argument('--weights_path', type=str, required=True)
18 | parser.add_argument('--image_path', type=str, required=True)
19 | parser.add_argument('--outputs_dir', type=str, required=True)
20 | parser.add_argument('--jpeg_quality', type=int, default=10)
21 | opt = parser.parse_args()
22 |
23 | if not os.path.exists(opt.outputs_dir):
24 | os.makedirs(opt.outputs_dir)
25 |
26 | if opt.arch == 'REDNet10':
27 | model = REDNet10()
28 | elif opt.arch == 'REDNet20':
29 | model = REDNet20()
30 | elif opt.arch == 'REDNet30':
31 | model = REDNet30()
32 |
33 | state_dict = model.state_dict()
34 | for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items():
35 | if n in state_dict.keys():
36 | state_dict[n].copy_(p)
37 | else:
38 | raise KeyError(n)
39 |
40 | model = model.to(device)
41 | model.eval()
42 |
43 | filename = os.path.basename(opt.image_path).split('.')[0]
44 |
45 | input = pil_image.open(opt.image_path).convert('RGB')
46 |
47 | buffer = io.BytesIO()
48 | input.save(buffer, format='jpeg', quality=opt.jpeg_quality)
49 | input = pil_image.open(buffer)
50 | input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))
51 |
52 | input = transforms.ToTensor()(input).unsqueeze(0).to(device)
53 |
54 | with torch.no_grad():
55 | pred = model(input)
56 |
57 | pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
58 | output = pil_image.fromarray(pred, mode='RGB')
59 | output.save(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, opt.arch)))
60 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from torch import nn
5 | import torch.optim as optim
6 | import torch.backends.cudnn as cudnn
7 | from torch.utils.data.dataloader import DataLoader
8 | from tqdm import tqdm
9 | from model import REDNet10, REDNet20, REDNet30
10 | from dataset import Dataset
11 | from utils import AverageMeter
12 |
13 | cudnn.benchmark = True
14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--arch', type=str, default='REDNet10', help='REDNet10, REDNet20, REDNet30')
20 | parser.add_argument('--images_dir', type=str, required=True)
21 | parser.add_argument('--outputs_dir', type=str, required=True)
22 | parser.add_argument('--jpeg_quality', type=int, default=10)
23 | parser.add_argument('--patch_size', type=int, default=50)
24 | parser.add_argument('--batch_size', type=int, default=16)
25 | parser.add_argument('--num_epochs', type=int, default=20)
26 | parser.add_argument('--lr', type=float, default=1e-4)
27 | parser.add_argument('--threads', type=int, default=8)
28 | parser.add_argument('--seed', type=int, default=123)
29 | parser.add_argument('--use_fast_loader', action='store_true')
30 | opt = parser.parse_args()
31 |
32 | if not os.path.exists(opt.outputs_dir):
33 | os.makedirs(opt.outputs_dir)
34 |
35 | torch.manual_seed(opt.seed)
36 |
37 | if opt.arch == 'REDNet10':
38 | model = REDNet10()
39 | elif opt.arch == 'REDNet20':
40 | model = REDNet20()
41 | elif opt.arch == 'REDNet30':
42 | model = REDNet30()
43 |
44 | model = model.to(device)
45 | criterion = nn.MSELoss()
46 |
47 | optimizer = optim.Adam(model.parameters(), lr=opt.lr)
48 |
49 | dataset = Dataset(opt.images_dir, opt.patch_size, opt.jpeg_quality, opt.use_fast_loader)
50 | dataloader = DataLoader(dataset=dataset,
51 | batch_size=opt.batch_size,
52 | shuffle=True,
53 | num_workers=opt.threads,
54 | pin_memory=True,
55 | drop_last=True)
56 |
57 | for epoch in range(opt.num_epochs):
58 | epoch_losses = AverageMeter()
59 |
60 | with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm:
61 | _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs))
62 | for data in dataloader:
63 | inputs, labels = data
64 | inputs = inputs.to(device)
65 | labels = labels.to(device)
66 |
67 | preds = model(inputs)
68 |
69 | loss = criterion(preds, labels)
70 | epoch_losses.update(loss.item(), len(inputs))
71 |
72 | optimizer.zero_grad()
73 | loss.backward()
74 | optimizer.step()
75 |
76 | _tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
77 | _tqdm.update(len(inputs))
78 |
79 | torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format(opt.arch, epoch)))
80 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch import nn
3 |
4 |
5 | class REDNet10(nn.Module):
6 | def __init__(self, num_layers=5, num_features=64):
7 | super(REDNet10, self).__init__()
8 | conv_layers = []
9 | deconv_layers = []
10 |
11 | conv_layers.append(nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=2, padding=1),
12 | nn.ReLU(inplace=True)))
13 | for i in range(num_layers - 1):
14 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
15 | nn.ReLU(inplace=True)))
16 |
17 | for i in range(num_layers - 1):
18 | deconv_layers.append(nn.Sequential(nn.ConvTranspose2d(num_features, num_features, kernel_size=3, padding=1),
19 | nn.ReLU(inplace=True)))
20 | deconv_layers.append(nn.ConvTranspose2d(num_features, 3, kernel_size=3, stride=2, padding=1, output_padding=1))
21 |
22 | self.conv_layers = nn.Sequential(*conv_layers)
23 | self.deconv_layers = nn.Sequential(*deconv_layers)
24 | self.relu = nn.ReLU(inplace=True)
25 |
26 | def forward(self, x):
27 | residual = x
28 | out = self.conv_layers(x)
29 | out = self.deconv_layers(out)
30 | out += residual
31 | out = self.relu(out)
32 | return out
33 |
34 |
35 | class REDNet20(nn.Module):
36 | def __init__(self, num_layers=10, num_features=64):
37 | super(REDNet20, self).__init__()
38 | self.num_layers = num_layers
39 |
40 | conv_layers = []
41 | deconv_layers = []
42 |
43 | conv_layers.append(nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=2, padding=1),
44 | nn.ReLU(inplace=True)))
45 | for i in range(num_layers - 1):
46 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
47 | nn.ReLU(inplace=True)))
48 |
49 | for i in range(num_layers - 1):
50 | deconv_layers.append(nn.Sequential(nn.ConvTranspose2d(num_features, num_features, kernel_size=3, padding=1),
51 | nn.ReLU(inplace=True)))
52 | deconv_layers.append(nn.ConvTranspose2d(num_features, 3, kernel_size=3, stride=2, padding=1, output_padding=1))
53 |
54 | self.conv_layers = nn.Sequential(*conv_layers)
55 | self.deconv_layers = nn.Sequential(*deconv_layers)
56 | self.relu = nn.ReLU(inplace=True)
57 |
58 | def forward(self, x):
59 | residual = x
60 |
61 | conv_feats = []
62 | for i in range(self.num_layers):
63 | x = self.conv_layers[i](x)
64 | if (i + 1) % 2 == 0 and len(conv_feats) < math.ceil(self.num_layers / 2) - 1:
65 | conv_feats.append(x)
66 |
67 | conv_feats_idx = 0
68 | for i in range(self.num_layers):
69 | x = self.deconv_layers[i](x)
70 | if (i + 1 + self.num_layers) % 2 == 0 and conv_feats_idx < len(conv_feats):
71 | conv_feat = conv_feats[-(conv_feats_idx + 1)]
72 | conv_feats_idx += 1
73 | x = x + conv_feat
74 | x = self.relu(x)
75 |
76 | x += residual
77 | x = self.relu(x)
78 |
79 | return x
80 |
81 |
82 | class REDNet30(nn.Module):
83 | def __init__(self, num_layers=15, num_features=64):
84 | super(REDNet30, self).__init__()
85 | self.num_layers = num_layers
86 |
87 | conv_layers = []
88 | deconv_layers = []
89 |
90 | conv_layers.append(nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=2, padding=1),
91 | nn.ReLU(inplace=True)))
92 | for i in range(num_layers - 1):
93 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
94 | nn.ReLU(inplace=True)))
95 |
96 | for i in range(num_layers - 1):
97 | deconv_layers.append(nn.Sequential(nn.ConvTranspose2d(num_features, num_features, kernel_size=3, padding=1),
98 | nn.ReLU(inplace=True)))
99 | deconv_layers.append(nn.ConvTranspose2d(num_features, 3, kernel_size=3, stride=2, padding=1, output_padding=1))
100 |
101 | self.conv_layers = nn.Sequential(*conv_layers)
102 | self.deconv_layers = nn.Sequential(*deconv_layers)
103 | self.relu = nn.ReLU(inplace=True)
104 |
105 | def forward(self, x):
106 | residual = x
107 |
108 | conv_feats = []
109 | for i in range(self.num_layers):
110 | x = self.conv_layers[i](x)
111 | if (i + 1) % 2 == 0 and len(conv_feats) < math.ceil(self.num_layers / 2) - 1:
112 | conv_feats.append(x)
113 |
114 | conv_feats_idx = 0
115 | for i in range(self.num_layers):
116 | x = self.deconv_layers[i](x)
117 | if (i + 1 + self.num_layers) % 2 == 0 and conv_feats_idx < len(conv_feats):
118 | conv_feat = conv_feats[-(conv_feats_idx + 1)]
119 | conv_feats_idx += 1
120 | x = x + conv_feat
121 | x = self.relu(x)
122 |
123 | x += residual
124 | x = self.relu(x)
125 |
126 | return x
127 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | def __init__(self):
3 | self.reset()
4 |
5 | def reset(self):
6 | self.val = 0
7 | self.avg = 0
8 | self.sum = 0
9 | self.count = 0
10 |
11 | def update(self, val, n=1):
12 | self.val = val
13 | self.sum += val * n
14 | self.count += n
15 | self.avg = self.sum / self.count
16 |
--------------------------------------------------------------------------------