├── .gitignore
├── data
├── monarch.bmp
├── monarch_ARCNN.png
├── monarch_FastARCNN.png
└── monarch_jpeg_q10.png
├── utils.py
├── model.py
├── example.py
├── README.md
├── dataset.py
└── main.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 |
--------------------------------------------------------------------------------
/data/monarch.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/ARCNN-pytorch/HEAD/data/monarch.bmp
--------------------------------------------------------------------------------
/data/monarch_ARCNN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/ARCNN-pytorch/HEAD/data/monarch_ARCNN.png
--------------------------------------------------------------------------------
/data/monarch_FastARCNN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/ARCNN-pytorch/HEAD/data/monarch_FastARCNN.png
--------------------------------------------------------------------------------
/data/monarch_jpeg_q10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/ARCNN-pytorch/HEAD/data/monarch_jpeg_q10.png
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | def __init__(self):
3 | self.reset()
4 |
5 | def reset(self):
6 | self.val = 0
7 | self.avg = 0
8 | self.sum = 0
9 | self.count = 0
10 |
11 | def update(self, val, n=1):
12 | self.val = val
13 | self.sum += val * n
14 | self.count += n
15 | self.avg = self.sum / self.count
16 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class ARCNN(nn.Module):
5 | def __init__(self):
6 | super(ARCNN, self).__init__()
7 | self.base = nn.Sequential(
8 | nn.Conv2d(3, 64, kernel_size=9, padding=4),
9 | nn.PReLU(),
10 | nn.Conv2d(64, 32, kernel_size=7, padding=3),
11 | nn.PReLU(),
12 | nn.Conv2d(32, 16, kernel_size=1),
13 | nn.PReLU()
14 | )
15 | self.last = nn.Conv2d(16, 3, kernel_size=5, padding=2)
16 |
17 | self._initialize_weights()
18 |
19 | def _initialize_weights(self):
20 | for m in self.modules():
21 | if isinstance(m, nn.Conv2d):
22 | nn.init.normal_(m.weight, std=0.001)
23 |
24 | def forward(self, x):
25 | x = self.base(x)
26 | x = self.last(x)
27 | return x
28 |
29 |
30 | class FastARCNN(nn.Module):
31 | def __init__(self):
32 | super(FastARCNN, self).__init__()
33 | self.base = nn.Sequential(
34 | nn.Conv2d(3, 64, kernel_size=9, stride=2, padding=4),
35 | nn.PReLU(),
36 | nn.Conv2d(64, 32, kernel_size=1),
37 | nn.PReLU(),
38 | nn.Conv2d(32, 32, kernel_size=7, padding=3),
39 | nn.PReLU(),
40 | nn.Conv2d(32, 64, kernel_size=1),
41 | nn.PReLU()
42 | )
43 | self.last = nn.ConvTranspose2d(64, 3, kernel_size=9, stride=2, padding=4, output_padding=1)
44 |
45 | self._initialize_weights()
46 |
47 | def _initialize_weights(self):
48 | for m in self.modules():
49 | if isinstance(m, nn.Conv2d):
50 | nn.init.normal_(m.weight, std=0.001)
51 |
52 | def forward(self, x):
53 | x = self.base(x)
54 | x = self.last(x)
55 | return x
56 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import io
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | from torchvision import transforms
7 | import PIL.Image as pil_image
8 | from model import ARCNN, FastARCNN
9 |
10 | cudnn.benchmark = True
11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12 |
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--arch', type=str, default='ARCNN', help='ARCNN or FastARCNN')
17 | parser.add_argument('--weights_path', type=str, required=True)
18 | parser.add_argument('--image_path', type=str, required=True)
19 | parser.add_argument('--outputs_dir', type=str, required=True)
20 | parser.add_argument('--jpeg_quality', type=int, default=10)
21 | opt = parser.parse_args()
22 |
23 | if not os.path.exists(opt.outputs_dir):
24 | os.makedirs(opt.outputs_dir)
25 |
26 | if opt.arch == 'ARCNN':
27 | model = ARCNN()
28 | elif opt.arch == 'FastARCNN':
29 | model = FastARCNN()
30 |
31 | state_dict = model.state_dict()
32 | for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items():
33 | if n in state_dict.keys():
34 | state_dict[n].copy_(p)
35 | else:
36 | raise KeyError(n)
37 |
38 | model = model.to(device)
39 | model.eval()
40 |
41 | filename = os.path.basename(opt.image_path).split('.')[0]
42 |
43 | input = pil_image.open(opt.image_path).convert('RGB')
44 |
45 | buffer = io.BytesIO()
46 | input.save(buffer, format='jpeg', quality=opt.jpeg_quality)
47 | input = pil_image.open(buffer)
48 | input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))
49 |
50 | input = transforms.ToTensor()(input).unsqueeze(0).to(device)
51 |
52 | with torch.no_grad():
53 | pred = model(input)
54 |
55 | pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
56 | output = pil_image.fromarray(pred, mode='RGB')
57 | output.save(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, opt.arch)))
58 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AR-CNN, Fast AR-CNN
2 |
3 | This repository is implementation of the "Deep Convolution Networks for Compression Artifacts Reduction".
4 | In contrast with original paper, It use RGB channels instead of luminance channel in YCbCr space and smaller(16) batch size.
5 |
6 | ## Requirements
7 | - PyTorch
8 | - Tensorflow
9 | - tqdm
10 | - Numpy
11 | - Pillow
12 |
13 | **Tensorflow** is required for quickly fetching image in training phase.
14 |
15 | ## Results
16 |
17 |
24 | ![]() |
26 |
27 | ![]() |
29 |
36 | ![]() |
38 |
39 | ![]() |
41 |