├── .gitignore ├── README.md ├── dataset.py ├── edsr.py ├── generateTrainData.py └── srresnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SISRBNN 2 | Try to implement CVPR 2019 "Efficient Super Resolution Using Binarized Neural Network" by PyTorch. 3 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | -------------------------------------------------------------------------------- /edsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ResBlock(nn.Module): 6 | def __init__(self, num_features, scaling_factor=0.1): 7 | super(ResBlock, self).__init__() 8 | self.scaling_factor = scaling_factor 9 | self.feas = nn.Sequential( 10 | nn.Conv2d(num_features, num_features, 3, padding=1), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(num_features, num_features, 3, padding=1) 13 | ) 14 | 15 | def forward(self, x): 16 | return x + self.scaling_factor*self.feas(x) 17 | 18 | 19 | class PixShuffleBlock(nn.Module): 20 | def __init__(self, in_features, upscale_factor=2): 21 | super(PixShuffleBlock, self).__init__() 22 | self.conv = nn.Conv2d(in_features, in_features*upscale_factor*upscale_factor, 3, padding=1) 23 | self.up = nn.PixelShuffle(upscale_factor) 24 | 25 | def forward(self, x): 26 | fea = self.conv(x) 27 | fea = self.up(fea) 28 | return fea 29 | 30 | 31 | class EDSR(nn.Module): 32 | def __init__(self, num_resblocks, num_features, scaling_factor): 33 | super(EDSR, self).__init__() 34 | self.conv = nn.Conv2d(3, num_features, 3, padding=1) 35 | self.res_blocks = nn.Sequential( 36 | *[ResBlock(num_features) for i in range(num_resblocks)] 37 | ) 38 | self.up_sampler = PixShuffleBlock(num_features, upscale_factor=scaling_factor) 39 | self.output = nn.Conv2d(num_features, 3, 3, padding=1) 40 | 41 | def forward(self, x): 42 | fea = self.conv(x) 43 | fea = self.res_blocks(fea) 44 | fea = self.up_sampler(fea) 45 | fea = self.output(fea) 46 | return fea 47 | 48 | 49 | if __name__=="__main__": 50 | net = EDSR(32, 256, 4) 51 | net.cuda() 52 | lr = torch.rand(4,3,64,64).cuda() 53 | print(lr.shape) 54 | hr = net(lr) 55 | print(hr.shape) -------------------------------------------------------------------------------- /generateTrainData.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import h5py 4 | import cv2 5 | import numpy as np 6 | 7 | 8 | 9 | def normalize(img, mean_value): 10 | return img - mean_value 11 | 12 | def get_mean(root_path, file_names): 13 | all_mean_value = [0., 0., 0.] 14 | for i, file_name in enumerate(file_names): 15 | img = cv2.imread(os.path.join(root_path, file_name)) 16 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 17 | img = np.transpose(img, [2,0,1])/255. 18 | mean_value = img.mean(-1).mean(-1) 19 | print('img {}, mean value {}'.format(i, mean_value)) 20 | all_mean_value += mean_value 21 | print('all {} imgs, mean value {}'.format(len(file_names), all_mean_value/len(file_names))) 22 | return all_mean_value/len(file_names) 23 | 24 | def get_img(file_path, mean_value): 25 | img = cv2.imread(file_path) 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = np.transpose(img, [2,0,1])/255. 28 | img = normalize(img, np.array(mean_value)[:, np.newaxis, np.newaxis]) 29 | return img 30 | 31 | def generate_h5(hr_path, lr_path, scaling_factor, h5_name): 32 | hr_file_names = glob.glob(os.path.join(hr_path, '*.png')) 33 | lr_file_name_tem = '{}' + scaling_factor + '.png' 34 | # mean_value = get_mean(root_path, file_names) 35 | mean_value = [0.44845608, 0.43749626, 0.40452776] 36 | h5f = h5py.File(os.path.join(hr_path, h5_name), mode='w') 37 | for i, hr_file_name in enumerate(hr_file_names): 38 | hr = get_img(os.path.join(hr_path, hr_file_name), mean_value) 39 | lr_file_name = lr_file_name_tem.format(hr_file_name.split('/')[-1].split('.')[0]) 40 | print('here') 41 | print(os.path.join(lr_path, scaling_factor, lr_file_name)) 42 | lr = get_img(os.path.join(lr_path, scaling_factor.upper(), lr_file_name), mean_value) 43 | print(hr.shape, hr.max(), hr.mean(), hr.min()) 44 | print(lr.shape, lr.max(), lr.mean(), lr.min()) 45 | exit() 46 | 47 | 48 | if __name__=="__main__": 49 | hr_path = '/data0/langzhiqiang/DIV2K/DIV2K_train_HR/' 50 | lr_path = '/data0/langzhiqiang/DIV2K/DIV2K_train_LR_bicubic/' 51 | scaling_factor = 'x4' 52 | h5_name = 'train.h5' 53 | generate_h5(hr_path, lr_path, scaling_factor, h5_name) -------------------------------------------------------------------------------- /srresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | 6 | class ResBlock(nn.Module): 7 | def __init__(self, num_features): 8 | super(ResBlock, self).__init__() 9 | self.feas = nn.Sequential( 10 | nn.Conv2d(num_features, num_features, 3, padding=1), 11 | nn.BatchNorm2d(num_features), nn.PReLU(), 12 | nn.Conv2d(num_features, num_features, 3, padding=1), 13 | nn.BatchNorm2d(num_features)) 14 | 15 | def forward(self, x): 16 | return x + self.feas(x) 17 | 18 | 19 | class PixShuffleBlock(nn.Module): 20 | def __init__(self, in_features, upscale_factor=2): 21 | super(PixShuffleBlock, self).__init__() 22 | self.conv = nn.Conv2d( 23 | in_features, 24 | in_features * upscale_factor * upscale_factor, 25 | 3, 26 | padding=1) 27 | self.up = nn.PixelShuffle(upscale_factor) 28 | self.prelu = nn.PReLU() 29 | 30 | def forward(self, x): 31 | fea = self.conv(x) 32 | fea = self.up(fea) 33 | fea = self.prelu(fea) 34 | return fea 35 | 36 | 37 | class SRResNet(nn.Module): 38 | def __init__(self, num_resblocks, num_features, scaling_factor): 39 | super(SRResNet, self).__init__() 40 | self.conv1 = nn.Conv2d(3, num_features, 3, padding=1) 41 | self.prelu = nn.PReLU() 42 | self.res_blocks = nn.Sequential( 43 | *[ResBlock(num_features) for i in range(num_resblocks)]) 44 | self.conv2 = nn.Conv2d(num_features, num_features, 3, padding=1) 45 | self.bn = nn.BatchNorm2d(num_features) 46 | self.pix_blocks = nn.Sequential(*[ 47 | PixShuffleBlock(num_features) 48 | for i in range(int(math.log2(scaling_factor))) 49 | ]) 50 | self.conv3 = nn.Conv2d(num_features, 3, 1, padding=1) 51 | 52 | def forward(self, x): 53 | fea_0 = self.prelu(self.conv1(x)) 54 | fea = self.res_blocks(fea_0) 55 | fea = fea_0 + self.bn(self.conv2(fea)) 56 | fea = self.pix_blocks(fea) 57 | fea = self.conv3(fea) 58 | return fea 59 | 60 | 61 | if __name__ == "__main__": 62 | net = SRResNet(6, 64, 4) 63 | x = torch.rand([6, 3, 64, 64]) 64 | y = net(x) 65 | print(x.shape) 66 | print(y.shape) --------------------------------------------------------------------------------