├── .gitignore
├── README.md
├── data
├── monarch.bmp
├── monarch_ARCNN.png
├── monarch_S-Net.png
└── monarch_jpeg_q10.png
├── dataset.py
├── example.py
├── main.py
├── model.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # S-Net
2 |
3 | This repository is implementation of the "S-Net: A Scalable Convolutional Neural Network for JPEG Compression Artifact Reduction".
4 |
5 | ## Requirements
6 | - Python 3.7
7 | - PyTorch 1.0.0
8 | - Tensorflow 1.13.0
9 | - tqdm 4.30.0
10 | - Numpy 1.15.4
11 | - Pillow 5.4.1
12 |
13 | **Tensorflow** is required for quickly fetching image in training phase.
14 |
15 | ## Results
16 |
17 |
18 |
19 | Input |
20 | JPEG (Quality 10) |
21 |
22 |
23 |
24 |
25 | |
26 |
27 |
28 | |
29 |
30 |
31 | AR-CNN |
32 | S-Net - Metric 8 |
33 |
34 |
35 |
36 |
37 | |
38 |
39 |
40 | |
41 |
42 |
43 |
44 | ## Usages
45 |
46 | ### Train
47 |
48 | When training begins, the model weights will be saved every epoch.
49 | If you want to train quickly, you should use **--use_fast_loader** option.
50 |
51 | ```bash
52 | python main.py --num_metrics 8 \
53 | --structure_type "advanced" \ # classic, advanced
54 | --images_dir "" \
55 | --outputs_dir "" \
56 | --jpeg_quality 10 \
57 | --patch_size 48 \
58 | --batch_size 16 \
59 | --num_epochs 20 \
60 | --lr 1e-4 \
61 | --threads 8 \
62 | --seed 123 \
63 | --use_fast_loader
64 | ```
65 |
66 | ### Test
67 |
68 | Output results consist of image compressed with JPEG and image with artifacts reduced.
69 |
70 | ```bash
71 | python example --num_metrics 8 \
72 | --structure_type "advanced" \ # classic, advanced
73 | --weights_path "" \
74 | --image_path "" \
75 | --outputs_dir "" \
76 | --jpeg_quality 10
77 | ```
78 |
--------------------------------------------------------------------------------
/data/monarch.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SNet-pytorch/3fd6cc07f5646eeb9afafdb58d874afc0a4b5d46/data/monarch.bmp
--------------------------------------------------------------------------------
/data/monarch_ARCNN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SNet-pytorch/3fd6cc07f5646eeb9afafdb58d874afc0a4b5d46/data/monarch_ARCNN.png
--------------------------------------------------------------------------------
/data/monarch_S-Net.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SNet-pytorch/3fd6cc07f5646eeb9afafdb58d874afc0a4b5d46/data/monarch_S-Net.png
--------------------------------------------------------------------------------
/data/monarch_jpeg_q10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SNet-pytorch/3fd6cc07f5646eeb9afafdb58d874afc0a4b5d46/data/monarch_jpeg_q10.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
3 |
4 | import random
5 | import glob
6 | import io
7 | import numpy as np
8 | import PIL.Image as pil_image
9 |
10 | import tensorflow as tf
11 | config = tf.ConfigProto()
12 | config.gpu_options.allow_growth = True
13 | tf.enable_eager_execution(config=config)
14 |
15 |
16 | class Dataset(object):
17 | def __init__(self, images_dir, patch_size, jpeg_quality, use_fast_loader=False):
18 | self.image_files = sorted(glob.glob(images_dir + '/*'))
19 | self.patch_size = patch_size
20 | self.jpeg_quality = jpeg_quality
21 | self.use_fast_loader = use_fast_loader
22 |
23 | def __getitem__(self, idx):
24 | if self.use_fast_loader:
25 | label = tf.read_file(self.image_files[idx])
26 | label = tf.image.decode_jpeg(label, channels=3)
27 | label = pil_image.fromarray(label.numpy())
28 | else:
29 | label = pil_image.open(self.image_files[idx]).convert('RGB')
30 |
31 | # randomly crop patch from training set
32 | crop_x = random.randint(0, label.width - self.patch_size)
33 | crop_y = random.randint(0, label.height - self.patch_size)
34 | label = label.crop((crop_x, crop_y, crop_x + self.patch_size, crop_y + self.patch_size))
35 |
36 | # additive jpeg noise
37 | buffer = io.BytesIO()
38 | label.save(buffer, format='jpeg', quality=self.jpeg_quality)
39 | input = pil_image.open(buffer)
40 |
41 | input = np.array(input).astype(np.float32)
42 | label = np.array(label).astype(np.float32)
43 | input = np.transpose(input, axes=[2, 0, 1])
44 | label = np.transpose(label, axes=[2, 0, 1])
45 |
46 | # normalization
47 | input /= 255.0
48 | label /= 255.0
49 |
50 | return input, label
51 |
52 | def __len__(self):
53 | return len(self.image_files)
54 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import io
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | from torchvision import transforms
7 | import PIL.Image as pil_image
8 | from model import S_Net
9 |
10 | cudnn.benchmark = True
11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12 |
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--arch', type=str, default='S-Net')
17 | parser.add_argument('--num_metrics', type=int, default=8)
18 | parser.add_argument('--structure_type', type=str, default='classic')
19 | parser.add_argument('--weights_path', type=str, required=True)
20 | parser.add_argument('--image_path', type=str, required=True)
21 | parser.add_argument('--outputs_dir', type=str, required=True)
22 | parser.add_argument('--jpeg_quality', type=int, default=10)
23 | opt = parser.parse_args()
24 |
25 | if not os.path.exists(opt.outputs_dir):
26 | os.makedirs(opt.outputs_dir)
27 |
28 | model = S_Net(opt.num_metrics, opt.structure_type)
29 |
30 | state_dict = model.state_dict()
31 | for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items():
32 | if n in state_dict.keys():
33 | state_dict[n].copy_(p)
34 | else:
35 | raise KeyError(n)
36 |
37 | model = model.to(device)
38 | model.eval()
39 |
40 | filename = os.path.basename(opt.image_path).split('.')[0]
41 |
42 | input = pil_image.open(opt.image_path).convert('RGB')
43 |
44 | buffer = io.BytesIO()
45 | input.save(buffer, format='jpeg', quality=opt.jpeg_quality)
46 | input = pil_image.open(buffer)
47 | input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))
48 |
49 | input = transforms.ToTensor()(input).unsqueeze(0).to(device)
50 |
51 | with torch.no_grad():
52 | pred = model(input)[-1]
53 |
54 | pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
55 | output = pil_image.fromarray(pred, mode='RGB')
56 | output.save(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, opt.arch)))
57 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from torch import nn
5 | import torch.optim as optim
6 | import torch.backends.cudnn as cudnn
7 | from torch.utils.data.dataloader import DataLoader
8 | from tqdm import tqdm
9 | from model import S_Net
10 | from dataset import Dataset
11 | from utils import AverageMeter
12 |
13 | cudnn.benchmark = True
14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--arch', type=str, default='S-Net')
20 | parser.add_argument('--num_metrics', type=int, default=8)
21 | parser.add_argument('--structure_type', type=str, default='classic')
22 | parser.add_argument('--images_dir', type=str, required=True)
23 | parser.add_argument('--outputs_dir', type=str, required=True)
24 | parser.add_argument('--jpeg_quality', type=int, default=10)
25 | parser.add_argument('--patch_size', type=int, default=48)
26 | parser.add_argument('--batch_size', type=int, default=16)
27 | parser.add_argument('--num_epochs', type=int, default=20)
28 | parser.add_argument('--lr', type=float, default=1e-4)
29 | parser.add_argument('--threads', type=int, default=8)
30 | parser.add_argument('--seed', type=int, default=123)
31 | parser.add_argument('--use_fast_loader', action='store_true')
32 | opt = parser.parse_args()
33 |
34 | if not os.path.exists(opt.outputs_dir):
35 | os.makedirs(opt.outputs_dir)
36 |
37 | torch.manual_seed(opt.seed)
38 |
39 | model = S_Net(opt.num_metrics, opt.structure_type).to(device)
40 | criterion = nn.MSELoss()
41 | optimizer = optim.Adam(model.parameters(), lr=opt.lr)
42 |
43 | dataset = Dataset(opt.images_dir, opt.patch_size, opt.jpeg_quality, opt.use_fast_loader)
44 | dataloader = DataLoader(dataset=dataset,
45 | batch_size=opt.batch_size,
46 | shuffle=True,
47 | num_workers=opt.threads,
48 | pin_memory=True,
49 | drop_last=True)
50 |
51 | for epoch in range(opt.num_epochs):
52 | epoch_losses = AverageMeter()
53 |
54 | with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm:
55 | _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs))
56 | for data in dataloader:
57 | inputs, labels = data
58 | inputs = inputs.to(device)
59 | labels = labels.to(device)
60 |
61 | outs = model(inputs)
62 |
63 | loss = 0.0
64 | for i in range(opt.num_metrics):
65 | loss += criterion(outs[i], labels)
66 |
67 | epoch_losses.update(loss.item(), len(inputs))
68 |
69 | optimizer.zero_grad()
70 | loss.backward()
71 | optimizer.step()
72 |
73 | _tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
74 | _tqdm.update(len(inputs))
75 |
76 | torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format(opt.arch, epoch)))
77 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class EncoderBlock(nn.Module):
5 | def __init__(self):
6 | super(EncoderBlock, self).__init__()
7 | self.net = nn.Sequential(
8 | nn.Conv2d(3, 256, kernel_size=5, padding=2),
9 | nn.ReLU(inplace=True),
10 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
11 | nn.ReLU(inplace=True)
12 | )
13 |
14 | def forward(self, x):
15 | return self.net(x)
16 |
17 |
18 | class DecoderBlock(nn.Module):
19 | def __init__(self):
20 | super(DecoderBlock, self).__init__()
21 | self.net = nn.Sequential(
22 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
23 | nn.ReLU(inplace=True),
24 | nn.Conv2d(256, 3, kernel_size=5, padding=2),
25 | nn.ReLU(inplace=True)
26 | )
27 |
28 | def forward(self, x):
29 | return self.net(x)
30 |
31 |
32 | class ConvolutionalUnit(nn.Module):
33 | def __init__(self, structure_type):
34 | super(ConvolutionalUnit, self).__init__()
35 | self.structure_type = structure_type
36 |
37 | if structure_type == 'classic':
38 | self.net = nn.Sequential(
39 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
40 | nn.ReLU(inplace=True)
41 | )
42 | elif structure_type == 'advanced':
43 | self.net = nn.Sequential(
44 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
45 | nn.ReLU(inplace=True),
46 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
47 | )
48 | else:
49 | raise ValueError(structure_type)
50 |
51 | def forward(self, x):
52 | residual = x
53 | x = self.net(x)
54 | if self.structure_type == 'advanced':
55 | x = 0.1 * x
56 | x = residual + x
57 | return x
58 |
59 |
60 | class S_Net(nn.Module):
61 | def __init__(self, num_metrics=8, structure_type='classic'):
62 | super(S_Net, self).__init__()
63 | self.num_metrics = num_metrics
64 |
65 | self.encoder = EncoderBlock()
66 | self.convolution_units = nn.Sequential(*[ConvolutionalUnit(structure_type) for i in range(num_metrics)])
67 | self.decoders = nn.Sequential(*[DecoderBlock() for i in range(num_metrics)])
68 |
69 | def forward(self, x):
70 | x = self.encoder(x)
71 |
72 | outs = []
73 | prev_out = x
74 | for i in range(self.num_metrics):
75 | out = self.convolution_units[i](prev_out)
76 | prev_out = out
77 | outs.append(self.decoders[i](out))
78 |
79 | return outs
80 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | def __init__(self):
3 | self.reset()
4 |
5 | def reset(self):
6 | self.val = 0
7 | self.avg = 0
8 | self.sum = 0
9 | self.count = 0
10 |
11 | def update(self, val, n=1):
12 | self.val = val
13 | self.sum += val * n
14 | self.count += n
15 | self.avg = self.sum / self.count
16 |
--------------------------------------------------------------------------------