├── .gitignore ├── LICENSE ├── README.md ├── dataset.py ├── main.py ├── networks.py ├── results.png ├── run.sh └── wavelet_weights_c2.pkl /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 hhb072 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaveletSRNet 2 | A pytorch implementation of Paper ["Wavelet-srnet: A wavelet-based cnn for multi-scale face super resolution"](http://openaccess.thecvf.com/content_iccv_2017/html/Huang_Wavelet-SRNet_A_Wavelet-Based_ICCV_2017_paper.html) 3 | 4 | ## Prerequisites 5 | * Python 2.7 6 | * PyTorch 7 | 8 | ## Run 9 | 10 | Use the default hyparameters except changing the parameter "upscale" according to the expected upscaling factor(2, 3, 4 for 4, 8, 16 upcaling factors, respectively). 11 | 12 | >CUDA_VISIBLE_DEVICES=1 python main.py --ngpu=1 --test --start_epoch=0 --test_iter=1000 --batchSize=64 --test_batchSize=32 --nrow=4 --upscale=3 --input_height=128 --output_height=128 --crop_height=128 --lr=2e-4 --nEpochs=500 --cuda 13 | 14 | ## Results 15 | 16 | ![](https://github.com/hhb072/WaveletSRNet/blob/master/results.png) 17 | 18 | ## Citation 19 | 20 | If you use our codes, please cite the following paper: 21 | 22 | @inproceedings{huang2017wavelet, 23 | title={Wavelet-srnet: A wavelet-based cnn for multi-scale face super resolution}, 24 | author={Huang, Huaibo and He, Ran and Sun, Zhenan and Tan, Tieniu}, 25 | booktitle={IEEE International Conference on Computer Vision}, 26 | pages={1689--1697},     27 | year={2017} 28 | } 29 | 30 | **The released codes are only allowed for non-commercial use.** 31 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | 4 | from os import listdir 5 | from os.path import join 6 | from PIL import Image, ImageOps 7 | import random 8 | import torchvision.transforms as transforms 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) 12 | 13 | def readlinesFromFile(path, datasize): 14 | print("Load from file %s" % path) 15 | f=open(path) 16 | data=[] 17 | for idx in xrange(0, datasize): 18 | line = f.readline() 19 | data.append(line) 20 | 21 | f.close() 22 | return data 23 | 24 | def loadFromFile(path, datasize): 25 | if path is None: 26 | return None, None 27 | 28 | print("Load from file %s" % path) 29 | f=open(path) 30 | data=[] 31 | label=[] 32 | for idx in xrange(0, datasize): 33 | line = f.readline().split() 34 | data.append(line[0]) 35 | label.append(line[1]) 36 | 37 | f.close() 38 | return data, label 39 | 40 | def load_video_image(file_path, input_height=128, input_width=None, output_height=128, output_width=None, 41 | crop_height=None, crop_width=None, is_random_crop=True, is_mirror=True, 42 | is_gray=False, scale=1.0, is_scale_back=False): 43 | 44 | if input_width is None: 45 | input_width = input_height 46 | if output_width is None: 47 | output_width = output_height 48 | if crop_width is None: 49 | crop_width = crop_height 50 | 51 | img = Image.open(file_path) 52 | if is_gray is False and img.mode is not 'RGB': 53 | img = img.convert('RGB') 54 | if is_gray and img.mode is not 'L': 55 | img = img.convert('L') 56 | 57 | if is_mirror and random.randint(0,1) is 0: 58 | img = ImageOps.mirror(img) 59 | 60 | if input_height is not None: 61 | img = img.resize((input_width, input_height),Image.BICUBIC) 62 | 63 | if crop_height is not None: 64 | [w, h] = img.size 65 | if is_random_crop: 66 | #print([w,cropSize]) 67 | cx1 = random.randint(0, w-crop_width) 68 | cx2 = w - crop_width - cx1 69 | cy1 = random.randint(0, h-crop_height) 70 | cy2 = h - crop_height - cy1 71 | else: 72 | cx2 = cx1 = int(round((w-crop_width)/2.)) 73 | cy2 = cy1 = int(round((h-crop_height)/2.)) 74 | img = ImageOps.crop(img, (cx1, cy1, cx2, cy2)) 75 | 76 | #print(scale) 77 | img = img.resize((output_width, output_height),Image.BICUBIC) 78 | img_lr = img.resize((int(output_width/scale),int(output_height/scale)),Image.BICUBIC) 79 | if is_scale_back: 80 | return img_lr.resize((output_width, output_height),Image.BICUBIC), img 81 | else: 82 | return img_lr, img 83 | 84 | 85 | class ImageDatasetFromFile(data.Dataset): 86 | def __init__(self, image_list, root_path, input_height=128, input_width=None, output_height=128, output_width=None, 87 | crop_height=None, crop_width=None, is_random_crop=True, is_mirror=True, 88 | is_gray=False, upscale=1.0, is_scale_back=False): 89 | super(ImageDatasetFromFile, self).__init__() 90 | 91 | self.image_filenames = image_list 92 | self.upscale = upscale 93 | self.is_random_crop = is_random_crop 94 | self.is_mirror = is_mirror 95 | self.input_height = input_height 96 | self.input_width = input_width 97 | self.output_height = output_height 98 | self.output_width = output_width 99 | self.root_path = root_path 100 | self.crop_height = crop_height 101 | self.crop_width = crop_width 102 | self.is_scale_back = is_scale_back 103 | self.is_gray = is_gray 104 | 105 | self.input_transform = transforms.Compose([ 106 | transforms.ToTensor() 107 | ]) 108 | 109 | def __getitem__(self, index): 110 | 111 | if self.is_mirror: 112 | is_mirror = random.randint(0,1) is 0 113 | else: 114 | is_mirror = False 115 | 116 | lr, hr = load_video_image(join(self.root_path, self.image_filenames[index]), 117 | self.input_height, self.input_width, self.output_height, self.output_width, 118 | self.crop_height, self.crop_width, self.is_random_crop, is_mirror, 119 | self.is_gray, self.upscale, self.is_scale_back) 120 | 121 | 122 | input = self.input_transform(lr) 123 | target = self.input_transform(hr) 124 | 125 | return input, target 126 | 127 | def __len__(self): 128 | return len(self.image_filenames) 129 | 130 | 131 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim as optim 10 | import torch.utils.data 11 | import torchvision.datasets as dset 12 | from torch.utils.data import DataLoader 13 | from dataset import * 14 | import time 15 | import numpy as np 16 | import torchvision.utils as vutils 17 | from torch.autograd import Variable 18 | from networks import * 19 | from math import log10 20 | import torchvision 21 | import cv2 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--test', action='store_true', help='enables test during training') 25 | parser.add_argument('--mse_avg', action='store_true', help='enables mse avg') 26 | parser.add_argument('--num_layers_res', type=int, help='number of the layers in residual block', default=2) 27 | parser.add_argument('--nrow', type=int, help='number of the rows to save images', default=10) 28 | parser.add_argument('--trainfiles', default="path/celeba/train.list", type=str, help='the list of training files') 29 | parser.add_argument('--dataroot', default="path/celeba", type=str, help='path to dataset') 30 | parser.add_argument('--testfiles', default="path/test.list", type=str, help='the list of training files') 31 | parser.add_argument('--testroot', default="path/celeba", type=str, help='path to dataset') 32 | parser.add_argument('--trainsize', type=int, help='number of training data', default=162770) 33 | parser.add_argument('--testsize', type=int, help='number of testing data', default=19962) 34 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 35 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 36 | parser.add_argument('--test_batchSize', type=int, default=64, help='test batch size') 37 | parser.add_argument('--save_iter', type=int, default=10, help='the interval iterations for saving models') 38 | parser.add_argument('--test_iter', type=int, default=500, help='the interval iterations for testing') 39 | parser.add_argument('--cdim', type=int, default=3, help='the channel-size of the input image to network') 40 | parser.add_argument('--input_height', type=int, default=128, help='the height of the input image to network') 41 | parser.add_argument('--input_width', type=int, default=None, help='the width of the input image to network') 42 | parser.add_argument('--output_height', type=int, default=128, help='the height of the output image to network') 43 | parser.add_argument('--output_width', type=int, default=None, help='the width of the output image to network') 44 | parser.add_argument('--crop_height', type=int, default=None, help='the width of the output image to network') 45 | parser.add_argument('--crop_width', type=int, default=None, help='the width of the output image to network') 46 | parser.add_argument('--upscale', type=int, default=2, help='the depth of wavelet tranform') 47 | parser.add_argument('--scale_back', action='store_true', help='enables scale_back') 48 | parser.add_argument("--nEpochs", type=int, default=500, help="number of epochs to train for") 49 | parser.add_argument("--start_epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 50 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') 51 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 52 | parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") 53 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 54 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') 55 | parser.add_argument('--outf', default='results/', help='folder to output images') 56 | parser.add_argument('--manualSeed', type=int, help='manual seed') 57 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)") 58 | 59 | def main(): 60 | 61 | global opt, model 62 | opt = parser.parse_args() 63 | print(opt) 64 | 65 | try: 66 | os.makedirs(opt.outf) 67 | except OSError: 68 | pass 69 | 70 | if opt.manualSeed is None: 71 | opt.manualSeed = random.randint(1, 10000) 72 | print("Random Seed: ", opt.manualSeed) 73 | random.seed(opt.manualSeed) 74 | torch.manual_seed(opt.manualSeed) 75 | if opt.cuda: 76 | torch.cuda.manual_seed_all(opt.manualSeed) 77 | 78 | cudnn.benchmark = True 79 | 80 | if torch.cuda.is_available() and not opt.cuda: 81 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 82 | 83 | 84 | ngpu = int(opt.ngpu) 85 | nc = opt.cdim 86 | mag = int(math.pow(2, opt.upscale)) 87 | groups = mag ** 2 88 | if opt.scale_back: 89 | is_scale_back = True 90 | else: 91 | is_scale_back = False 92 | 93 | #--------------build models-------------------------- 94 | srnet = NetSR(opt.upscale, num_layers_res=opt.num_layers_res) 95 | if opt.pretrained: 96 | if os.path.isfile(opt.pretrained): 97 | print("=> loading model '{}'".format(opt.pretrained)) 98 | weights = torch.load(opt.pretrained) 99 | pretrained_dict = weights['model'].state_dict() 100 | model_dict = srnet.state_dict() 101 | # print(model_dict) 102 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 103 | model_dict.update(pretrained_dict) 104 | # 3. load the new state dict 105 | srnet.load_state_dict(model_dict) 106 | # srnet.load_state_dict(weights['model'].state_dict()) 107 | else: 108 | print("=> no model found at '{}'".format(opt.pretrained)) 109 | print(srnet) 110 | 111 | wavelet_dec = WaveletTransform(scale=opt.upscale, dec=True) 112 | wavelet_rec = WaveletTransform(scale=opt.upscale, dec=False) 113 | 114 | criterion_m = nn.MSELoss(size_average=True) 115 | 116 | if opt.cuda: 117 | srnet = srnet.cuda() 118 | wavelet_dec = wavelet_dec.cuda() 119 | wavelet_rec = wavelet_rec.cuda() 120 | criterion_m = criterion_m.cuda() 121 | 122 | 123 | optimizer_sr = optim.Adam(srnet.parameters(), lr=opt.lr, betas=(opt.momentum, 0.999), weight_decay=0.0005) 124 | 125 | #-----------------load dataset-------------------------- 126 | train_list, _ = loadFromFile(opt.trainfiles, opt.trainsize) 127 | train_set = ImageDatasetFromFile(train_list, opt.dataroot, 128 | input_height=opt.input_height, input_width=opt.input_width, 129 | output_height=opt.output_height, output_width=opt.output_width, 130 | crop_height=opt.crop_height, crop_width=opt.crop_width, 131 | is_random_crop=True, is_mirror=True, is_gray=False, 132 | upscale=mag, is_scale_back=is_scale_back) 133 | train_data_loader = torch.utils.data.DataLoader(train_set, batch_size=opt.batchSize, 134 | shuffle=True, num_workers=int(opt.workers)) 135 | 136 | test_list, _ = loadFromFile(opt.testfiles, opt.testsize) 137 | test_set = ImageDatasetFromFile(test_list, opt.testroot, 138 | input_height=opt.output_height, input_width=opt.output_width, 139 | output_height=opt.output_height, output_width=opt.output_width, 140 | crop_height=None, crop_width=None, 141 | is_random_crop=False, is_mirror=False, is_gray=False, 142 | upscale=mag, is_scale_back=is_scale_back) 143 | test_data_loader = torch.utils.data.DataLoader(test_set, batch_size=opt.test_batchSize, 144 | shuffle=False, num_workers=int(opt.workers)) 145 | 146 | 147 | 148 | start_time = time.time() 149 | srnet.train() 150 | #----------------Train by epochs-------------------------- 151 | for epoch in range(opt.start_epoch, opt.nEpochs + 1): 152 | if epoch%opt.save_iter == 0: 153 | save_checkpoint(srnet, epoch, 0, 'sr_') 154 | 155 | for iteration, batch in enumerate(train_data_loader, 0): 156 | #--------------test------------- 157 | if iteration % opt.test_iter is 0 and opt.test: 158 | srnet.eval() 159 | avg_psnr = 0 160 | for titer, batch in enumerate(test_data_loader,0): 161 | input, target = Variable(batch[0]), Variable(batch[1]) 162 | if opt.cuda: 163 | input = input.cuda() 164 | target = target.cuda() 165 | 166 | wavelets = forward_parallel(srnet, input, opt.ngpu) 167 | prediction = wavelet_rec(wavelets) 168 | mse = criterion_m(prediction, target) 169 | psnr = 10 * log10(1 / (mse.data[0]) ) 170 | avg_psnr += psnr 171 | 172 | save_images(prediction, "Epoch_{:03d}_Iter_{:06d}_{:02d}_o.jpg".format(epoch, iteration, titer), 173 | path=opt.outf, nrow=opt.nrow) 174 | 175 | 176 | print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(test_data_loader))) 177 | srnet.train() 178 | 179 | #--------------train------------ 180 | input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False) 181 | if opt.cuda: 182 | input = input.cuda() 183 | target = target.cuda() 184 | 185 | target_wavelets = wavelet_dec(target) 186 | 187 | batch_size = target.size(0) 188 | wavelets_lr = target_wavelets[:,0:3,:,:] 189 | wavelets_sr = target_wavelets[:,3:,:,:] 190 | 191 | wavelets_predict = forward_parallel(srnet, input, opt.ngpu) 192 | img_predict = wavelet_rec(wavelets_predict) 193 | 194 | 195 | loss_lr = loss_MSE(wavelets_predict[:,0:3,:,:], wavelets_lr, opt.mse_avg) 196 | loss_sr = loss_MSE(wavelets_predict[:,3:,:,:], wavelets_sr, opt.mse_avg) 197 | loss_textures = loss_Textures(wavelets_predict[:,3:,:,:], wavelets_sr) 198 | loss_img = loss_MSE(img_predict, target, opt.mse_avg) 199 | 200 | loss = loss_sr.mul(0.99) + loss_lr.mul(0.01) + loss_img.mul(0.1) + loss_textures.mul(1) 201 | 202 | optimizer_sr.zero_grad() 203 | loss.backward() 204 | optimizer_sr.step() 205 | 206 | info = "===> Epoch[{}]({}/{}): time: {:4.4f}:".format(epoch, iteration, len(train_data_loader), time.time()-start_time) 207 | info += "Rec: {:.4f}, {:.4f}, {:.4f}, Texture: {:.4f}".format(loss_lr.data[0], loss_sr.data[0], 208 | loss_img.data[0], loss_textures.data[0]) 209 | 210 | print(info) 211 | 212 | 213 | def forward_parallel(net, input, ngpu): 214 | if ngpu > 1: 215 | return nn.parallel.data_parallel(net, input, range(ngpu)) 216 | else: 217 | return net(input) 218 | 219 | def save_checkpoint(model, epoch, iteration, prefix=""): 220 | model_out_path = "model/" + prefix +"model_epoch_{}_iter_{}.pth".format(epoch, iteration) 221 | state = {"epoch": epoch ,"model": model} 222 | if not os.path.exists("model/"): 223 | os.makedirs("model/") 224 | 225 | torch.save(state, model_out_path) 226 | 227 | print("Checkpoint saved to {}".format(model_out_path)) 228 | 229 | def save_images(images, name, path, nrow=10): 230 | #print(images.size()) 231 | img = images.cpu() 232 | im = img.data.numpy().astype(np.float32) 233 | #print(im.shape) 234 | im = im.transpose(0,2,3,1) 235 | imsave(im, [nrow, int(math.ceil(im.shape[0]/float(nrow)))], os.path.join(path, name) ) 236 | 237 | def merge(images, size): 238 | #print(images.shape()) 239 | h, w = images.shape[1], images.shape[2] 240 | img = np.zeros((h * size[0], w * size[1], 3)) 241 | #print(img) 242 | for idx, image in enumerate(images): 243 | image = image * 255 244 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 245 | i = idx % size[1] 246 | j = idx // size[1] 247 | img[j*h:j*h+h, i*w:i*w+w, :] = image 248 | return img 249 | 250 | def imsave(images, size, path): 251 | img = merge(images, size) 252 | # print(img) 253 | return cv2.imwrite(path, img) 254 | 255 | if __name__ == "__main__": 256 | main() 257 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import pickle 5 | from torch.autograd import Variable 6 | 7 | def loss_MSE(x, y, size_average=False): 8 | z = x - y 9 | z2 = z * z 10 | if size_average: 11 | return z2.mean() 12 | else: 13 | return z2.sum().div(x.size(0)*2) 14 | 15 | def loss_Textures(x, y, nc=3, alpha=1.2, margin=0): 16 | xi = x.contiguous().view(x.size(0), -1, nc, x.size(2), x.size(3)) 17 | yi = y.contiguous().view(y.size(0), -1, nc, y.size(2), y.size(3)) 18 | 19 | xi2 = torch.sum(xi * xi, dim=2) 20 | yi2 = torch.sum(yi * yi, dim=2) 21 | 22 | out = nn.functional.relu(yi2.mul(alpha) - xi2 + margin) 23 | 24 | return torch.mean(out) 25 | 26 | 27 | class WaveletTransform(nn.Module): 28 | def __init__(self, scale=1, dec=True, params_path='wavelet_weights_c2.pkl', transpose=True): 29 | super(WaveletTransform, self).__init__() 30 | 31 | self.scale = scale 32 | self.dec = dec 33 | self.transpose = transpose 34 | 35 | ks = int(math.pow(2, self.scale) ) 36 | nc = 3 * ks * ks 37 | 38 | if dec: 39 | self.conv = nn.Conv2d(in_channels=3, out_channels=nc, kernel_size=ks, stride=ks, padding=0, groups=3, bias=False) 40 | else: 41 | self.conv = nn.ConvTranspose2d(in_channels=nc, out_channels=3, kernel_size=ks, stride=ks, padding=0, groups=3, bias=False) 42 | 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 45 | f = file(params_path,'rb') 46 | dct = pickle.load(f) 47 | f.close() 48 | m.weight.data = torch.from_numpy(dct['rec%d' % ks]) 49 | m.weight.requires_grad = False 50 | 51 | def forward(self, x): 52 | if self.dec: 53 | output = self.conv(x) 54 | if self.transpose: 55 | osz = output.size() 56 | #print(osz) 57 | output = output.view(osz[0], 3, -1, osz[2], osz[3]).transpose(1,2).contiguous().view(osz) 58 | else: 59 | if self.transpose: 60 | xx = x 61 | xsz = xx.size() 62 | xx = xx.view(xsz[0], -1, 3, xsz[2], xsz[3]).transpose(1,2).contiguous().view(xsz) 63 | output = self.conv(xx) 64 | return output 65 | 66 | class _Residual_Block(nn.Module): 67 | def __init__(self, inc=64, outc=64, groups=1): 68 | super(_Residual_Block, self).__init__() 69 | 70 | if inc is not outc: 71 | self.conv_expand = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1, padding=0, groups=1, bias=False) 72 | else: 73 | self.conv_expand = None 74 | 75 | self.conv1 = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=3, stride=1, padding=1, groups=groups, bias=False) 76 | self.bn1 = nn.BatchNorm2d(outc) 77 | self.relu1 = nn.ReLU(inplace=True) 78 | self.conv2 = nn.Conv2d(in_channels=outc, out_channels=outc, kernel_size=3, stride=1, padding=1, groups=groups, bias=False) 79 | self.bn2 = nn.BatchNorm2d(outc) 80 | self.relu2 = nn.ReLU(inplace=True) 81 | 82 | def forward(self, x): 83 | if self.conv_expand is not None: 84 | identity_data = self.conv_expand(x) 85 | else: 86 | identity_data = x 87 | 88 | output = self.relu1(self.bn1(self.conv1(x))) 89 | output = self.conv2(output) 90 | output = self.relu2(self.bn2(torch.add(output,identity_data))) 91 | return output 92 | 93 | def make_layer(block, num_of_layer, inc=64, outc=64, groups=1): 94 | layers = [] 95 | layers.append(block(inc=inc, outc=outc, groups=groups)) 96 | for _ in range(1, num_of_layer): 97 | layers.append(block(inc=outc, outc=outc, groups=groups)) 98 | return nn.Sequential(*layers) 99 | 100 | class _Interim_Block(nn.Module): 101 | def __init__(self, inc=64, outc=64, groups=1): 102 | super(_Interim_Block, self).__init__() 103 | 104 | self.conv_expand = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1, padding=0, groups=1, bias=False) 105 | self.conv1 = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=3, stride=1, padding=1, groups=1, bias=False) 106 | self.bn1 = nn.BatchNorm2d(outc) 107 | self.relu1 = nn.ReLU(inplace=True) 108 | self.conv2 = nn.Conv2d(in_channels=outc, out_channels=outc, kernel_size=3, stride=1, padding=1, groups=groups, bias=False) 109 | self.bn2 = nn.BatchNorm2d(outc) 110 | self.relu2 = nn.ReLU(inplace=True) 111 | 112 | def forward(self, x): 113 | identity_data = self.conv_expand(x) 114 | output = self.relu1(self.bn1(self.conv1(x))) 115 | output = self.conv2(output) 116 | output = self.relu2(self.bn2(torch.add(output,identity_data))) 117 | return output 118 | 119 | class NetSR(nn.Module): 120 | def __init__(self, scale=2, num_layers_res=2): 121 | super(NetSR, self).__init__() 122 | 123 | self.scale = int(scale) 124 | self.groups = int(math.pow(4, self.scale)) 125 | self.wavelet_c = wavelet_c = 32 126 | 127 | #----------input conv------------------- 128 | self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 129 | self.bn_input = nn.BatchNorm2d(64) 130 | self.relu_input = nn.ReLU(inplace=True) 131 | 132 | #----------residual------------------- 133 | self.residual = nn.Sequential( 134 | make_layer(_Residual_Block, num_layers_res, inc=64, outc=64), 135 | make_layer(_Residual_Block, num_layers_res, inc=64, outc=128), 136 | make_layer(_Residual_Block, num_layers_res, inc=128, outc=256), 137 | make_layer(_Residual_Block, num_layers_res, inc=256, outc=512), 138 | make_layer(_Residual_Block, num_layers_res, inc=512, outc=1024) 139 | ) 140 | 141 | #----------wavelet conv------------------- 142 | inc = 1024 143 | layer_num = 1 144 | if self.scale >= 0: 145 | g = 1 146 | self.interim_0 = _Interim_Block(inc, wavelet_c * g, g) 147 | self.wavelet_0 = make_layer(_Residual_Block, layer_num, wavelet_c * g, wavelet_c * 2 * g, g) 148 | self.predict_0 = nn.Conv2d(in_channels=wavelet_c * 2 * g, out_channels=3 * g, kernel_size=3, stride=1, padding=1, 149 | groups=g, bias=True) 150 | 151 | if self.scale >= 1: 152 | g = 3 153 | self.interim_1 = _Interim_Block(inc, wavelet_c * g, g) 154 | self.wavelet_1 = make_layer(_Residual_Block, layer_num, wavelet_c * g, wavelet_c * 2 * g, g) 155 | self.predict_1 = nn.Conv2d(in_channels=wavelet_c * 2 * g, out_channels=3 * g, kernel_size=3, stride=1, padding=1, 156 | groups=g, bias=True) 157 | 158 | if self.scale >= 2: 159 | g = 12 160 | self.interim_2 = _Interim_Block(inc, wavelet_c * g, g) 161 | self.wavelet_2 = make_layer(_Residual_Block, layer_num, wavelet_c * g, wavelet_c * 2 * g, g) 162 | self.predict_2 = nn.Conv2d(in_channels=wavelet_c * 2 * g, out_channels=3 * g, kernel_size=3, stride=1, padding=1, 163 | groups=g, bias=True) 164 | 165 | if self.scale >= 3: 166 | g = 48 167 | self.interim_3 = _Interim_Block(inc, wavelet_c * g, g) 168 | self.wavelet_3 = make_layer(_Residual_Block, layer_num, wavelet_c * g, wavelet_c * 2 * g, g) 169 | self.predict_3 = nn.Conv2d(in_channels=wavelet_c * 2 * g, out_channels=3 * g, kernel_size=3, stride=1, padding=1, 170 | groups=g, bias=True) 171 | 172 | if self.scale >= 4: 173 | g = 192 174 | self.interim_4 = _Interim_Block(inc, wavelet_c * g, g) 175 | self.wavelet_4 = make_layer(_Residual_Block, layer_num, wavelet_c * g, wavelet_c * 2 * g, g) 176 | self.predict_4 = nn.Conv2d(in_channels=wavelet_c * 2 * g, out_channels=3 * g, kernel_size=3, stride=1, padding=1, 177 | groups=g, bias=True) 178 | 179 | for m in self.modules(): 180 | if isinstance(m, nn.Conv2d): 181 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 182 | m.weight.data.normal_(0, math.sqrt(2. / n)) 183 | if m.bias is not None: 184 | m.bias.data.zero_() 185 | elif isinstance(m, nn.BatchNorm2d): 186 | m.weight.data.fill_(1) 187 | if m.bias is not None: 188 | m.bias.data.zero_() 189 | 190 | def forward(self, x): 191 | 192 | f = self.relu_input(self.bn_input(self.conv_input(x))) 193 | 194 | f = self.residual(f) 195 | 196 | if self.scale >= 0: 197 | out_0 = self.interim_0(f) 198 | out_0 = self.wavelet_0(out_0) 199 | out_0 = self.predict_0(out_0) 200 | out = out_0 201 | 202 | if self.scale >= 1: 203 | out_1 = self.interim_1(f) 204 | out_1 = self.wavelet_1(out_1) 205 | out_1 = self.predict_1(out_1) 206 | out = torch.cat((out, out_1), 1) 207 | 208 | 209 | if self.scale >= 2: 210 | out_2 = self.interim_2(f) 211 | out_2 = self.wavelet_2(out_2) 212 | out_2 = self.predict_2(out_2) 213 | out = torch.cat((out, out_2), 1) 214 | 215 | if self.scale >= 3: 216 | out_3 = self.interim_3(f) 217 | out_3 = self.wavelet_3(out_3) 218 | out_3 = self.predict_3(out_3) 219 | out = torch.cat((out, out_3), 1) 220 | 221 | if self.scale >= 4: 222 | out_4 = self.interim_4(f) 223 | out_4 = self.wavelet_4(out_4) 224 | out_4 = self.predict_4(out_4) 225 | out = torch.cat((out, out_4), 1) 226 | 227 | return out 228 | -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/WaveletSRNet/f0219900056c505143d9831b44a112453784b2a7/results.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # 3 | #$ -cwd 4 | #$ -j y 5 | #$ -N output_train_lstm 6 | #$ -S /bin/sh 7 | # 8 | 9 | CUDA_VISIBLE_DEVICES=1 python main.py --ngpu=1 --test --start_epoch=0 --test_iter=1000 --batchSize=64 --test_batchSize=32 --nrow=4 --upscale=3 --input_height=128 --output_height=128 --crop_height=128 --lr=2e-4 --nEpochs=500 --cuda -------------------------------------------------------------------------------- /wavelet_weights_c2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/WaveletSRNet/f0219900056c505143d9831b44a112453784b2a7/wavelet_weights_c2.pkl --------------------------------------------------------------------------------