├── README.md ├── datasets ├── DIV2K │ └── DIV2K_train_HR │ │ └── 0002.png └── benchmark │ ├── BSD100.csv │ ├── Set14.csv │ ├── Set5.csv │ └── Urban100.csv ├── docs ├── references.md └── todo_list.md ├── exp └── TBD.txt ├── notebooks ├── figs │ ├── Thumbs.db │ ├── self_aug.PNG │ ├── test1.png │ └── test2.png └── implementation.ipynb └── src ├── __init__.py ├── data.py ├── eval.py ├── main.py ├── models ├── RCAN.py ├── RDN.py └── __pycache__ │ └── RCAN.cpython-36.pyc ├── pytorch_ssim.py ├── trainer.py ├── try_test.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # 2018_Deep_Learning_HW-Image Super Resolution 2 | 3 | A simplified PyTorch implementation of **RDN** and **RCAN** 4 | * [RDN](https://arxiv.org/pdf/1802.08797.pdf): Residual Dense Network for Image Super-Resolution 5 | * [RCAN](https://arxiv.org/pdf/1807.02758.pdf): Image Super-Resolution Using Very Deep Residual Channel Attention Networks 6 | 7 | # Contents 8 | * Requirements 9 | * Preparing Dataset and Benchmark 10 | * Training 11 | * Testing 12 | 13 | # Reference 14 | * Official Implementations: 15 | * RDN: https://github.com/yulunzhang/RDN 16 | * RCAN: https://github.com/yulunzhang/RCAN 17 | -------------------------------------------------------------------------------- /datasets/DIV2K/DIV2K_train_HR/0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wayne391/2018_Deep_Learning_HW-Image_SR/95e82e7ad1187552c03826462087e4381fe56ec3/datasets/DIV2K/DIV2K_train_HR/0002.png -------------------------------------------------------------------------------- /datasets/benchmark/BSD100.csv: -------------------------------------------------------------------------------- 1 | method,psnr,ssim 2 | bicubic,29.55380220639785,0.848024091720581 3 | nearest,28.42105636021033,0.830420406460762 4 | glasner,30.258145626885298,0.8653802001476287 5 | Kim,31.10990343850525,0.887862583398819 6 | SelfExSR,31.190067561532683,0.8893164145946503 7 | SRCNN,31.05934739250512,0.8859882575273513 8 | -------------------------------------------------------------------------------- /datasets/benchmark/Set14.csv: -------------------------------------------------------------------------------- 1 | method,psnr,ssim 2 | bicubic,30.07357190411235,0.8726970085075924 3 | nearest,28.381406196315304,0.8496315479278564 4 | glasner,31.233853426966785,0.8909385459763663 5 | Kim,31.966101181520763,0.9062613504273551 6 | SelfExSR,32.32999276754898,0.9086319761616843 7 | SRCNN,31.80858995201594,0.9044969422476632 8 | -------------------------------------------------------------------------------- /datasets/benchmark/Set5.csv: -------------------------------------------------------------------------------- 1 | method,psnr,ssim 2 | bicubic,33.6474676499156,0.9323098421096802 3 | nearest,30.875881877963174,0.903334641456604 4 | glasner,35.34632905578892,0.9459365963935852 5 | Kim,36.1812524258189,0.9539006471633911 6 | SelfExSR,36.374930673834484,0.9556310057640076 7 | SRCNN,36.142361508203535,0.9515218615531922 8 | -------------------------------------------------------------------------------- /datasets/benchmark/Urban100.csv: -------------------------------------------------------------------------------- 1 | method,psnr,ssim 2 | bicubic,26.650316488983684,0.8456345623731614 3 | nearest,25.29819372382894,0.8212874639034271 4 | glasner,27.822410986621108,0.873549565076828 5 | Kim,28.716277045591163,0.8971735113859176 6 | SelfExSR,29.37382386316921,0.9060872340202332 7 | SRCNN,28.59192138166379,0.8930709475278854 8 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | ## NITRE 2018 3 | NTIRE 2018 Challenge on Single Image Super-Resolution: Methods and Results: 4 | http://www.vision.ee.ethz.ch/~timofter/publications/NTIRE2018_SR_report_CVPRW-2018.pdf 5 | 6 | ## Coding 7 | References for coding 8 | * RDN: https://github.com/lizhengwei1992/ResidualDenseNetwork-Pytorch 9 | * SRGAN: https://github.com/leftthomas/SRGAN 10 | * SSIM: https://github.com/Po-Hsun-Su/pytorch-ssim 11 | 12 | ----- 13 | 14 | ## Datasets 15 | Sources of datasets 16 | * [DIV2K](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) 17 | * [Flickr2K](https://github.com/LimBee/NTIRE2017/issues/25) 18 | 19 | ----- 20 | 21 | ## Implementations 22 | Other good implementations. 23 | * https://github.com/togheppi/pytorch-super-resolution-model-collection 24 | * https://github.com/yulunzhang/RCAN 25 | * https://github.com/thstkdgus35/EDSR-PyTorch 26 | 27 | ----- 28 | 29 | ## Benchmarks 30 | The following table is from: https://github.com/jbhuang0604/SelfExSR 31 | 32 | | Dataset | Image source | Download full results | 33 | |:-------:|:------------:|:---------------------:| 34 | | **Set 5** | [Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html) | [link](https://uofi.box.com/shared/static/kfahv87nfe8ax910l85dksyl2q212voc.zip) (16.1 MB) | 35 | | **Set 14** | [Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests) | [link](https://uofi.box.com/shared/static/igsnfieh4lz68l926l8xbklwsnnk8we9.zip) (86.0 MB) | 36 | | **Urban 100** | [Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr) | [link](https://uofi.box.com/shared/static/65upg43jjd0a4cwsiqgl6o6ixube6klm.zip) (1.14 GB) | 37 | | **BSD 100** | [Martin et al. ICCV 2001](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/) | [link](https://uofi.box.com/shared/static/qgctsplb8txrksm9to9x01zfa4m61ngq.zip) (568 MB) | 38 | | **Sun-Hays 80** | [Sun and Hays ICCP 2012](http://cs.brown.edu/~lbsun/SRproj2012/SR_iccp2012.html) | [link](https://uofi.box.com/shared/static/rirohj4773jl7ef752r330rtqw23djt8.zip) (311 MB) | 39 | 40 | For curated list, you can visit here: https://github.com/huangzehao/Super-Resolution.Benckmark 41 | 42 | ----- 43 | 44 | ## Pytorch 45 | * pytorch padding issue 46 | https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338 47 | 48 | * ADAM with learning rate decay 49 | * https://github.com/XifengGuo/CapsNet-Keras/issues/9 50 | * https://www.reddit.com/r/MachineLearning/comments/5slcyi/d_mixing_learning_rate_decay_and_adam_is_it/ 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /docs/todo_list.md: -------------------------------------------------------------------------------- 1 | * script for creating datsets 2 | * patch testing 3 | * single image 4 | * video 5 | -------------------------------------------------------------------------------- /exp/TBD.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wayne391/2018_Deep_Learning_HW-Image_SR/95e82e7ad1187552c03826462087e4381fe56ec3/exp/TBD.txt -------------------------------------------------------------------------------- /notebooks/figs/Thumbs.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wayne391/2018_Deep_Learning_HW-Image_SR/95e82e7ad1187552c03826462087e4381fe56ec3/notebooks/figs/Thumbs.db -------------------------------------------------------------------------------- /notebooks/figs/self_aug.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wayne391/2018_Deep_Learning_HW-Image_SR/95e82e7ad1187552c03826462087e4381fe56ec3/notebooks/figs/self_aug.PNG -------------------------------------------------------------------------------- /notebooks/figs/test1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wayne391/2018_Deep_Learning_HW-Image_SR/95e82e7ad1187552c03826462087e4381fe56ec3/notebooks/figs/test1.png -------------------------------------------------------------------------------- /notebooks/figs/test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wayne391/2018_Deep_Learning_HW-Image_SR/95e82e7ad1187552c03826462087e4381fe56ec3/notebooks/figs/test2.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wayne391/2018_Deep_Learning_HW-Image_SR/95e82e7ad1187552c03826462087e4381fe56ec3/src/__init__.py -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from utils import traverse_dir, get_patch, augment_random 5 | 6 | from torchvision.transforms import ToTensor 7 | from torch.utils.data import Dataset 8 | 9 | 10 | # meta class for dataset classes 11 | class MetaSet(Dataset): 12 | """ 13 | Meta Class for various src 14 | 15 | Usage: 16 | Must define four vars in the __init__ 17 | 18 | Depending on 'file_list_hr': 19 | traing stage: list of img 20 | validation stage: list of img 21 | testing stage: list of None 22 | 23 | Note that no matter what the stage is, the length of 'file_list_hr' 24 | should be the same as 'file_list_lr' 25 | """ 26 | def __init__(self): 27 | # basic 28 | self.args = None 29 | self.file_list = None 30 | self.file_list_hr = None 31 | self.file_list_lr = None 32 | self.num_train = None 33 | 34 | def __getitem__(self, idx): 35 | filename = self.file_list[idx] 36 | file_hr = self.file_list_hr[idx] 37 | file_lr = self.file_list_lr[idx] 38 | 39 | pil_hr = Image.open(file_hr).convert('RGB') 40 | pil_lr = Image.open(file_lr).convert('RGB') 41 | 42 | if self.args.need_patch: 43 | img_hr = np.array(pil_hr) 44 | img_lr = np.array(pil_lr) 45 | img_lr_patch, img_hr_patch = get_patch(img_lr, img_hr, self.args.patch_size, self.scale) 46 | img_lr_patch, img_hr_patch = augment_random(img_lr_patch, img_hr_patch) 47 | pil_hr = Image.fromarray(img_hr_patch, "RGB") 48 | pil_lr = Image.fromarray(img_lr_patch, "RGB") 49 | 50 | hr_tensor = ToTensor()(pil_hr) 51 | lr_tensor = ToTensor()(pil_lr) 52 | return filename, lr_tensor, hr_tensor 53 | 54 | def __len__(self): 55 | return self.num_train 56 | 57 | def append(self, new_set): 58 | self.file_list += new_set.file_list 59 | self.file_list_hr += new_set.file_list_hr 60 | self.file_list_lr += new_set.file_list_lr 61 | self.num_train += new_set.num_train 62 | 63 | 64 | class BenchmarkSet(MetaSet): 65 | def __init__(self, args, dir_benchmark): 66 | # basic 67 | self.args = args 68 | self.scale = args.scale 69 | self.dir_benchmark = dir_benchmark 70 | 71 | # list arrangement 72 | self.file_list_hr = traverse_dir(dir_benchmark, extension=('png'), str_='HR', is_sort=True) 73 | self.file_list_lr = [f.replace('HR', 'LR') for f in self.file_list_hr] 74 | self.file_list = [f.replace('HR', 'new')[22:] for f in self.file_list_hr] 75 | 76 | # num of train 77 | self.num_train = len(self.file_list) 78 | 79 | 80 | class DIV2KDataSet(MetaSet): 81 | def __init__(self, args, scale, data_dir, dir_hr, dir_lr): 82 | # basic 83 | self.args = args 84 | self.scale = scale 85 | 86 | # src dir 87 | self.dirHR = os.path.join(data_dir, dir_hr) 88 | self.dirLR = os.path.join(data_dir, dir_lr, 'X'+str(self.scale)) 89 | self.file_list = traverse_dir(self.dirHR, extension=('png'), is_pure=True, is_ext=False) 90 | 91 | # re-arrange filename 92 | self.file_list_hr = [os.path.join(self.dirHR, f + '.png') for f in self.file_list] 93 | self.file_list_lr = [os.path.join(self.dirLR, f + 'x'+str(self.scale)+'.png') for f in self.file_list] 94 | self.file_list = [f + '.png' for f in self.file_list] # for saving 95 | 96 | # num of train 97 | self.num_train = len(self.file_list) 98 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_ssim 4 | import pandas as pd 5 | import numpy as np 6 | from PIL import Image 7 | from math import log10 8 | from utils import traverse_dir, rgb2y_channel 9 | from torch.autograd import Variable 10 | from torchvision.transforms import ToTensor 11 | from torch.utils.data import Dataset, DataLoader 12 | 13 | 14 | class EvalDataset(Dataset): 15 | """ 16 | Class for evaluation datasets 17 | 18 | Attributes 19 | ---------- 20 | zip_list : list of tuples 21 | tuple[0]: file name of list 1 22 | tuple[1]: file name of list 2 23 | """ 24 | def __init__(self, zip_list): 25 | super(EvalDataset, self).__init__() 26 | self.zip_list = zip_list 27 | 28 | def __getitem__(self, idx): 29 | tmp = self.zip_list[idx] 30 | img1_pil = preproc(tmp[0]) 31 | img2_pil = preproc(tmp[1]) 32 | 33 | img1_tensor = ToTensor()(img1_pil) 34 | img2_tensor = ToTensor()(img2_pil) 35 | return tmp[0], tmp[1], img1_tensor, img2_tensor 36 | 37 | def __len__(self): 38 | return len(self.zip_list) 39 | 40 | 41 | def calc_ssim(img1, img2): 42 | """ 43 | Caculate structural Similarity Index 44 | """ 45 | ssim = pytorch_ssim.ssim(img1, img2).item() 46 | return ssim 47 | 48 | 49 | def calc_psnr(img1, img2, max_=1): 50 | """ 51 | Caculate PSNR 52 | """ 53 | mse = ((img1 - img2)**2).data.mean() 54 | psnr = 10 * log10(max_**2 / mse) 55 | return psnr 56 | 57 | 58 | def calc_score(img1, img2, crop=0): 59 | """ 60 | Caculate ssim and PSNR 61 | """ 62 | with torch.no_grad(): 63 | img1 = Variable(img1) 64 | img2 = Variable(img2) 65 | if torch.cuda.is_available(): 66 | img1 = img1.cuda() 67 | img2 = img2.cuda() 68 | psnr = calc_psnr(img1, img2) 69 | ssim = calc_ssim(img1, img2) 70 | return psnr, ssim 71 | 72 | 73 | def preproc(filename): 74 | """ 75 | Convert to Ycbcr and return PIL with only Y channel 76 | """ 77 | img_pil = Image.open(filename).convert('RGB') 78 | img_Y = rgb2y_channel(np.array(img_pil)) 79 | img_Y_pil = Image.fromarray(np.uint8(img_Y), "L") 80 | return img_Y_pil 81 | 82 | 83 | def eval_image_file(img_file1, img_file2): 84 | """ 85 | Example Function. 86 | Evaluate two images. The inputs are filenames 87 | """ 88 | img1_pil = preproc(img_file1) 89 | img2_pil = preproc(img_file2) 90 | img1_tensor = torch.unsqueeze(ToTensor()(img1_pil), 0) 91 | img2_tensor = torch.unsqueeze(ToTensor()(img2_pil), 0) 92 | return calc_score(img1_tensor, img2_tensor) 93 | 94 | 95 | def eval_filelist(filelist1, filelist2, verbose=True): 96 | """ 97 | evaluate file list 98 | """ 99 | if len(filelist1) != len(filelist2): 100 | raise ValueError('the amount should be equal') 101 | 102 | score_info = {'psnr': [], 'ssim': []} 103 | zip_list = list(zip(filelist1, filelist2)) 104 | eval_set = EvalDataset(zip_list) 105 | eval_loader = DataLoader(dataset=eval_set, num_workers=4, batch_size=1, shuffle=False) 106 | for idx, (fn1, fn2, img1, img2) in enumerate(eval_loader): 107 | psnr, ssim = calc_score(img1, img2) 108 | if verbose: 109 | print('[%d] - %s | %s' % 110 | (idx, os.path.basename(fn1[0]), os.path.basename(fn2[0]))) 111 | print(' psnr: %.6f | ssim: %.6f\n' % (psnr, ssim)) 112 | score_info['psnr'].append(psnr) 113 | score_info['ssim'].append(ssim) 114 | return score_info 115 | 116 | 117 | def eval_benchmark(dir_benchmark, scale=2, verbose=False, dir_='./'): 118 | """ 119 | Running Benchmarks. Srouce: https://github.com/jbhuang0604/SelfExSR 120 | [Warning] The results are slightly lower than original implematation in matlab 121 | """ 122 | if not os.path.exists(dir_): 123 | os.makedirs(dir_) 124 | 125 | # define benchmark and method 126 | benchmarks = ['Set5', 'Set14', 'BSD100', 'Urban100'] 127 | methods = ['bicubic', 'nearest', 'glasner', 'Kim', 'SelfExSR', 'SRCNN'] 128 | 129 | # scale 130 | scale_str = 'image_SRF_' + str(scale) 131 | 132 | # start eval 133 | print('{:=^40}'.format(' runnung benchmarks ')) 134 | for benchmark in benchmarks: 135 | print('[%s]' % benchmark) 136 | root = os.path.join(dir_benchmark, benchmark, scale_str) 137 | method_result = {'psnr': [], 'ssim': []} 138 | 139 | for method in methods: 140 | print(' - %s' % method) 141 | 142 | # get list of files 143 | filelist_hr = traverse_dir(root, extension=('png'), str_='HR', is_sort=True) 144 | filelist_sr = [f.replace('HR', method) for f in filelist_hr] 145 | 146 | # eval 147 | score_info = eval_filelist(filelist_hr, filelist_sr, verbose=verbose) 148 | 149 | method_result['psnr'].append(np.mean(score_info['psnr'])) 150 | method_result['ssim'].append(np.mean(score_info['ssim'])) 151 | 152 | # save result 153 | data_frame = pd.DataFrame(method_result, index=methods) 154 | data_frame.to_csv(os.path.join(dir_, benchmark+'.csv'), index_label='method') 155 | print('{:=^40}'.format(' Done!!! ')) 156 | 157 | 158 | if __name__ == '__main__': 159 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 160 | eval_benchmark('datasets/benchmark', dir_='datasets/benchmark') 161 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import trainer 4 | from data import DIV2KDataSet 5 | # from model.RDN import SRNet 6 | from models.RCAN import SRNet 7 | 8 | parser = argparse.ArgumentParser(description='Semantic aware super-resolution') 9 | 10 | # ----- Global Setting ----- 11 | GPU = 1 12 | EPOCH = 6000 13 | P = 196 14 | STEP = 1000 # for optimizer scheduler 15 | nohup = '../exp' 16 | EXP_DIR = '../debug' 17 | 18 | # train 19 | parser.add_argument('--gpu', type=int, default=GPU, help='gpu') 20 | parser.add_argument('--epochs', type=int, default=EPOCH, help='number of epochs') 21 | parser.add_argument('--is_finetuning', default=False, help='finetune the model') 22 | parser.add_argument('--step_print_loss', type=int, default=100, help='stpes to print loss') 23 | parser.add_argument('--step_save', type=int, default=800, help='stpes to print loss') 24 | 25 | parser.add_argument('--num_threads', type=int, default=4, help='number of threads for data loading') 26 | parser.add_argument('--batch_size', type=int, default=16, help='batch size') 27 | 28 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 29 | parser.add_argument('--scheduler_step_size', type=float, default=STEP, help='period of learning rate decay') 30 | parser.add_argument('--scheduler_gamma', type=float, default=0.5, help='decay ratio') 31 | 32 | # data 33 | parser.add_argument('--data_dir', default='../datasets/DIV2K', help='dataset directory') 34 | parser.add_argument('--benchmark_dir', default='../datasets/benchmark', help='benchmark directory') 35 | parser.add_argument('--need_patch', default=True, help='get patch form image') 36 | parser.add_argument('--patch_size', type=int, default=P, help='patch size (P)') 37 | parser.add_argument('--scale', type=int, default=2, help='scale') 38 | 39 | # ----- model (RDN) ----- 40 | # C = 6 41 | # G = 32 42 | # D = 10 43 | # F = 64 44 | # EXP_NAME = '[%s]_GPU_%d---EPOCH_%d-P_%d---C_%d-G_%d-D_%d-F_%d' % (nohup, GPU, EPOCH, P, C, G, D, F) 45 | # parser.add_argument('--num_dense', type=int, default=C, help='number of conv layer in RDB (C)') 46 | # parser.add_argument('--growth_rate', type=int, default=G, help='growth rate of dense net (G)') 47 | # parser.add_argument('--num_RDB', type=int, default=D, help='number of RDB block (D)') 48 | # parser.add_argument('--num_feat', type=int, default=F, help='number of conv feature maps (F)') 49 | # parser.add_argument('--num_channel', type=int, default=3, help='number of color channels to use') 50 | 51 | # ----- model (RCAN) ----- 52 | F = 64 53 | NB = 12 54 | NG = 6 55 | EXP_NAME = '[%s]_GPU_%s---EPOCH_%d-P_%d---F_%d-NB_%d-NG_%d' % (nohup, str(GPU), EPOCH, P, F, NB, NG) 56 | parser.add_argument('--num_feat', type=int, default=F, help='number of conv feature maps (F)') 57 | parser.add_argument('--num_channel', type=int, default=3, help='number of color channels to use') 58 | parser.add_argument('--n_resblocks', type=int, default=NB, help='number of residual blocks (NB)') 59 | parser.add_argument('--n_resgroups', type=int, default=NG, help='number of residual groups (NG)') 60 | parser.add_argument('--reduction', type=int, default=16, help='number of feature maps reduction') 61 | 62 | # Saver 63 | parser.add_argument('--save_dir', default=EXP_DIR, help='datasave directory') 64 | parser.add_argument('--exp_name', default=EXP_NAME, help='save result') 65 | 66 | args = parser.parse_args() 67 | 68 | 69 | def main(): 70 | # prepare datasets 71 | train_set = DIV2KDataSet(args, args.scale, args.data_dir, 'DIV2K_train_HR', 'DIV2K_train_LR_bicubic') 72 | test_set = DIV2KDataSet(args, args.scale, args.data_dir, 'DIV2K_valid_HR', 'DIV2K_valid_LR_bicubic') 73 | 74 | # append datasets 75 | # train_set_flickr2k = DIV2KDataSet(args, args.scale, '../datasets/Flickr2K', 76 | # 'Flickr2K_HR', 'Flickr2K_LR_bicubic') 77 | # train_set.append(train_set_flickr2k) 78 | print(train_set.__len__()) 79 | 80 | # model 81 | my_model = SRNet(args) 82 | 83 | if isinstance(args.gpu, list) and len(args.gpu) > 1: 84 | import torch.nn as nn 85 | my_model = nn.DataParallel(my_model, args.gpu) 86 | else: 87 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 88 | 89 | trainer.network_paras(my_model) 90 | 91 | # restor for testing/fine-tuning 92 | if args.is_finetuning: 93 | my_model = trainer.restore(args, my_model) 94 | 95 | # train 96 | trainer.train(args, my_model, train_set) 97 | 98 | # run benchmark 99 | trainer.run_benchmark(args, my_model, args.benchmark_dir) 100 | 101 | # test 102 | trainer.test(args, my_model, test_set) 103 | 104 | 105 | if __name__ == '__main__': 106 | main() 107 | -------------------------------------------------------------------------------- /src/models/RCAN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def conv(in_channels, out_channels, kernel_size=3, bias=True): 5 | return nn.Conv2d( 6 | in_channels, 7 | out_channels, 8 | kernel_size=kernel_size, 9 | padding=(kernel_size-1)//2, # same padding 10 | bias=bias) 11 | 12 | 13 | # Channel Attention (CA) Layer 14 | class CALayer(nn.Module): 15 | def __init__(self, channel, reduction=16): 16 | super(CALayer, self).__init__() 17 | # global average pooling: feature --> point 18 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 19 | 20 | # feature channel downscale and upscale --> channel weight 21 | self.conv_du = nn.Sequential( 22 | conv(channel, channel // reduction, 1), 23 | nn.ReLU(inplace=True), 24 | conv(channel // reduction, channel, 1), 25 | nn.Sigmoid()) 26 | 27 | def forward(self, x): 28 | y = self.avg_pool(x) 29 | y = self.conv_du(y) 30 | return x * y 31 | 32 | 33 | # Residual Channel Attention Block (RCAB) 34 | class RCAB(nn.Module): 35 | def __init__(self, args): 36 | super(RCAB, self).__init__() 37 | # init 38 | self.num_feat = args.num_feat 39 | self.reduction = args.reduction 40 | 41 | # body 42 | modules_body = [ 43 | conv(self.num_feat, self.num_feat), 44 | nn.ReLU(True), 45 | conv(self.num_feat, self.num_feat), 46 | CALayer(self.num_feat, self.reduction) 47 | ] 48 | self.body = nn.Sequential(*modules_body) 49 | 50 | def forward(self, x): 51 | res = self.body(x) 52 | res += x 53 | return res 54 | 55 | 56 | # Residual Group (RG) 57 | class ResidualGroup(nn.Module): 58 | def __init__(self, args): 59 | super(ResidualGroup, self).__init__() 60 | # init 61 | self.n_resblocks = args.n_resblocks 62 | self.num_feat = args.num_feat 63 | 64 | # body 65 | modules_body = [RCAB(args) for _ in range(self.n_resblocks)] 66 | modules_body.append(conv(self.num_feat, self.num_feat)) 67 | self.body = nn.Sequential(*modules_body) 68 | 69 | def forward(self, x): 70 | res = self.body(x) 71 | res += x 72 | return res 73 | 74 | 75 | # Residual Channel Attention Network (RCAN) 76 | class SRNet(nn.Module): 77 | def __init__(self, args): 78 | super(SRNet, self).__init__() 79 | # init 80 | self.num_channel = args.num_channel 81 | self.num_feat = args.num_feat 82 | self.scale = args.scale 83 | self.n_resgroups = args.n_resgroups 84 | 85 | # shallow 86 | self.shallow = conv(self.num_channel, self.num_feat) 87 | 88 | # RIR 89 | modules_body = [ResidualGroup(args) for _ in range(self.n_resgroups)] 90 | modules_body.append(conv(self.num_feat, self.num_feat)) 91 | self.body = nn.Sequential(*modules_body) 92 | 93 | # Upsampler 94 | self.conv_up = conv(self.num_feat, self.num_feat*self.scale*self.scale) 95 | self.upsample = nn.PixelShuffle(self.scale) 96 | self.conv_out = conv(self.num_feat, self.num_channel) 97 | 98 | def forward(self, x): 99 | # shallow feature 100 | x = self.shallow(x) 101 | 102 | # RIR 103 | res = self.body(x) 104 | 105 | # Residual 106 | res += x 107 | 108 | # upsample 109 | up = self.conv_up(res) 110 | up = self.upsample(up) 111 | x = self.conv_out(up) 112 | 113 | return x 114 | -------------------------------------------------------------------------------- /src/models/RDN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | def __init__(self, in_planes, out_planes, kernel_size=3): 8 | super(BasicBlock, self).__init__() 9 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2, bias=False) 10 | 11 | def forward(self, x): 12 | out = F.relu(self.conv(x)) 13 | out = torch.cat((x, out), 1) # <- make local residual by 'concat' (important!!) 14 | # print(out.szie()) # print for sanity check 15 | return out 16 | 17 | 18 | # Residual Dense Block 19 | class ResDenseBlock(nn.Module): 20 | ''' 21 | good explaination of how it works: 22 | https://discuss.pytorch.org/t/resolved-how-to-understand-the-densenet-implementation/3964/4 23 | ''' 24 | def __init__(self, nChannels, nDenselayer, growthRate): 25 | super(ResDenseBlock, self).__init__() 26 | nChannels_ = nChannels 27 | modules = [] 28 | for i in range(nDenselayer): 29 | modules.append(BasicBlock(nChannels_, growthRate)) 30 | nChannels_ += growthRate 31 | self.dense_layers = nn.Sequential(*modules) # <- key step, nn.Sequential can hold previous output 32 | self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=False) 33 | 34 | def forward(self, x): 35 | out = self.dense_layers(x) 36 | out = self.conv_1x1(out) 37 | out = out + x 38 | return out 39 | 40 | 41 | # Residual Dense Network 42 | class SRNet(nn.Module): 43 | def __init__(self, args): 44 | super(SRNet, self).__init__() 45 | self.num_channel = args.num_channel 46 | self.num_dense = args.num_dense 47 | self.num_feat = args.num_feat 48 | self.num_RDB = args.num_RDB 49 | self.scale = args.scale 50 | self.growth_rate = args.growth_rate 51 | 52 | # F-1 53 | self.conv1 = nn.Conv2d(self.num_channel, self.num_feat, kernel_size=3, padding=1, bias=True) 54 | # F0 55 | self.conv2 = nn.Conv2d(self.num_feat, self.num_feat, kernel_size=3, padding=1, bias=True) 56 | 57 | # RDBs 58 | self.RDBs = [] 59 | for i in range(self.num_RDB): 60 | tmp_RDB = ResDenseBlock(self.num_feat, self.num_dense, self.growth_rate) 61 | setattr(self, 'RDB%i' % i, tmp_RDB) # set layer to the Module (very important) 62 | self.RDBs.append(tmp_RDB) 63 | 64 | # global feature fusion (GFF) 65 | self.GFF_1x1 = nn.Conv2d(self.num_feat*self.num_RDB, self.num_feat, kernel_size=1, padding=0, bias=True) 66 | self.GFF_3x3 = nn.Conv2d(self.num_feat, self.num_feat, kernel_size=3, padding=1, bias=True) 67 | 68 | # Upsampler 69 | self.conv_up = nn.Conv2d(self.num_feat, self.num_feat*self.scale*self.scale, 70 | kernel_size=3, padding=1, bias=True) 71 | self.upsample = nn.PixelShuffle(self.scale) 72 | self.conv3 = nn.Conv2d(self.num_feat, self.num_channel, kernel_size=3, padding=1, bias=True) 73 | 74 | def forward(self, x): 75 | 76 | # shallow 77 | F_ = self.conv1(x) 78 | F_N = self.conv2(F_) 79 | 80 | # RDBs 81 | F_s = [] 82 | for i in range(self.num_RDB): 83 | F_N = self.RDBs[i](F_N) 84 | F_s.append(F_N) 85 | 86 | # GFF 87 | FF = torch.cat(F_s, 1) 88 | FdLF = self.GFF_1x1(FF) 89 | FGF = self.GFF_3x3(FdLF) 90 | 91 | # DFF 92 | FDF = FGF + F_ 93 | 94 | # upscale 95 | up = self.conv_up(FDF) 96 | up = self.upsample(up) 97 | output = self.conv3(up) 98 | 99 | return output 100 | -------------------------------------------------------------------------------- /src/models/__pycache__/RCAN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wayne391/2018_Deep_Learning_HW-Image_SR/95e82e7ad1187552c03826462087e4381fe56ec3/src/models/__pycache__/RCAN.cpython-36.pyc -------------------------------------------------------------------------------- /src/pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 9 | return gauss / gauss.sum() 10 | 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | 19 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 20 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 21 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 22 | 23 | mu1_sq = mu1.pow(2) 24 | mu2_sq = mu2.pow(2) 25 | mu1_mu2 = mu1 * mu2 26 | 27 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 28 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 29 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 30 | 31 | C1 = 0.01 ** 2 32 | C2 = 0.03 ** 2 33 | 34 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 35 | 36 | if size_average: 37 | return ssim_map.mean() 38 | else: 39 | return ssim_map.mean(1).mean(1).mean(1) 40 | 41 | 42 | class SSIM(torch.nn.Module): 43 | def __init__(self, window_size=11, size_average=True): 44 | super(SSIM, self).__init__() 45 | self.window_size = window_size 46 | self.size_average = size_average 47 | self.channel = 1 48 | self.window = create_window(window_size, self.channel) 49 | 50 | def forward(self, img1, img2): 51 | (_, channel, _, _) = img1.size() 52 | 53 | if channel == self.channel and self.window.data.type() == img1.data.type(): 54 | window = self.window 55 | else: 56 | window = create_window(self.window_size, channel) 57 | 58 | if img1.is_cuda: 59 | window = window.cuda(img1.get_device()) 60 | window = window.type_as(img1) 61 | 62 | self.window = window 63 | self.channel = channel 64 | 65 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 66 | 67 | 68 | def ssim(img1, img2, window_size=11, size_average=True): 69 | (_, channel, _, _) = img1.size() 70 | window = create_window(window_size, channel) 71 | 72 | if img1.is_cuda: 73 | window = window.cuda(img1.get_device()) 74 | window = window.type_as(img1) 75 | 76 | return _ssim(img1, img2, window, window_size, channel, size_average) 77 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | import numpy as np 5 | from utils import Saver 6 | from eval import calc_score 7 | 8 | import torch 9 | from torch.nn import init 10 | from torch.autograd import Variable 11 | from torchvision.transforms import ToPILImage 12 | 13 | 14 | def backup_codes(args): 15 | import shutil 16 | import glob 17 | out_dir = os.path.join(args.save_dir, args.exp_name, 'codes') 18 | if not os.path.exists(out_dir): 19 | os.makedirs(out_dir) 20 | pyfiles = glob.glob("./*.py") 21 | for pf in pyfiles: 22 | shutil.copy2(pf, out_dir) 23 | 24 | 25 | def restore(args, model): 26 | # load 27 | saver = Saver(args) 28 | return saver.load(model) 29 | 30 | 31 | def print_lr(optimizer): 32 | for param_group in optimizer.param_groups: 33 | print(param_group['lr']) 34 | 35 | 36 | def network_paras(model): 37 | # compute only trainable 38 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 39 | params = sum([np.prod(p.size()) for p in model_parameters]) 40 | return params 41 | 42 | 43 | def run_benchmark(args, model, dir_benchmark, out_dir='default', is_compare=True): 44 | import pandas as pd 45 | from data import BenchmarkSet 46 | 47 | # mkdir 48 | if out_dir == 'default': 49 | out_dir = os.path.join(args.save_dir, args.exp_name, 'benchmark_result') 50 | if not os.path.exists(out_dir): 51 | os.makedirs(out_dir) 52 | 53 | # misc 54 | benchmarks = ['Set5', 'Set14', 'BSD100', 'Urban100'] 55 | scale_str = 'image_SRF_' + str(args.scale) 56 | result = {'psnr': [], 'ssim': []} 57 | 58 | # start running 59 | print('{:=^40}'.format(' testing benchmarks ')) 60 | for benchmark in benchmarks: 61 | print('[%s]' % benchmark) 62 | 63 | # foler arrangement 64 | root = os.path.join(dir_benchmark, benchmark, scale_str) 65 | save_dir = os.path.join(out_dir, root[len(dir_benchmark)+1:]) 66 | if not os.path.exists(save_dir): 67 | os.makedirs(save_dir) 68 | 69 | # get set 70 | benchmark_set = BenchmarkSet(args, root) 71 | 72 | # testing 73 | psnr, ssim = test(args, model, benchmark_set, out_dir=out_dir) 74 | 75 | # append 76 | result['psnr'].append(psnr) 77 | result['ssim'].append(ssim) 78 | 79 | if is_compare: 80 | data_frame = pd.read_csv(os.path.join(dir_benchmark, benchmark+'.csv')) 81 | new_row = pd.Series({'psnr': psnr, 'ssim': ssim, 'method': args.exp_name}, name='new') 82 | data_frame = data_frame.append(new_row) 83 | data_frame.to_csv(os.path.join(out_dir, benchmark+'.csv')) 84 | 85 | 86 | def test(args, model, test_set, out_dir='default', is_save=True): 87 | test_loader = torch.utils.data.DataLoader( 88 | test_set, 89 | batch_size=1, 90 | shuffle=False, 91 | num_workers=int(args.num_threads)) 92 | 93 | # mkdir 94 | if out_dir == 'default': 95 | out_dir = os.path.join(args.save_dir, args.exp_name, 'testing_result') 96 | if not os.path.exists(out_dir): 97 | os.makedirs(out_dir) 98 | 99 | # misc 100 | total_psnr = 0.0 101 | total_ssim = 0.0 102 | flag_loss = False 103 | num_img = test_set.__len__() 104 | 105 | # ensurance 106 | args.need_patch = False 107 | model.cuda() 108 | model.eval() 109 | 110 | # start testing 111 | print('{:=^40}'.format(' testing start ')) 112 | time_start = time.time() 113 | with torch.no_grad(): 114 | for idx, (fn, im_lr, im_hr) in enumerate(test_loader): 115 | # forward 116 | im_lr = Variable(im_lr.cuda(), volatile=False) 117 | output = model(im_lr) 118 | 119 | # clip value range [0.0, 1.0] 120 | output = torch.clamp(output, min=0.0, max=1.0).cpu() 121 | 122 | # to PIL 123 | pil = ToPILImage()(torch.squeeze(output, 0)) 124 | pil_lr = ToPILImage()(torch.squeeze(im_lr.cpu(), 0)) 125 | print('(%d/%d) size: %s -> %s' % ( 126 | idx, 127 | num_img, 128 | 'x'.join(map(str, list(pil_lr.size))), 129 | 'x'.join(map(str, list(pil.size))))) 130 | 131 | # save PIL 132 | if out_dir is not None: 133 | out_path = os.path.join(out_dir, fn[0]) 134 | print(' => %s' % out_path) 135 | pil.save(out_path) 136 | 137 | # compute loss 138 | if im_hr is not None: 139 | flag_loss = True 140 | im_hr = Variable(im_hr.cuda()) 141 | psnr, ssim = calc_score(output, im_hr) 142 | total_psnr += psnr 143 | total_ssim += ssim 144 | print(' psnr: %.5f, ssim: %.5f\n' % (psnr, ssim)) 145 | 146 | print('{:=^40}'.format(' Finish ')) 147 | runtime = time.time() - time_start 148 | print('testing time:', str(datetime.timedelta(seconds=runtime))+'\n') 149 | 150 | # record results 151 | if flag_loss: 152 | psnr_mean = total_psnr/num_img 153 | ssim_mean = total_ssim/num_img 154 | log = 'psnr: %.6f\nssim: %.6f\n' % (psnr_mean, ssim_mean) 155 | print(log) 156 | if is_save: 157 | log_file = os.path.join(out_dir, 'eval_result.txt') 158 | with open(log_file, "w") as text_file: 159 | text_file.write(log) 160 | return psnr_mean, ssim_mean 161 | 162 | 163 | def train(args, model, train_set): 164 | # to cuda 165 | model.cuda() 166 | model.train() 167 | 168 | # dataloader 169 | train_loader = torch.utils.data.DataLoader( 170 | train_set, 171 | batch_size=args.batch_size, 172 | drop_last=True, 173 | shuffle=True, 174 | num_workers=int(args.num_threads)) 175 | 176 | # optimizer 177 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 178 | scheduler = torch.optim.lr_scheduler.StepLR( 179 | optimizer, step_size=args.scheduler_step_size, gamma=args.scheduler_gamma) 180 | 181 | # saver 182 | saver = Saver(args) 183 | 184 | # loss function 185 | criterion = torch.nn.L1Loss() 186 | 187 | # time 188 | time_start_train = time.time() 189 | 190 | # misc 191 | num_batch = train_set.__len__() // args.batch_size 192 | counter = 0 193 | backup_codes(args) 194 | 195 | # compute paras 196 | params = network_paras(model) 197 | log = "num of parameters: {:,}".format(params) 198 | saver.save_log(log) 199 | print(log) 200 | 201 | # init weights 202 | def weights_init(m): 203 | if isinstance(m, torch.nn.Conv2d): 204 | init.kaiming_normal_(m.weight.data) 205 | 206 | if not args.is_finetuning: 207 | model.apply(weights_init) 208 | 209 | # start training 210 | print('{:=^40}'.format(' training start ')) 211 | for epoch in range(args.epochs): 212 | scheduler.step(epoch) 213 | running_loss = 0.0 214 | for bidx, (_, im_lr, im_hr) in enumerate(train_loader): 215 | im_lr = Variable(im_lr.cuda(), volatile=False) 216 | im_hr = Variable(im_hr.cuda()) 217 | 218 | # zero the parameter gradients 219 | model.zero_grad() 220 | 221 | # forward 222 | output = model(im_lr) 223 | 224 | # loss 225 | loss = criterion(output, im_hr) 226 | 227 | # backward & update 228 | loss.backward() 229 | optimizer.step() 230 | 231 | # accumulate running loss 232 | running_loss += loss.cpu().item() 233 | 234 | # print for every N batch 235 | if counter % args.step_print_loss == 0: 236 | # time 237 | acc_time = time.time() - time_start_train 238 | 239 | # log 240 | log = 'epoch: (%d/%d) [%5d/%5d], loss: %.6f | time: %s' % \ 241 | (epoch, args.epochs, bidx, num_batch, running_loss, str(datetime.timedelta(seconds=acc_time))) 242 | 243 | print(log) 244 | saver.save_log(log) 245 | running_loss = 0.0 246 | 247 | print_lr(optimizer) 248 | 249 | if counter and counter % args.step_save == 0: 250 | # save 251 | saver.save_model(model) 252 | 253 | # counter increment 254 | counter += 1 255 | 256 | print('{:=^40}'.format(' Finish ')) 257 | runtime = time.time() - time_start_train 258 | print('training time:', str(datetime.timedelta(seconds=runtime))+'\n\n') 259 | -------------------------------------------------------------------------------- /src/try_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | 5 | 6 | 7 | for epoch in range(0, 10000, 2000): 8 | print(set_lr(0.0001, epoch, 2000)) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | 6 | 7 | class Saver(object): 8 | def __init__(self, args): 9 | self.args = args 10 | self.save_dir = os.path.join(args.save_dir, args.exp_name) 11 | 12 | # folder for saving 13 | if not os.path.exists(self.save_dir): 14 | os.makedirs(self.save_dir) 15 | 16 | # log 17 | log_file = os.path.join(self.save_dir, 'log.txt') 18 | if os.path.exists(log_file): 19 | self.logFile = open(log_file, 'a') 20 | else: 21 | self.logFile = open(log_file, 'w') 22 | 23 | def save_log(self, log): 24 | self.logFile.write(log + '\n') 25 | 26 | def save_model(self, model, name='model'): 27 | torch.save( 28 | model.state_dict(), 29 | os.path.join(self.save_dir, name+'_state.pt')) 30 | # torch.save( 31 | # model, 32 | # os.path.join(self.save_dir, name+'_obj.pt')) 33 | 34 | def load(self, model, name='model'): 35 | load_path = os.path.join(self.save_dir, name+'_state.pt') 36 | model.load_state_dict(torch.load(load_path)) 37 | print(" [*] loaded from %s" % load_path) 38 | return model 39 | 40 | 41 | def traverse_dir( 42 | root_dir, 43 | extension=('.jpg', 'png'), 44 | str_=None, 45 | is_pure=False, 46 | verbose=False, 47 | is_sort=False, 48 | is_ext=True): 49 | """ 50 | Evaluate two images. The inputs are specified by file names 51 | """ 52 | if verbose: 53 | print('[*] Scanning...') 54 | file_list = [] 55 | for root, dirs, files in os.walk(root_dir): 56 | for file in files: 57 | if file.endswith(extension): 58 | if str_ is not None: 59 | if str_ not in file: 60 | continue 61 | mix_path = os.path.join(root, file) 62 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 63 | if not is_ext: 64 | ext = pure_path.split('.')[-1] 65 | pure_path = pure_path[:-(len(ext)+1)] 66 | if verbose: 67 | print(pure_path) 68 | file_list.append(pure_path) 69 | if verbose: 70 | print('Total: %d images' % len(file_list)) 71 | print('Done!!!') 72 | if is_sort: 73 | file_list.sort() 74 | return file_list 75 | 76 | 77 | ycbcr_para = np.array( 78 | [[65.481, 128.553, 24.966], 79 | [-37.797, -74.203, 112.0], 80 | [112.0, -93.786, -18.214]]) 81 | 82 | 83 | def rgb2ycbcr(npy): 84 | """ 85 | Customized rgb to tcbcr function. 86 | """ 87 | shape = npy.shape 88 | if len(shape) == 3: 89 | npy = npy.reshape((shape[0] * shape[1], 3)) 90 | ycbcr = np.dot(npy, ycbcr_para.transpose() / 255.) 91 | ycbcr[:, 0] += 16. 92 | ycbcr[:, 1:] += 128. 93 | return ycbcr.reshape(shape) 94 | 95 | 96 | def rgb2y_channel(npy): 97 | """ 98 | Convert RGB to Y channel 99 | """ 100 | npy_ycbcr = rgb2ycbcr(npy) 101 | npy_y = npy_ycbcr[:, :, 0] 102 | return npy_y 103 | 104 | 105 | def get_patch(img_lr, img_hr, patch_size, scale): 106 | (ih, iw, c) = img_lr.shape 107 | tp = patch_size 108 | ip = tp // scale 109 | ix = random.randrange(0, iw - ip + 1) 110 | iy = random.randrange(0, ih - ip + 1) 111 | (tx, ty) = (scale * ix, scale * iy) 112 | img_lr = img_lr[iy:iy + ip, ix:ix + ip, :] 113 | img_hr = img_hr[ty:ty + tp, tx:tx + tp, :] 114 | return img_lr, img_hr 115 | 116 | 117 | # codes for data augmentation 118 | PARA_LIST = [(0, 0), (0, 1), (0, 2), (0, 3), 119 | (1, 0), (1, 1), (1, 2), (1, 3)] 120 | 121 | PARA_LIST_INV = [(0, 0), (0, 3), (0, 2), (0, 1), 122 | (1, 0), (1, 1), (1, 2), (1, 3)] 123 | 124 | 125 | def flip_vert(img): 126 | return img[::-1, :, :] 127 | 128 | 129 | def flip_hori(img): 130 | return img[:, ::-1, :] 131 | 132 | 133 | def rot(img, k=1): 134 | return np.rot90(img, k) 135 | 136 | 137 | def augment_random(img1, img2): 138 | rand = random.randint(0, 7) 139 | p, k = PARA_LIST[rand] 140 | if p: 141 | img1 = flip_vert(img1) 142 | img2 = flip_vert(img2) 143 | return rot(img1, k), rot(img2, k) 144 | 145 | 146 | ''' 147 | Self-ensemble (from EDSR): 148 | First, 149 | 150 | ''' 151 | 152 | 153 | def augment_all(img): 154 | proc_list = [] 155 | for p, k in PARA_LIST: 156 | if p: 157 | proc_img = flip_vert(img) 158 | else: 159 | proc_img = img[:, :, :] 160 | proc_list.append(rot(proc_img, k)) 161 | return proc_list 162 | 163 | 164 | def self_ensemble(aug_list): 165 | proc_list = [] 166 | sum_all = np.zeros_like(aug_list[0], dtype=float) 167 | for idx, (p, k) in enumerate(PARA_LIST_INV): 168 | img = aug_list[idx] 169 | if p: 170 | proc_img = flip_vert(img) 171 | else: 172 | proc_img = img 173 | result = rot(proc_img, k) 174 | sum_all = np.add(sum_all, result) 175 | proc_list.append(result) 176 | return sum_all/8.0, proc_list 177 | --------------------------------------------------------------------------------