├── README.md ├── correct_imgs.py ├── correction_func.py ├── downsample_imgs.py ├── estimate_correction.py ├── figs ├── SR │ ├── baboon_Gauss_std3.2_x4_s.png │ ├── baboon_Gauss_std3.2_x4_s_x4_corr_corrected.png │ ├── bridge_Gauss_std1.8_x2_s.png │ ├── bridge_Gauss_std1.8_x2_s_x2_corr_corrected.png │ ├── zebra_Gauss_std3.2_x4_s.png │ └── zebra_Gauss_std3.2_x4_s_x4_corr_corrected.png └── blind_SR │ ├── bird.png │ ├── bird_.png │ ├── bird_x2_corr_est.png │ ├── butterfly.png │ ├── butterfly_.png │ ├── butterfly_x2_corr_est.png │ ├── chip.png │ ├── chip_LR.png │ ├── chip_x4_corrected_est.png │ ├── im_31.png │ ├── im_31_x2_corr_est.png │ ├── im_59.png │ ├── im_59_x2_corr_est.png │ ├── im_66.png │ ├── im_66_x2_corr_est.png │ ├── man2_Gauss_std3.2_x4_s.png │ └── man2_Gauss_std3.2_x4_s_x4_corr_l0_est.png └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Correction-Filter 2 | 3 | The official implementation of the work "Correction Filter for Single Image Super-Resolution: Robustifying Off-the-Shelf Deep Super-Resolvers" (https://arxiv.org/abs/1912.00157 , Accepted to CVPR 2020 - oral) 4 | 5 | # Non-Blind 6 | 1. Downsample images and put into folders according to which down-sampling filter was used (it is recommended to use '.mat' for saving the LR images). 7 | 2. Define the sampling and reconstruction basis (s and r) in "correct_imgs.py" (lines 25-33). 8 | 3. Run: python correct_imgs.py --in_dir "Directory to the folder of the LR images" --out_dir "Directory of where to save the corrected images" --scale_factor "SR scale factor". 9 | 4. Run any off-the-shelf deep SR network trained using r (usually bicubic) on the images saved to out_dir 10 | 11 | Note that this code assumes that the images within a folder are sampled using the same kernel. 12 | 13 | # Blind 14 | 1. Define the SR network in "estimate_correction.py". 15 | 2. Run: estimate_correction.py --scale_factor "SR scale factor" --in_dir "Directory of the LR images" --out_dir "Directory to save the LR corrected images in it" 16 | 4. Run any off-the-shelf deep SR network trained using r (usually bicubic) on the images saved to out_dir 17 | 18 | # Citation: 19 | 20 | @ARTICLE{correction_filter, 21 | author = {{Abu Hussein}, Shady and {Tirer}, Tom and {Giryes}, Raja}, 22 | title = "{Correction Filter for Single Image Super-Resolution: Robustifying Off-the-Shelf Deep Super-Resolvers}", 23 | journal = {In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 24 | year = "2020" 25 | } 26 | 27 | # Results 28 | ## Non-Blind Super-Resolution 29 | Non-blind super-resolution with scale factor of 4 on Gaussian model with std 4.5/sqrt(2) (left is DBPN without correction, right is with correction filter) 30 | 31 | 32 | 33 | 34 | 35 | Non-blind super-resolution with scale factor of 2 on Gaussian model with std 2.5/sqrt(2) (left is DBPN without correction, right is with correction filter) 36 | 37 | 38 | 39 | ## Blind Super-Resolution 40 | ### Synthetic Images 41 | Here we demonstrate the performance of our method on images that were sampled from their ground-truth image. 42 | #### Man image from Set14 43 | Blind super-resolution with scale factor of 4 on Gaussian model with std 4.5/sqrt(2) (left is DBPN without correction, right is with estimated correction filter) 44 | 45 | 46 | 47 | #### Images from DIV2KRK dataset 48 | 49 | Blind super-resolution with scale factor of 2 tested on images from DIV2KRK dataset http://www.wisdom.weizmann.ac.il/~vision/kernelgan/ (left is DBPN without correction, right is with estimated correction filter) 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | ## Real-World Super-Resolution 58 | Here we present the results of our approach on images with no ground-truth images 59 | 60 | ### Images from Set5 dataset 61 | Here we take images from Set5 and apply our blind SR (scale factor of 2) algorithm on them directly (without down-sampling them). 62 | 63 | On the left is DBPN without correction, right is with estimated correction filter. 64 | 65 | 66 | 67 | 68 | 69 | ### Chip image 70 | 71 | Super resolution with scale factor of 4 on the famous chip image. On the left is the original LR image, in the middle is the result of DBPN applied directly, and on the right is DBPN applied with the estimated correction filter. 72 | 73 | 74 | -------------------------------------------------------------------------------- /correct_imgs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as transforms 4 | import utils 5 | import os 6 | import Config 7 | import correction_func 8 | import matplotlib.pyplot as plt 9 | import scipy.io as io 10 | 11 | from PIL import Image 12 | 13 | import argparse 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--in_dir', default='./input_x4/4.5/', type=str) 16 | parser.add_argument('--out_dir', default='./output/corrected_x4/', type=str) 17 | parser.add_argument('--opt_suffix', default='', type=str) 18 | parser.add_argument('--scale_factor', type=int, default=4) 19 | parser.add_argument('--eps', type=float, default=0) 20 | args = parser.parse_args() 21 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | ############################################# 24 | 25 | # Define the reconstruction basis here 26 | r = utils.get_bicubic(args.scale_factor).to(args.device) 27 | r = r/r.sum() 28 | 29 | # Define the sampling basis here 30 | sigma = 4.5/np.sqrt(2) 31 | s_size = 32 + args.scale_factor%2 32 | s = utils.get_gauss_flt(s_size, sigma).to(args.device) 33 | s = s/s.sum() 34 | 35 | ############################################# 36 | 37 | if(not os.path.isdir(args.out_dir)): 38 | os.mkdir(args.out_dir) 39 | 40 | imgs = [f for f in os.listdir(args.in_dir) if os.path.isfile(os.path.join(args.in_dir, f)) and ('.mat' in f)] 41 | imgs.sort() 42 | 43 | for img_in in imgs: 44 | y = np.moveaxis(io.loadmat(args.in_dir + img_in)['img'], 2, 0) 45 | y = torch.tensor(y.real).float().unsqueeze(0).to(args.device) 46 | 47 | Corr_flt = correction_func.Correction_Filter(s, args.scale_factor, (y.shape[2]*args.scale_factor, y.shape[3]*args.scale_factor), r=r, eps=args.eps, inv_type='Tikhonov') 48 | 49 | if y.shape[1] == 1: 50 | y = y.repeat(1,3,1,1) 51 | img = img_in[0:-4] + '_x%d_corr.png' %(args.scale_factor) 52 | 53 | y_h = Corr_flt.correct_img(y) 54 | 55 | utils.save_img_torch(y_h.real, args.out_dir + img[0:-4] + '_corrected.png', clamp=True) -------------------------------------------------------------------------------- /correction_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.fft 3 | import numpy as np 4 | import utils 5 | 6 | from torch.optim.lr_scheduler import StepLR 7 | 8 | class Correction_Filter(): 9 | def __init__(self, s, scale_factor, x_shape, eps=0, r=None, inv_type='naive'): 10 | self.s = s.clone() 11 | self.r = None 12 | if r != None: 13 | self.r = r.clone() 14 | else: 15 | self.r = utils.get_bicubic(scale_factor).float().to(s.device) 16 | self.r = self.r/self.r.sum() 17 | self.shape = x_shape 18 | self.scale_factor = scale_factor 19 | self.eps = eps 20 | self.inv_type = inv_type 21 | self.H = self.find_H(self.s, self.r) 22 | 23 | def correct_img(self, y): 24 | y_h = utils.fft_Filter_(y, self.H) 25 | return y_h 26 | 27 | def correct_img_(self, y, s): 28 | self.H = self.find_H(s, self.r) 29 | y_h = utils.fft_Filter_(y, self.H) 30 | return y_h 31 | 32 | def find_H(self, s, r): 33 | R = utils.fft_torch(r, self.shape) 34 | S = utils.fft_torch(s, self.shape) 35 | 36 | R, S = utils.shift_by(R, 0.5*(not self.scale_factor%2)), utils.shift_by(S, 0.5*(not self.scale_factor%2)) 37 | 38 | # Find Q = S*R 39 | Q = S.conj() * R 40 | q = torch.fft.ifftn(Q, dim=(-2,-1)) 41 | 42 | q_d = q[:,:,0::self.scale_factor,0::self.scale_factor] 43 | Q_d = torch.fft.fftn(q_d, dim=(-2,-1)) 44 | 45 | # Find R*R 46 | RR = R.conj() * R 47 | rr = torch.fft.ifftn(RR, dim=(-2,-1)) 48 | rr_d = rr[:,:,0::self.scale_factor,0::self.scale_factor] 49 | RR_d = torch.fft.fftn(rr_d, dim=(-2,-1)) 50 | 51 | # Invert S*R 52 | Q_d_inv = utils.dagger(Q_d, self.eps, mode=self.inv_type) 53 | 54 | H = RR_d * Q_d_inv 55 | 56 | return H 57 | 58 | def est_corr(img_name, y_full, R_dag, S_conj, args, log_file=None, ref_bic=None, s_target=None): 59 | # Crop to area with the most high frequency texture 60 | crop = min(args.crop, min(y_full.shape[2], y_full.shape[3])) 61 | topleft_x, topleft_y = utils.crop_high_freq(y_full, crop, args.device) 62 | print('crop (x,y) = (%d, %d)' %(topleft_x, topleft_y)) 63 | 64 | init_sd1 = utils.get_bicubic(1, (31, 31)).to(args.device) 65 | init_sd1= init_sd1/init_sd1.sum() 66 | s_d_1 = torch.autograd.Variable(init_sd1, requires_grad=True) 67 | 68 | init_sd2 = utils.get_bicubic(1, (31, 31)).to(args.device) 69 | init_sd2 = init_sd2/init_sd2.sum() 70 | s_d_2 = torch.autograd.Variable(init_sd2, requires_grad=True) 71 | 72 | init_sd3 = utils.get_bicubic(1, (31, 31)).to(args.device) 73 | init_sd3 = init_sd3/init_sd3.sum() 74 | s_d_3 = torch.autograd.Variable(init_sd3, requires_grad=True) 75 | 76 | init_sd4 = utils.get_bicubic(args.scale_factor, (32, 32)).to(args.device) 77 | init_sd4 = init_sd4/init_sd4.sum() 78 | s_d_4 = torch.autograd.Variable(init_sd4, requires_grad=True) 79 | optimizer_sd = torch.optim.Adam([{'params' : s_d_1, 'lr': args.lr_s}, {'params' : s_d_2, 'lr': args.lr_s}, {'params' : s_d_3, 'lr': args.lr_s}, {'params' : s_d_4, 'lr': args.lr_s}]) 80 | 81 | objective = torch.nn.L1Loss() 82 | 83 | s_c = torch.fft.ifftn(utils.fft_torch(s_d_1/s_d_1.sum(), y_full.shape[2:4])*utils.fft_torch(s_d_2/s_d_2.sum(), y_full.shape[2:4])* 84 | utils.fft_torch(s_d_3/s_d_3.sum(), y_full.shape[2:4])*utils.fft_torch(s_d_4/s_d_4.sum(), y_full.shape[2:4]) ,dim=(-2,-1)) 85 | with torch.no_grad(): 86 | corr_flt = Correction_Filter(s_c, args.scale_factor, (y_full.shape[2]*args.scale_factor, y_full.shape[3]*args.scale_factor), inv_type='Tikhonov', eps=0) 87 | 88 | for itr in range(args.iterations): 89 | optimizer_sd.zero_grad() 90 | 91 | s_c = torch.fft.ifftn(utils.fft_torch(s_d_1/s_d_1.sum(), y_full.shape[2:4])*utils.fft_torch(s_d_2/s_d_2.sum(), y_full.shape[2:4])* 92 | utils.fft_torch(s_d_3/s_d_3.sum(), y_full.shape[2:4])*utils.fft_torch(s_d_4/s_d_4.sum(), y_full.shape[2:4]) ,dim=(-2,-1)) 93 | 94 | y_full_h = torch.abs(corr_flt.correct_img_(y_full, s_c)).float() 95 | y_h = y_full_h[:,:,topleft_y:topleft_y+crop, topleft_x:topleft_x+crop] 96 | 97 | x_hat = R_dag(y_h + torch.randn_like(y_h)*args.per_std) 98 | 99 | x_hat1 = utils.fft_Filter_(x_hat, utils.fft_torch(utils.flip_torch(s_d_1)/s_d_1.sum(), s = x_hat.shape[2:4])) 100 | x_hat2 = utils.fft_Filter_(x_hat1, utils.fft_torch(utils.flip_torch(s_d_2)/s_d_2.sum(), s = x_hat1.shape[2:4])) 101 | x_hat3 = utils.fft_Filter_(x_hat2, utils.fft_torch(utils.flip_torch(s_d_3)/s_d_3.sum(), s = x_hat2.shape[2:4])) 102 | 103 | y_hat = torch.abs(S_conj(x_hat3, s_d_4/s_d_4.sum())) 104 | 105 | shave = ((crop - y_hat.shape[2])//2, (crop - y_hat.shape[3])//2 ) 106 | y = y_full[:,:,topleft_y+shave[0]:topleft_y+crop-shave[0], topleft_x+shave[1]:topleft_x+crop-shave[1]] 107 | consistency = objective(y_hat[:,:,2:-2, 2:-2], y[:,:,2:-2, 2:-2]) 108 | with torch.no_grad(): 109 | x_c, y_c = utils.get_center_of_mass(torch.roll(s_c.real, (s_c.shape[2]//2, s_c.shape[3]//2), dims=(-2,-1)), args.device) 110 | 111 | abs_s = torch.abs(s_c) 112 | l0 = torch.mean( abs_s[abs_s > 0]**0.5 ) # Relaxed l_0 113 | loss = consistency + args.lambda_l0*l0 114 | 115 | loss.backward() 116 | optimizer_sd.step() 117 | 118 | with torch.no_grad(): 119 | if(args.save_trace and np.mod(itr, 10) == 0): 120 | utils.save_img_torch(y_full_h, args.out_dir + 'within_loop.png') 121 | 122 | opt_s = s_c.clone() 123 | s_norm = s_c/s_c.sum() 124 | out_log = img_name[:-4] + \ 125 | '| Itr = %d' %itr + \ 126 | '| loss = %.7f' %(loss.item()) + \ 127 | '| x_c, y_c = %.2f/%.2f, %.2f/%.2f' %(x_c, (s_c.shape[3] - 1)/2, y_c, (s_c.shape[2] - 1)/2) 128 | if not ref_bic == None: 129 | out_log += '| PSNR bic = %.5f' %(-10*torch.log10( torch.mean( (y_full_h - ref_bic)**2 ) )) 130 | if not s_target == None: 131 | l_s_test = torch.sum( torch.abs(s_norm - s_target)).item() 132 | out_log += '| SAE(s) = %.3f' %l_s_test 133 | if not log_file == None: 134 | log_file.write(out_log + '\n') 135 | print(out_log) 136 | return opt_s 137 | -------------------------------------------------------------------------------- /downsample_imgs.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import torch 3 | import torch.fft 4 | import os 5 | import numpy as np 6 | from PIL import Image 7 | import torchvision.transforms as transforms 8 | import scipy.io as io 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | scale_factor = 4 13 | std = 4.5/np.sqrt(2) 14 | s = utils.get_gauss_flt(32, std).to(device) 15 | s = s/s.sum() 16 | 17 | in_dir = '../../SR_testing_datasets/Set14/' 18 | out_dir = './input_x%d/' %scale_factor 19 | out_GT = './GT_x%d/' %scale_factor 20 | 21 | imgs = [f for f in os.listdir(in_dir) if '.png' in f] 22 | imgs.sort() 23 | 24 | for img in imgs: 25 | I = utils.load_img_torch(in_dir + img, device) 26 | 27 | if I.shape[2] % scale_factor: 28 | I = I[:,:,:-(I.shape[2]%scale_factor),:] 29 | if I.shape[3] % scale_factor: 30 | I = I[:,:,:,:-(I.shape[3]%scale_factor)] 31 | utils.save_img_torch(I, out_GT + img) 32 | 33 | y = utils.fft_Down_(I, s, scale_factor) 34 | 35 | y_np = np.moveaxis(np.array(torch.abs(y)[0,:].cpu()), 0, 2) 36 | utils.save_img_torch(torch.abs(y), out_dir + '/PNG/' + img[:-4] + '_Gauss_std%1.1f_x%d_s.png'%(std, scale_factor)) 37 | io.savemat(out_dir + img[:-4] + '_Gauss_std%1.1f_x%d_s.mat' %(std, scale_factor), {'img': y_np}) 38 | 39 | I_PIL = transforms.ToPILImage()(I[0,:].cpu()) 40 | W, H = I_PIL.size 41 | I_PIL_bic_down = I_PIL.resize((W//scale_factor, H//scale_factor), Image.BICUBIC) 42 | I_PIL_bic_down.save(out_dir + '/bicubic/' + img[:-4] + '_bicubic_down_PIL.png') 43 | 44 | S = utils.fft_torch(s, y.shape[2:4]) 45 | s_ = torch.roll(torch.fft.ifftn(S, dim=(-2,-1)).real, (S.shape[2]//2, S.shape[3]//2), dims=(2,3)) 46 | utils.save_img_torch(s_/s_.max(), out_dir + '/Filters/' + img[:-4] + '_Gauss_std%1.1f_x%d_s.png' %(std, scale_factor)) 47 | -------------------------------------------------------------------------------- /estimate_correction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.fft 4 | import utils 5 | import os 6 | import correction_func 7 | import scipy.io as io 8 | import argparse 9 | 10 | torch.backends.cudnn.enabled = False 11 | torch.backends.cudnn.deterministic = True 12 | np.random.seed(0) 13 | torch.manual_seed(0) 14 | torch.cuda.manual_seed(0) 15 | 16 | torch.backends.cudnn.enabled = True 17 | torch.backends.cudnn.benchmark =False 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--scale_factor', type=int, default=2) 21 | parser.add_argument('--iterations', type=int, default=500) 22 | parser.add_argument('--lr_s', type=float, default=1e-4) 23 | parser.add_argument('--out_dir', default='./output/estimated_x2/', type=str) 24 | parser.add_argument('--in_dir', default='./input_x2/', type=str) 25 | parser.add_argument('--opt_suffix', default='', type=str) 26 | parser.add_argument('--eps', type=float, default=0) 27 | parser.add_argument('--per_std', type=float, default=0.005) 28 | parser.add_argument('--lambda_l0', type=float, default=1) 29 | parser.add_argument('--crop', type=int, default=150) 30 | parser.add_argument('--save_trace', action='store_true') 31 | parser.add_argument('--suffix', default='', type=str) 32 | 33 | args = parser.parse_args() 34 | 35 | args.gpu = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | args.device = torch.device(args.gpu) 37 | 38 | ############################################# 39 | 40 | # Import the SR network here (e.g. DBPN): 41 | from dbpn_iterative import Net as DBPNITER 42 | 43 | # Define the desired SR model here (e.g. DBPN): 44 | SR_model = DBPNITER(num_channels=3, base_filter=64, feat = 256, num_stages=3, scale_factor=args.scale_factor) 45 | SR_model = torch.nn.DataParallel(SR_model, device_ids=[args.gpu], output_device=args.device) 46 | state_dict = torch.load('./models/DBPN-RES-MR64-3_%dx.pth' %args.scale_factor, map_location=args.gpu) 47 | SR_model.load_state_dict(state_dict) 48 | SR_model = SR_model.module 49 | SR_model = SR_model.eval() 50 | r = utils.get_bicubic(args.scale_factor).to(args.device) 51 | r = r/r.sum() 52 | R_dag = lambda I: SR_model(I) + args.scale_factor**2 * torch.abs(utils.fft_Up_(I, r, args.scale_factor)) 53 | 54 | ############################################# 55 | 56 | if(not os.path.isdir(args.out_dir)): 57 | os.mkdir(args.out_dir) 58 | 59 | imgs = [f for f in os.listdir(args.in_dir) if os.path.isfile(os.path.join(args.in_dir, f)) and ('.mat' in f) or ('.png' in f)] 60 | imgs.sort() 61 | 62 | S_conj = lambda I, s: utils.fft_Down_(I, utils.flip_torch(s), args.scale_factor) 63 | 64 | 65 | ref_bic = None 66 | 67 | for img_in in imgs: 68 | if '.mat' in img_in: 69 | y_full = io.loadmat(args.in_dir + img_in)['img'] 70 | y_full = torch.tensor(np.moveaxis(y_full, 2, 0)).unsqueeze(0).to(args.device).float() 71 | else: 72 | y_full = utils.load_img_torch(args.in_dir + img_in, args.device) 73 | if y_full.shape[1] == 1: 74 | y_full = y_full.repeat(1,3,1,1) 75 | if y_full.shape[1] == 4: 76 | y_full = y_full[:,:-1,:] 77 | 78 | img = img_in[0:-4] + '_x%d_corrected.png' %(args.scale_factor) 79 | 80 | if(not os.path.isdir(args.out_dir + '/Filters/')): 81 | os.mkdir(args.out_dir + '/Filters/') 82 | 83 | log_file = open(args.out_dir + '/Filters/' + img[:-4] + args.suffix + '.txt', 'w') 84 | 85 | # Estimate the anti-aliasing filter using the proposed algorithm 86 | opt_s = correction_func.est_corr(img, y_full, R_dag, S_conj, args, log_file, ref_bic=ref_bic) 87 | 88 | # Apply the correction filter using the estimated anti-aliasing filter 89 | corr_flt = correction_func.Correction_Filter(opt_s/opt_s.sum(), args.scale_factor, (y_full.shape[2]*args.scale_factor, y_full.shape[3]*args.scale_factor), inv_type='Tikhonov', eps=args.eps) 90 | y_h = corr_flt.correct_img(y_full) 91 | 92 | # Save the result 93 | utils.save_img_torch(y_h.real, args.out_dir + img[0:-4] + args.suffix + '_est.png') 94 | utils.save_img_torch(torch.roll(opt_s.real, ((opt_s.shape[2]//2), (opt_s.shape[3]//2)), dims=(-2,-1))/opt_s.real.max(), args.out_dir + '/Filters/' + img[0:-4] + args.suffix + '_est_s.png') 95 | log_file.close() 96 | -------------------------------------------------------------------------------- /figs/SR/baboon_Gauss_std3.2_x4_s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/baboon_Gauss_std3.2_x4_s.png -------------------------------------------------------------------------------- /figs/SR/baboon_Gauss_std3.2_x4_s_x4_corr_corrected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/baboon_Gauss_std3.2_x4_s_x4_corr_corrected.png -------------------------------------------------------------------------------- /figs/SR/bridge_Gauss_std1.8_x2_s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/bridge_Gauss_std1.8_x2_s.png -------------------------------------------------------------------------------- /figs/SR/bridge_Gauss_std1.8_x2_s_x2_corr_corrected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/bridge_Gauss_std1.8_x2_s_x2_corr_corrected.png -------------------------------------------------------------------------------- /figs/SR/zebra_Gauss_std3.2_x4_s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/zebra_Gauss_std3.2_x4_s.png -------------------------------------------------------------------------------- /figs/SR/zebra_Gauss_std3.2_x4_s_x4_corr_corrected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/zebra_Gauss_std3.2_x4_s_x4_corr_corrected.png -------------------------------------------------------------------------------- /figs/blind_SR/bird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/bird.png -------------------------------------------------------------------------------- /figs/blind_SR/bird_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/bird_.png -------------------------------------------------------------------------------- /figs/blind_SR/bird_x2_corr_est.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/bird_x2_corr_est.png -------------------------------------------------------------------------------- /figs/blind_SR/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/butterfly.png -------------------------------------------------------------------------------- /figs/blind_SR/butterfly_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/butterfly_.png -------------------------------------------------------------------------------- /figs/blind_SR/butterfly_x2_corr_est.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/butterfly_x2_corr_est.png -------------------------------------------------------------------------------- /figs/blind_SR/chip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/chip.png -------------------------------------------------------------------------------- /figs/blind_SR/chip_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/chip_LR.png -------------------------------------------------------------------------------- /figs/blind_SR/chip_x4_corrected_est.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/chip_x4_corrected_est.png -------------------------------------------------------------------------------- /figs/blind_SR/im_31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_31.png -------------------------------------------------------------------------------- /figs/blind_SR/im_31_x2_corr_est.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_31_x2_corr_est.png -------------------------------------------------------------------------------- /figs/blind_SR/im_59.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_59.png -------------------------------------------------------------------------------- /figs/blind_SR/im_59_x2_corr_est.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_59_x2_corr_est.png -------------------------------------------------------------------------------- /figs/blind_SR/im_66.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_66.png -------------------------------------------------------------------------------- /figs/blind_SR/im_66_x2_corr_est.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_66_x2_corr_est.png -------------------------------------------------------------------------------- /figs/blind_SR/man2_Gauss_std3.2_x4_s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/man2_Gauss_std3.2_x4_s.png -------------------------------------------------------------------------------- /figs/blind_SR/man2_Gauss_std3.2_x4_s_x4_corr_l0_est.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/man2_Gauss_std3.2_x4_s_x4_corr_l0_est.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.fft 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | from scipy import interpolate 6 | from scipy import fftpack 7 | from scipy import integrate 8 | from scipy import signal 9 | from PIL import Image 10 | 11 | def flip(x): 12 | return x.flip([2,3]) 13 | 14 | def flip_torch(x): 15 | x_ = torch.flip(torch.roll(x, ((x.shape[2]//2), (x.shape[3]//2)), dims=(2,3)), dims=(2,3)) 16 | return torch.roll(x_, (- (x_.shape[2]//2), -(x_.shape[3]//2)), dims=(2,3)) 17 | 18 | def flip_np(x): 19 | x_ = np.flip(np.roll(x, ((x.shape[0]//2), (x.shape[1]//2)), (0,1))) 20 | return np.roll(x_, (- (x_.shape[0]//2), -(x_.shape[1]//2)), (0,1)) 21 | 22 | def shift_by(H, shift): 23 | k_x = np.linspace(0, H.shape[3]-1, H.shape[3]) 24 | k_y = np.linspace(0, H.shape[2]-1, H.shape[2]) 25 | 26 | k_x[((k_x.shape[0] + 1)//2):] -= H.shape[3] 27 | k_y[((k_y.shape[0] + 1)//2):] -= H.shape[2] 28 | 29 | exp_x, exp_y = np.meshgrid(np.exp(-1j * 2* np.pi * k_x * shift / H.shape[3]), np.exp(-1j * 2* np.pi * k_y * shift / H.shape[2])) 30 | 31 | exp_x_torch = (torch.tensor(np.real(exp_x)) + 1j*torch.tensor(np.imag(exp_x))).unsqueeze(0).unsqueeze(0).to(H.device) 32 | exp_y_torch = (torch.tensor(np.real(exp_y)) + 1j*torch.tensor(np.imag(exp_y))).unsqueeze(0).unsqueeze(0).to(H.device) 33 | 34 | return H * exp_x_torch * exp_y_torch 35 | 36 | def fft_torch(x, s=None, zero_centered=True): 37 | # s = (Ny, Nx) 38 | __,__,H,W = x.shape 39 | if s == None: 40 | s = (H, W) 41 | if zero_centered: 42 | x_ = torch.roll(x, ((H//2), (W//2)), dims=(2,3)) 43 | else: 44 | x_ = x 45 | x_pad = torch.nn.functional.pad(x_, (0, s[1] - W, 0, s[0] - H)) 46 | if zero_centered: 47 | x_pad_ = torch.roll(x_pad, (- (H//2), -(W//2)), dims=(2,3)) 48 | else: 49 | x_pad_ = x_pad 50 | return torch.fft.fftn(x_pad_, dim=(-2,-1)) 51 | 52 | def bicubic_ker(x, y, a=-0.5): 53 | # X: 54 | abs_phase = np.abs(x) 55 | abs_phase2 = abs_phase**2 56 | abs_phase3 = abs_phase**3 57 | out_x = np.zeros_like(x) 58 | out_x[abs_phase <= 1] = (a+2)*abs_phase3[abs_phase <= 1] - (a+3)*abs_phase2[abs_phase <= 1] + 1 59 | out_x[(abs_phase > 1) & (abs_phase < 2)] = a*abs_phase3[(abs_phase > 1) & (abs_phase < 2)] -\ 60 | 5*a*abs_phase2[(abs_phase > 1) & (abs_phase < 2)] +\ 61 | 8*a*abs_phase[(abs_phase > 1) & (abs_phase < 2)] - 4*a 62 | # Y: 63 | abs_phase = np.abs(y) 64 | abs_phase2 = abs_phase**2 65 | abs_phase3 = abs_phase**3 66 | out_y = np.zeros_like(y) 67 | out_y[abs_phase <= 1] = (a+2)*abs_phase3[abs_phase <= 1] - (a+3)*abs_phase2[abs_phase <= 1] + 1 68 | out_y[(abs_phase > 1) & (abs_phase < 2)] = a*abs_phase3[(abs_phase > 1) & (abs_phase < 2)] -\ 69 | 5*a*abs_phase2[(abs_phase > 1) & (abs_phase < 2)] +\ 70 | 8*a*abs_phase[(abs_phase > 1) & (abs_phase < 2)] - 4*a 71 | 72 | return out_x*out_y 73 | 74 | def build_flt(f, size): 75 | is_even_x = not size[1] % 2 76 | is_even_y = not size[0] % 2 77 | 78 | grid_x = np.linspace(-(size[1]//2 - is_even_x*0.5), (size[1]//2 - is_even_x*0.5), size[1]) 79 | grid_y = np.linspace(-(size[0]//2 - is_even_y*0.5), (size[0]//2 - is_even_y*0.5), size[0]) 80 | 81 | x, y = np.meshgrid(grid_x, grid_y) 82 | 83 | h =f(x, y) 84 | h = np.roll(h, (- (h.shape[0]//2), -(h.shape[1]//2)), (0,1)) 85 | 86 | return torch.tensor(h).float().unsqueeze(0).unsqueeze(0) 87 | 88 | def get_bicubic(scale, size=None): 89 | f = lambda x,y: bicubic_ker(x/scale, y/scale) 90 | if size: 91 | h = build_flt(f, (size[0], size[1])) 92 | else: 93 | h = build_flt(f, (4*scale + 8 + scale%2, 4*scale + 8 + scale%2)) 94 | return h 95 | 96 | def get_box(supp, size=None): 97 | if size == None: 98 | size = (supp[0]*2, supp[1]*2) 99 | 100 | h = np.zeros(size) 101 | 102 | h[0:supp[0]//2 , 0:supp[1]//2] = 1 103 | h[0:supp[0]//2 , -(supp[1]//2):] = 1 104 | h[-(supp[0]//2):, 0:supp[1]//2] = 1 105 | h[-(supp[0]//2):, -(supp[1]//2):] = 1 106 | 107 | return torch.tensor(h).float().unsqueeze(0).unsqueeze(0) 108 | 109 | def get_delta(size): 110 | h = torch.zeros(1,1,size,size) 111 | h[0,0,0,0] = 1 112 | return h 113 | 114 | def get_gauss_flt(flt_size, std): 115 | f = lambda x,y: np.exp( -(x**2 + y**2)/2/std**2 ) 116 | h = build_flt(f, (flt_size,flt_size)) 117 | return h 118 | 119 | def fft_Filter_(x, A): 120 | X_fft = torch.fft.fftn(x, dim=(-2,-1)) 121 | HX = A * X_fft 122 | return torch.fft.ifftn(HX, dim=(-2,-1)) 123 | 124 | def fft_Down_(x, h, alpha): 125 | X_fft = torch.fft.fftn(x, dim=(-2,-1)) 126 | H = fft_torch(h, s=X_fft.shape[2:4]) 127 | HX = H * X_fft 128 | margin = (alpha - 1)//2 129 | y = torch.fft.ifftn(HX, dim=(-2,-1))[:,:,margin:HX.shape[2]-margin:alpha, margin:HX.shape[3]-margin:alpha] 130 | return y 131 | 132 | def fft_Up_(y, h, alpha): 133 | x = torch.zeros(y.shape[0], y.shape[1], y.shape[2]*alpha, y.shape[3]*alpha).to(y.device) 134 | H = fft_torch(h, s=x.shape[2:4]) 135 | start = alpha//2 136 | x[:,:,start::alpha, start::alpha] = y 137 | X = torch.fft.fftn(x, dim=(-2,-1)) 138 | HX = H * X 139 | return torch.fft.ifftn(HX, dim=(-2,-1)) 140 | 141 | def zero_SV(H, eps): 142 | H_real = H.real 143 | H_imag = H.imag 144 | abs_H2 = H_real**2 + H_imag**2 145 | H[abs_H2/abs_H2.max() <= eps**2] = 0 146 | return H 147 | 148 | def dagger(X, eps=0, mode='Tikhonov'): 149 | real = X.real 150 | imag = X.imag 151 | abs2 = real**2 + imag**2 152 | if mode == 'naive': 153 | out = X.clone() 154 | out[abs2/abs2.max() > eps**2] = 1/X[abs2/abs2.max() > eps**2] 155 | out[abs2/abs2.max() <= eps**2] = 0 156 | return out 157 | if mode == 'Tikhonov': 158 | return X.conj()/(abs2 + eps**2) 159 | 160 | def load_img_torch(dir, device): 161 | I = Image.open(dir) 162 | I = transforms.ToTensor()(I).unsqueeze(0) 163 | return I.to(device) 164 | 165 | def save_img_torch(I, dir, clamp=True): 166 | if clamp: 167 | img = torch.clamp(I, 0, 1)[0,:].detach().cpu() 168 | else: 169 | img = I[0,:].detach().cpu() 170 | img = transforms.ToPILImage()(img) 171 | img.save(dir) 172 | 173 | def get_center_of_mass(s, device): 174 | idx_x = torch.linspace(0, s.shape[3]-1, s.shape[3]).to(device) 175 | idx_y = torch.linspace(0, s.shape[2]-1, s.shape[2]).to(device) 176 | i_y, i_x = torch.meshgrid(idx_y, idx_x) 177 | 178 | x_c = torch.sum(s*i_x/s.sum()) 179 | y_c = torch.sum(s*i_y/s.sum()) 180 | 181 | return x_c, y_c 182 | 183 | def crop_high_freq(I, crop_size, device): 184 | if(crop_size >= I.shape[2] or crop_size >= I.shape[3]): 185 | argmax_x = I.shape[3]//2 - crop_size//2 186 | argmax_y = I.shape[2]//2 - crop_size//2 187 | else: 188 | filt = torch.tensor([[ 0,-1, 0], 189 | [-1, 4,-1], 190 | [ 0,-1, 0]]).float().to(device) 191 | if(I.shape[1] > 1): 192 | I_gray = 0.2126*I[:,0:1,:,:] + 0.7152*I[:,1:2,:,:] + 0.0722*I[:,2:3,:,:] 193 | else: 194 | I_gray = I 195 | D = torch.abs(torch.nn.functional.conv2d(I_gray, filt.unsqueeze(0).unsqueeze(0))) 196 | Avg = torch.nn.functional.avg_pool2d(D, crop_size, stride=1, padding=0, ceil_mode=False) 197 | argmax = Avg.argmax() 198 | argmax_y = argmax//Avg.shape[3] 199 | argmax_x = argmax % Avg.shape[3] 200 | return argmax_x, argmax_y --------------------------------------------------------------------------------