├── .gitignore ├── README.md ├── data ├── monarch.bmp ├── monarch_ARCNN.png ├── monarch_REDNet10.png ├── monarch_REDNet20.png ├── monarch_REDNet30.png └── monarch_jpeg_q10.png ├── dataset.py ├── example.py ├── main.py ├── model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RED-Net 2 | 3 | This repository is implementation of the "Image Restoration Using Very Deep Convolutional Encoder-Decoder Networks with Symmetric Skip Connections".
4 | To reduce computational cost, it adopts stride 2 for the first convolution layer and the last transposed convolution layer. 5 | 6 | ## Requirements 7 | - PyTorch 8 | - Tensorflow 9 | - tqdm 10 | - Numpy 11 | - Pillow 12 | 13 | **Tensorflow** is required for quickly fetching image in training phase. 14 | 15 | ## Results 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 26 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 38 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 50 | 53 | 54 |
Input
JPEG (Quality 10)
24 |
25 |
27 |
28 |
AR-CNN
RED-Net 10
36 |
37 |
39 |
40 |
RED-Net 20
RED-Net 30
48 |
49 |
51 |
52 |
55 | 56 | ## Usages 57 | 58 | ### Train 59 | 60 | When training begins, the model weights will be saved every epoch.
61 | If you want to train quickly, you should use **--use_fast_loader** option. 62 | 63 | ```bash 64 | python main.py --arch "REDNet30" \ # REDNet10, REDNet20, REDNet30 65 | --images_dir "" \ 66 | --outputs_dir "" \ 67 | --jpeg_quality 10 \ 68 | --patch_size 50 \ 69 | --batch_size 16 \ 70 | --num_epochs 20 \ 71 | --lr 1e-4 \ 72 | --threads 8 \ 73 | --seed 123 \ 74 | --use_fast_loader 75 | ``` 76 | 77 | ### Test 78 | 79 | Output results consist of image compressed with JPEG and image with artifacts reduced. 80 | 81 | ```bash 82 | python example --arch "REDNet30" \ # REDNet10, REDNet20, REDNet30 83 | --weights_path "" \ 84 | --image_path "" \ 85 | --outputs_dir "" \ 86 | --jpeg_quality 10 87 | ``` 88 | -------------------------------------------------------------------------------- /data/monarch.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch.bmp -------------------------------------------------------------------------------- /data/monarch_ARCNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch_ARCNN.png -------------------------------------------------------------------------------- /data/monarch_REDNet10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch_REDNet10.png -------------------------------------------------------------------------------- /data/monarch_REDNet20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch_REDNet20.png -------------------------------------------------------------------------------- /data/monarch_REDNet30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch_REDNet30.png -------------------------------------------------------------------------------- /data/monarch_jpeg_q10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/REDNet-pytorch/11ee46722a4fbee48b37f417e4329026d5b78bfa/data/monarch_jpeg_q10.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 3 | 4 | import random 5 | import glob 6 | import io 7 | import numpy as np 8 | import PIL.Image as pil_image 9 | 10 | import tensorflow as tf 11 | config = tf.ConfigProto() 12 | config.gpu_options.allow_growth = True 13 | tf.enable_eager_execution(config=config) 14 | 15 | 16 | class Dataset(object): 17 | def __init__(self, images_dir, patch_size, jpeg_quality, use_fast_loader=False): 18 | self.image_files = sorted(glob.glob(images_dir + '/*')) 19 | self.patch_size = patch_size 20 | self.jpeg_quality = jpeg_quality 21 | self.use_fast_loader = use_fast_loader 22 | 23 | def __getitem__(self, idx): 24 | if self.use_fast_loader: 25 | label = tf.read_file(self.image_files[idx]) 26 | label = tf.image.decode_jpeg(label, channels=3) 27 | label = pil_image.fromarray(label.numpy()) 28 | else: 29 | label = pil_image.open(self.image_files[idx]).convert('RGB') 30 | 31 | # randomly crop patch from training set 32 | crop_x = random.randint(0, label.width - self.patch_size) 33 | crop_y = random.randint(0, label.height - self.patch_size) 34 | label = label.crop((crop_x, crop_y, crop_x + self.patch_size, crop_y + self.patch_size)) 35 | 36 | # additive jpeg noise 37 | buffer = io.BytesIO() 38 | label.save(buffer, format='jpeg', quality=self.jpeg_quality) 39 | input = pil_image.open(buffer) 40 | 41 | input = np.array(input).astype(np.float32) 42 | label = np.array(label).astype(np.float32) 43 | input = np.transpose(input, axes=[2, 0, 1]) 44 | label = np.transpose(label, axes=[2, 0, 1]) 45 | 46 | # normalization 47 | input /= 255.0 48 | label /= 255.0 49 | 50 | return input, label 51 | 52 | def __len__(self): 53 | return len(self.image_files) 54 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import io 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | from torchvision import transforms 7 | import PIL.Image as pil_image 8 | from model import REDNet10, REDNet20, REDNet30 9 | 10 | cudnn.benchmark = True 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--arch', type=str, default='REDNet10', help='REDNet10, REDNet20, REDNet30') 17 | parser.add_argument('--weights_path', type=str, required=True) 18 | parser.add_argument('--image_path', type=str, required=True) 19 | parser.add_argument('--outputs_dir', type=str, required=True) 20 | parser.add_argument('--jpeg_quality', type=int, default=10) 21 | opt = parser.parse_args() 22 | 23 | if not os.path.exists(opt.outputs_dir): 24 | os.makedirs(opt.outputs_dir) 25 | 26 | if opt.arch == 'REDNet10': 27 | model = REDNet10() 28 | elif opt.arch == 'REDNet20': 29 | model = REDNet20() 30 | elif opt.arch == 'REDNet30': 31 | model = REDNet30() 32 | 33 | state_dict = model.state_dict() 34 | for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items(): 35 | if n in state_dict.keys(): 36 | state_dict[n].copy_(p) 37 | else: 38 | raise KeyError(n) 39 | 40 | model = model.to(device) 41 | model.eval() 42 | 43 | filename = os.path.basename(opt.image_path).split('.')[0] 44 | 45 | input = pil_image.open(opt.image_path).convert('RGB') 46 | 47 | buffer = io.BytesIO() 48 | input.save(buffer, format='jpeg', quality=opt.jpeg_quality) 49 | input = pil_image.open(buffer) 50 | input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality))) 51 | 52 | input = transforms.ToTensor()(input).unsqueeze(0).to(device) 53 | 54 | with torch.no_grad(): 55 | pred = model(input) 56 | 57 | pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy() 58 | output = pil_image.fromarray(pred, mode='RGB') 59 | output.save(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, opt.arch))) 60 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from torch import nn 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | from torch.utils.data.dataloader import DataLoader 8 | from tqdm import tqdm 9 | from model import REDNet10, REDNet20, REDNet30 10 | from dataset import Dataset 11 | from utils import AverageMeter 12 | 13 | cudnn.benchmark = True 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--arch', type=str, default='REDNet10', help='REDNet10, REDNet20, REDNet30') 20 | parser.add_argument('--images_dir', type=str, required=True) 21 | parser.add_argument('--outputs_dir', type=str, required=True) 22 | parser.add_argument('--jpeg_quality', type=int, default=10) 23 | parser.add_argument('--patch_size', type=int, default=50) 24 | parser.add_argument('--batch_size', type=int, default=16) 25 | parser.add_argument('--num_epochs', type=int, default=20) 26 | parser.add_argument('--lr', type=float, default=1e-4) 27 | parser.add_argument('--threads', type=int, default=8) 28 | parser.add_argument('--seed', type=int, default=123) 29 | parser.add_argument('--use_fast_loader', action='store_true') 30 | opt = parser.parse_args() 31 | 32 | if not os.path.exists(opt.outputs_dir): 33 | os.makedirs(opt.outputs_dir) 34 | 35 | torch.manual_seed(opt.seed) 36 | 37 | if opt.arch == 'REDNet10': 38 | model = REDNet10() 39 | elif opt.arch == 'REDNet20': 40 | model = REDNet20() 41 | elif opt.arch == 'REDNet30': 42 | model = REDNet30() 43 | 44 | model = model.to(device) 45 | criterion = nn.MSELoss() 46 | 47 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 48 | 49 | dataset = Dataset(opt.images_dir, opt.patch_size, opt.jpeg_quality, opt.use_fast_loader) 50 | dataloader = DataLoader(dataset=dataset, 51 | batch_size=opt.batch_size, 52 | shuffle=True, 53 | num_workers=opt.threads, 54 | pin_memory=True, 55 | drop_last=True) 56 | 57 | for epoch in range(opt.num_epochs): 58 | epoch_losses = AverageMeter() 59 | 60 | with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm: 61 | _tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs)) 62 | for data in dataloader: 63 | inputs, labels = data 64 | inputs = inputs.to(device) 65 | labels = labels.to(device) 66 | 67 | preds = model(inputs) 68 | 69 | loss = criterion(preds, labels) 70 | epoch_losses.update(loss.item(), len(inputs)) 71 | 72 | optimizer.zero_grad() 73 | loss.backward() 74 | optimizer.step() 75 | 76 | _tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg)) 77 | _tqdm.update(len(inputs)) 78 | 79 | torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format(opt.arch, epoch))) 80 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | 4 | 5 | class REDNet10(nn.Module): 6 | def __init__(self, num_layers=5, num_features=64): 7 | super(REDNet10, self).__init__() 8 | conv_layers = [] 9 | deconv_layers = [] 10 | 11 | conv_layers.append(nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=2, padding=1), 12 | nn.ReLU(inplace=True))) 13 | for i in range(num_layers - 1): 14 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), 15 | nn.ReLU(inplace=True))) 16 | 17 | for i in range(num_layers - 1): 18 | deconv_layers.append(nn.Sequential(nn.ConvTranspose2d(num_features, num_features, kernel_size=3, padding=1), 19 | nn.ReLU(inplace=True))) 20 | deconv_layers.append(nn.ConvTranspose2d(num_features, 3, kernel_size=3, stride=2, padding=1, output_padding=1)) 21 | 22 | self.conv_layers = nn.Sequential(*conv_layers) 23 | self.deconv_layers = nn.Sequential(*deconv_layers) 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | def forward(self, x): 27 | residual = x 28 | out = self.conv_layers(x) 29 | out = self.deconv_layers(out) 30 | out += residual 31 | out = self.relu(out) 32 | return out 33 | 34 | 35 | class REDNet20(nn.Module): 36 | def __init__(self, num_layers=10, num_features=64): 37 | super(REDNet20, self).__init__() 38 | self.num_layers = num_layers 39 | 40 | conv_layers = [] 41 | deconv_layers = [] 42 | 43 | conv_layers.append(nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=2, padding=1), 44 | nn.ReLU(inplace=True))) 45 | for i in range(num_layers - 1): 46 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), 47 | nn.ReLU(inplace=True))) 48 | 49 | for i in range(num_layers - 1): 50 | deconv_layers.append(nn.Sequential(nn.ConvTranspose2d(num_features, num_features, kernel_size=3, padding=1), 51 | nn.ReLU(inplace=True))) 52 | deconv_layers.append(nn.ConvTranspose2d(num_features, 3, kernel_size=3, stride=2, padding=1, output_padding=1)) 53 | 54 | self.conv_layers = nn.Sequential(*conv_layers) 55 | self.deconv_layers = nn.Sequential(*deconv_layers) 56 | self.relu = nn.ReLU(inplace=True) 57 | 58 | def forward(self, x): 59 | residual = x 60 | 61 | conv_feats = [] 62 | for i in range(self.num_layers): 63 | x = self.conv_layers[i](x) 64 | if (i + 1) % 2 == 0 and len(conv_feats) < math.ceil(self.num_layers / 2) - 1: 65 | conv_feats.append(x) 66 | 67 | conv_feats_idx = 0 68 | for i in range(self.num_layers): 69 | x = self.deconv_layers[i](x) 70 | if (i + 1 + self.num_layers) % 2 == 0 and conv_feats_idx < len(conv_feats): 71 | conv_feat = conv_feats[-(conv_feats_idx + 1)] 72 | conv_feats_idx += 1 73 | x = x + conv_feat 74 | x = self.relu(x) 75 | 76 | x += residual 77 | x = self.relu(x) 78 | 79 | return x 80 | 81 | 82 | class REDNet30(nn.Module): 83 | def __init__(self, num_layers=15, num_features=64): 84 | super(REDNet30, self).__init__() 85 | self.num_layers = num_layers 86 | 87 | conv_layers = [] 88 | deconv_layers = [] 89 | 90 | conv_layers.append(nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=2, padding=1), 91 | nn.ReLU(inplace=True))) 92 | for i in range(num_layers - 1): 93 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), 94 | nn.ReLU(inplace=True))) 95 | 96 | for i in range(num_layers - 1): 97 | deconv_layers.append(nn.Sequential(nn.ConvTranspose2d(num_features, num_features, kernel_size=3, padding=1), 98 | nn.ReLU(inplace=True))) 99 | deconv_layers.append(nn.ConvTranspose2d(num_features, 3, kernel_size=3, stride=2, padding=1, output_padding=1)) 100 | 101 | self.conv_layers = nn.Sequential(*conv_layers) 102 | self.deconv_layers = nn.Sequential(*deconv_layers) 103 | self.relu = nn.ReLU(inplace=True) 104 | 105 | def forward(self, x): 106 | residual = x 107 | 108 | conv_feats = [] 109 | for i in range(self.num_layers): 110 | x = self.conv_layers[i](x) 111 | if (i + 1) % 2 == 0 and len(conv_feats) < math.ceil(self.num_layers / 2) - 1: 112 | conv_feats.append(x) 113 | 114 | conv_feats_idx = 0 115 | for i in range(self.num_layers): 116 | x = self.deconv_layers[i](x) 117 | if (i + 1 + self.num_layers) % 2 == 0 and conv_feats_idx < len(conv_feats): 118 | conv_feat = conv_feats[-(conv_feats_idx + 1)] 119 | conv_feats_idx += 1 120 | x = x + conv_feat 121 | x = self.relu(x) 122 | 123 | x += residual 124 | x = self.relu(x) 125 | 126 | return x 127 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | def __init__(self): 3 | self.reset() 4 | 5 | def reset(self): 6 | self.val = 0 7 | self.avg = 0 8 | self.sum = 0 9 | self.count = 0 10 | 11 | def update(self, val, n=1): 12 | self.val = val 13 | self.sum += val * n 14 | self.count += n 15 | self.avg = self.sum / self.count 16 | --------------------------------------------------------------------------------