├── .gitignore ├── datasets.py ├── generate_testset.py ├── eval.py ├── models.py ├── utils.py ├── generate_trainset.py ├── README.md └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class TrainDataset(Dataset): 6 | def __init__(self, h5_file): 7 | super(TrainDataset, self).__init__() 8 | self.h5_file = h5_file 9 | 10 | def __getitem__(self, idx): 11 | with h5py.File(self.h5_file, 'r') as f: 12 | lr = f['lr'][idx] 13 | hr = f['hr'][idx] 14 | return lr, hr 15 | 16 | def __len__(self): 17 | with h5py.File(self.h5_file, 'r') as f: 18 | return len(f['lr']) 19 | 20 | 21 | class EvalDataset(Dataset): 22 | def __init__(self, h5_file): 23 | super(EvalDataset, self).__init__() 24 | self.h5_file = h5_file 25 | 26 | def __getitem__(self, idx): 27 | with h5py.File(self.h5_file, 'r') as f: 28 | lr = f['lr'][str(idx)].value 29 | hr = f['hr'][str(idx)].value 30 | return lr, hr 31 | 32 | def __len__(self): 33 | with h5py.File(self.h5_file, 'r') as f: 34 | return len(f['lr']) 35 | -------------------------------------------------------------------------------- /generate_testset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import h5py 4 | import numpy as np 5 | from utils import load_image, modcrop, generate_lr, image_to_array, rgb_to_y, normalize 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--images-dir', type=str, required=True) 11 | parser.add_argument('--output-path', type=str, required=True) 12 | parser.add_argument('--scale', type=int, default=2) 13 | args = parser.parse_args() 14 | 15 | h5_file = h5py.File(args.output_path, 'w') 16 | 17 | lr_group = h5_file.create_group('lr') 18 | hr_group = h5_file.create_group('hr') 19 | 20 | for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))): 21 | hr = load_image(image_path) 22 | hr = modcrop(hr, args.scale) 23 | lr = generate_lr(hr, args.scale) 24 | 25 | hr = image_to_array(hr) 26 | lr = image_to_array(lr) 27 | 28 | hr = np.expand_dims(normalize(rgb_to_y(hr.astype(np.float32), 'chw')), 0) 29 | lr = np.expand_dims(normalize(rgb_to_y(lr.astype(np.float32), 'chw')), 0) 30 | 31 | hr_group.create_dataset(str(i), data=hr) 32 | lr_group.create_dataset(str(i), data=lr) 33 | 34 | h5_file.close() 35 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | from torch.utils.data.dataloader import DataLoader 6 | 7 | from models import DRRN 8 | from datasets import EvalDataset 9 | from utils import AverageMeter, denormalize, PSNR, load_weights 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--weights-file', type=str, required=True) 15 | parser.add_argument('--eval-file', type=str, required=True) 16 | parser.add_argument('--eval-scale', type=int, required=True) 17 | parser.add_argument('--B', type=int, default=1) 18 | parser.add_argument('--U', type=int, default=9) 19 | parser.add_argument('--num-features', type=int, default=128) 20 | args = parser.parse_args() 21 | 22 | cudnn.benchmark = True 23 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 24 | 25 | model = DRRN(B=args.B, U=args.U, num_features=args.num_features).to(device) 26 | model = load_weights(model, args.weights_file) 27 | 28 | eval_dataset = EvalDataset(args.eval_file) 29 | eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1) 30 | 31 | if args.eval_file is not None: 32 | model.eval() 33 | epoch_psnr = AverageMeter() 34 | 35 | for data in eval_dataloader: 36 | inputs, labels = data 37 | 38 | inputs = inputs.to(device) 39 | labels = labels.to(device) 40 | 41 | with torch.no_grad(): 42 | preds = model(inputs) 43 | 44 | preds = denormalize(preds.squeeze(0).squeeze(0)) 45 | labels = denormalize(labels.squeeze(0).squeeze(0)) 46 | 47 | epoch_psnr.update(PSNR(preds, labels, shave_border=args.eval_scale), len(inputs)) 48 | 49 | print('eval psnr: {:.2f}'.format(epoch_psnr.avg)) 50 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class ConvLayer(nn.Module): 5 | def __init__(self, in_channels, out_channels, kernel_size=3): 6 | super(ConvLayer, self).__init__() 7 | self.module = nn.Sequential( 8 | nn.ReLU(inplace=True), 9 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2, bias=False) 10 | ) 11 | 12 | def forward(self, x): 13 | return self.module(x) 14 | 15 | 16 | class ResidualUnit(nn.Module): 17 | def __init__(self, num_features): 18 | super(ResidualUnit, self).__init__() 19 | self.module = nn.Sequential( 20 | ConvLayer(num_features, num_features), 21 | ConvLayer(num_features, num_features) 22 | ) 23 | 24 | def forward(self, h0, x): 25 | return h0 + self.module(x) 26 | 27 | 28 | class RecursiveBlock(nn.Module): 29 | def __init__(self, in_channels, out_channels, U): 30 | super(RecursiveBlock, self).__init__() 31 | self.U = U 32 | self.h0 = ConvLayer(in_channels, out_channels) 33 | self.ru = ResidualUnit(out_channels) 34 | 35 | def forward(self, x): 36 | h0 = self.h0(x) 37 | x = h0 38 | for i in range(self.U): 39 | x = self.ru(h0, x) 40 | return x 41 | 42 | 43 | class DRRN(nn.Module): 44 | def __init__(self, B, U, num_channels=1, num_features=128): 45 | super(DRRN, self).__init__() 46 | self.rbs = nn.Sequential(*[RecursiveBlock(num_channels if i == 0 else num_features, num_features, U) for i in range(B)]) 47 | self.rec = ConvLayer(num_features, num_channels) 48 | self._initialize_weights() 49 | 50 | def _initialize_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | 57 | def forward(self, x): 58 | residual = x 59 | x = self.rbs(x) 60 | x = self.rec(x) 61 | x += residual 62 | return x 63 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import PIL.Image as pil_image 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def load_image(path): 7 | return pil_image.open(path).convert('RGB') 8 | 9 | 10 | def generate_lr(image, scale): 11 | image = image.resize((image.width // scale, image.height // scale), resample=pil_image.BICUBIC) 12 | image = image.resize((image.width * scale, image.height * scale), resample=pil_image.BICUBIC) 13 | return image 14 | 15 | 16 | def modcrop(image, modulo): 17 | w = image.width - image.width % modulo 18 | h = image.height - image.height % modulo 19 | return image.crop((0, 0, w, h)) 20 | 21 | 22 | def generate_patch(image, patch_size, stride): 23 | for i in range(0, image.height - patch_size + 1, stride): 24 | for j in range(0, image.width - patch_size + 1, stride): 25 | yield image.crop((j, i, j + patch_size, i + patch_size)) 26 | 27 | 28 | def image_to_array(image): 29 | return np.array(image).transpose((2, 0, 1)) 30 | 31 | 32 | def normalize(x): 33 | return x / 255.0 34 | 35 | 36 | def denormalize(x): 37 | if type(x) == torch.Tensor: 38 | return (x * 255.0).clamp(0.0, 255.0) 39 | elif type(x) == np.ndarray: 40 | return (x * 255.0).clip(0.0, 255.0) 41 | else: 42 | raise Exception('The denormalize function supports torch.Tensor or np.ndarray types.', type(x)) 43 | 44 | 45 | def rgb_to_y(img, dim_order='hwc'): 46 | if dim_order == 'hwc': 47 | return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256. 48 | else: 49 | return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256. 50 | 51 | 52 | def PSNR(a, b, max=255.0, shave_border=0): 53 | assert type(a) == type(b) 54 | assert (type(a) == torch.Tensor) or (type(a) == np.ndarray) 55 | 56 | a = a[shave_border:a.shape[0]-shave_border, shave_border:a.shape[1]-shave_border] 57 | b = b[shave_border:b.shape[0]-shave_border, shave_border:b.shape[1]-shave_border] 58 | 59 | if type(a) == torch.Tensor: 60 | return 10. * ((max ** 2) / ((a - b) ** 2).mean()).log10() 61 | elif type(a) == np.ndarray: 62 | return 10. * np.log10((max ** 2) / np.mean(((a - b) ** 2))) 63 | else: 64 | raise Exception('The PSNR function supports torch.Tensor or np.ndarray types.', type(a)) 65 | 66 | 67 | def load_weights(model, path): 68 | state_dict = model.state_dict() 69 | for n, p in torch.load(path, map_location=lambda storage, loc: storage).items(): 70 | if n in state_dict.keys(): 71 | state_dict[n].copy_(p) 72 | else: 73 | raise KeyError(n) 74 | return model 75 | 76 | 77 | class AverageMeter(object): 78 | def __init__(self): 79 | self.reset() 80 | 81 | def reset(self): 82 | self.val = 0 83 | self.avg = 0 84 | self.sum = 0 85 | self.count = 0 86 | 87 | def update(self, val, n=1): 88 | self.val = val 89 | self.sum += val * n 90 | self.count += n 91 | self.avg = self.sum / self.count 92 | -------------------------------------------------------------------------------- /generate_trainset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import h5py 4 | import numpy as np 5 | from utils import load_image, modcrop, generate_lr, generate_patch, image_to_array, rgb_to_y, normalize 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--images-dir', type=str, required=True) 11 | parser.add_argument('--output-path', type=str, required=True) 12 | parser.add_argument('--patch-size', type=int, default=31) 13 | parser.add_argument('--stride', type=int, default=21) 14 | args = parser.parse_args() 15 | 16 | hr_patches = [] 17 | lr_patches = [] 18 | 19 | for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))): 20 | # Scale 2 21 | hr = load_image(image_path) 22 | hr = modcrop(hr, 2) 23 | lr = generate_lr(hr, 2) 24 | 25 | for patch in generate_patch(hr, args.patch_size, args.stride): 26 | patch = image_to_array(patch) 27 | patch = np.expand_dims(normalize(rgb_to_y(patch.astype(np.float32), 'chw')), 0) 28 | hr_patches.append(patch) 29 | 30 | for patch in generate_patch(lr, args.patch_size, args.stride): 31 | patch = image_to_array(patch) 32 | patch = np.expand_dims(normalize(rgb_to_y(patch.astype(np.float32), 'chw')), 0) 33 | lr_patches.append(patch) 34 | 35 | # Scale 3 36 | hr = load_image(image_path) 37 | hr = modcrop(hr, 3) 38 | lr = generate_lr(hr, 3) 39 | 40 | for patch in generate_patch(hr, args.patch_size, args.stride): 41 | patch = image_to_array(patch) 42 | patch = np.expand_dims(normalize(rgb_to_y(patch.astype(np.float32), 'chw')), 0) 43 | hr_patches.append(patch) 44 | 45 | for patch in generate_patch(lr, args.patch_size, args.stride): 46 | patch = image_to_array(patch) 47 | patch = np.expand_dims(normalize(rgb_to_y(patch.astype(np.float32), 'chw')), 0) 48 | lr_patches.append(patch) 49 | 50 | # Scale 4 51 | hr = load_image(image_path) 52 | hr = modcrop(hr, 4) 53 | lr = generate_lr(hr, 4) 54 | 55 | for patch in generate_patch(hr, args.patch_size, args.stride): 56 | patch = image_to_array(patch) 57 | patch = np.expand_dims(normalize(rgb_to_y(patch.astype(np.float32), 'chw')), 0) 58 | hr_patches.append(patch) 59 | 60 | for patch in generate_patch(lr, args.patch_size, args.stride): 61 | patch = image_to_array(patch) 62 | patch = np.expand_dims(normalize(rgb_to_y(patch.astype(np.float32), 'chw')), 0) 63 | lr_patches.append(patch) 64 | 65 | print('Images: {}, Patches: {}'.format(i + 1, len(hr_patches))) 66 | 67 | # if i > 100: 68 | # break 69 | 70 | h5_file = h5py.File(args.output_path, 'w') 71 | 72 | h5_file.create_dataset('hr', data=np.array(hr_patches)) 73 | h5_file.create_dataset('lr', data=np.array(lr_patches)) 74 | 75 | h5_file.close() 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DRRN 2 | 3 | This repository is implementation of the ["Image Super-Resolution via Deep Recursive Residual Network"](http://cvlab.cse.msu.edu/project-super-resolution.html). 4 | 5 | ## Requirements 6 | 7 | - PyTorch 1.0.0 8 | - Numpy 1.15.4 9 | - Pillow 5.4.1 10 | - h5py 2.8.0 11 | - tqdm 4.30.0 12 | 13 | ## Prepare 14 | 15 | The images for creating a dataset used for training (**291-image**) or evaluation (**Set5**) can be downloaded from the paper author's [implementation](https://github.com/tyshiwo/DRRN_CVPR17/tree/master/data). 16 | 17 | You can also use pre-created dataset files with same settings as the paper. 18 | 19 | | Dataset | Scale | Type | Link | 20 | |---------|-------|------|------| 21 | | 291-image | 2, 3, 4 | Train | [Download](https://www.dropbox.com/s/w67yqju1suxejxn/291-image_x234.h5?dl=0) | 22 | | Set5 | 2 | Eval | [Download](https://www.dropbox.com/s/b4a48onyqedx8dz/Set5_x2.h5?dl=0) | 23 | | Set5 | 3 | Eval | [Download](https://www.dropbox.com/s/if01dprb3tzc8jr/Set5_x3.h5?dl=0) | 24 | | Set5 | 4 | Eval | [Download](https://www.dropbox.com/s/cdoxdgz99imy9ik/Set5_x4.h5?dl=0) | 25 | 26 | ### Generate training dataset 27 | 28 | ```bash 29 | python generate_trainset.py --images-dir "BLAH_BLAH/Train_291" \ 30 | --output-path "BLAH_BLAH/Train_291_x234.h5" \ 31 | --patch-size 31 \ 32 | --stride 21 33 | ``` 34 | 35 | ### Generate test dataset 36 | 37 | ```bash 38 | python generate_testset.py --images-dir "BLAH_BLAH/Set5" \ 39 | --output-path "BLAH_BLAH/Set5_x2.h5" \ 40 | --scale 2 41 | ``` 42 | 43 | ## Train 44 | 45 | Model weights will be stored in the `--outputs-dir` after every epoch. 46 | 47 | ```bash 48 | python train.py --train-file "BLAH_BLAH/Train_291_x234.h5" \ 49 | --outputs-dir "BLAH_BLAH/DRRN_B1U9" \ 50 | --B 1 \ 51 | --U 9 \ 52 | --num-features 128 \ 53 | --lr 0.1 \ 54 | --clip-grad 0.01 \ 55 | --batch-size 128 \ 56 | --num-epochs 50 \ 57 | --num-workers 8 \ 58 | --seed 123 59 | ``` 60 | 61 | You can also evaluate using `--eval-file`, `--eval-scale` options during training after every epoch. In addition, the best weights file will be stored in the `--outputs-dir` as a `best.pth`. 62 | 63 | ```bash 64 | python train.py --train-file "BLAH_BLAH/Train_291_x234.h5" \ 65 | --outputs-dir "BLAH_BLAH/DRRN_B1U9" \ 66 | --eval-file "BLAH_BLAH/Set5_x2.h5" \ 67 | --eval-scale 2 \ 68 | --B 1 \ 69 | --U 9 \ 70 | --num-features 128 \ 71 | --lr 0.1 \ 72 | --clip-grad 0.01 \ 73 | --batch-size 128 \ 74 | --num-epochs 50 \ 75 | --num-workers 8 \ 76 | --seed 123 77 | ``` 78 | 79 | ## Evaluate 80 | 81 | The pre-trained weights can be downloaded from the following links. 82 | 83 | | Model | Link | 84 | |-------|------| 85 | | DRRN_B1U9 | [Download](https://www.dropbox.com/s/1ozete9panliycb/drrn_x234.pth?dl=0) | 86 | 87 | ```bash 88 | python eval.py --weights-file "BLAH_BLAH/DRRN_B1U9/best.pth" \ 89 | --eval-file "BLAH_BLAH/Set5_x2.h5" \ 90 | --eval-scale 2 \ 91 | --B 1 \ 92 | --U 9 \ 93 | --num-features 128 94 | ``` 95 | 96 | ## Results 97 | 98 | The our model was learned and evaluated on the **Y(luminance) channel**. 99 | 100 | For performance, we modified the original implementation as follows. 101 | 102 | - **Batch normalization** was removed from the residual unit. 103 | - **No bias** was used in the convolution layer. 104 | 105 | ### Performance comparision on the Set5 106 | 107 | | Eval. Mat | Scale | DRRN_B1U9 (Paper) | DRRN_B1U9 (Ours) | 108 | |-----------|-------|-------|-----------------| 109 | | PSNR | 2 | 37.66 | **37.62** | 110 | | PSNR | 3 | 33.93 | **33.86** | 111 | | PSNR | 4 | 31.58 | **31.52** | 112 | 113 | ## References 114 | 115 | 1. [https://github.com/tyshiwo/DRRN_CVPR17](https://github.com/tyshiwo/DRRN_CVPR17) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import copy 4 | 5 | import torch 6 | from torch import nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.utils.data.dataloader import DataLoader 10 | from tqdm import tqdm 11 | 12 | from models import DRRN 13 | from datasets import TrainDataset, EvalDataset 14 | from utils import AverageMeter, denormalize, PSNR, load_weights 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--train-file', type=str, required=True) 20 | parser.add_argument('--outputs-dir', type=str, required=True) 21 | parser.add_argument('--eval-file', type=str) 22 | parser.add_argument('--eval-scale', type=int) 23 | parser.add_argument('--weights-file', type=str) 24 | parser.add_argument('--B', type=int, default=1) 25 | parser.add_argument('--U', type=int, default=9) 26 | parser.add_argument('--num-features', type=int, default=128) 27 | parser.add_argument('--lr', type=float, default=0.1) 28 | parser.add_argument('--clip-grad', type=float, default=0.01) 29 | parser.add_argument('--batch-size', type=int, default=128) 30 | parser.add_argument('--num-epochs', type=int, default=50) 31 | parser.add_argument('--num-workers', type=int, default=8) 32 | parser.add_argument('--seed', type=int, default=123) 33 | args = parser.parse_args() 34 | 35 | args.outputs_dir = os.path.join(args.outputs_dir, 'x234') 36 | if not os.path.exists(args.outputs_dir): 37 | os.makedirs(args.outputs_dir) 38 | 39 | cudnn.benchmark = True 40 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 41 | torch.manual_seed(args.seed) 42 | 43 | model = DRRN(B=args.B, U=args.U, num_features=args.num_features).to(device) 44 | 45 | if args.weights_file is not None: 46 | model = load_weights(model, args.weights_file) 47 | 48 | criterion = nn.MSELoss(reduction='sum') 49 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4) 50 | 51 | train_dataset = TrainDataset(args.train_file) 52 | train_dataloader = DataLoader(dataset=train_dataset, 53 | batch_size=args.batch_size, 54 | shuffle=True, 55 | num_workers=args.num_workers, 56 | pin_memory=True) 57 | 58 | if args.eval_file is not None: 59 | eval_dataset = EvalDataset(args.eval_file) 60 | eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1) 61 | 62 | best_weights = copy.deepcopy(model.state_dict()) 63 | best_epoch = 0 64 | best_psnr = 0.0 65 | 66 | for epoch in range(args.num_epochs): 67 | lr = args.lr * (0.5 ** ((epoch + 1) // 10)) 68 | 69 | for param_group in optimizer.param_groups: 70 | param_group['lr'] = lr 71 | 72 | model.train() 73 | epoch_losses = AverageMeter() 74 | 75 | with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t: 76 | t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1)) 77 | 78 | for data in train_dataloader: 79 | inputs, labels = data 80 | 81 | inputs = inputs.to(device) 82 | labels = labels.to(device) 83 | 84 | preds = model(inputs) 85 | 86 | loss = criterion(preds, labels) / (2 * len(inputs)) 87 | epoch_losses.update(loss.item(), len(inputs)) 88 | 89 | optimizer.zero_grad() 90 | loss.backward() 91 | 92 | nn.utils.clip_grad.clip_grad_norm_(model.parameters(), args.clip_grad / lr) 93 | 94 | optimizer.step() 95 | 96 | t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg), lr=lr) 97 | t.update(len(inputs)) 98 | 99 | torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch))) 100 | 101 | if args.eval_file is not None: 102 | model.eval() 103 | epoch_psnr = AverageMeter() 104 | 105 | for data in eval_dataloader: 106 | inputs, labels = data 107 | 108 | inputs = inputs.to(device) 109 | labels = labels.to(device) 110 | 111 | with torch.no_grad(): 112 | preds = model(inputs) 113 | 114 | preds = denormalize(preds.squeeze(0).squeeze(0)) 115 | labels = denormalize(labels.squeeze(0).squeeze(0)) 116 | 117 | epoch_psnr.update(PSNR(preds, labels, shave_border=args.eval_scale), len(inputs)) 118 | 119 | print('eval psnr: {:.2f}'.format(epoch_psnr.avg)) 120 | 121 | if epoch_psnr.avg > best_psnr: 122 | best_epoch = epoch 123 | best_psnr = epoch_psnr.avg 124 | best_weights = copy.deepcopy(model.state_dict()) 125 | 126 | if args.eval_file is not None: 127 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr)) 128 | torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth')) 129 | --------------------------------------------------------------------------------