├── README.md
├── correct_imgs.py
├── correction_func.py
├── downsample_imgs.py
├── estimate_correction.py
├── figs
├── SR
│ ├── baboon_Gauss_std3.2_x4_s.png
│ ├── baboon_Gauss_std3.2_x4_s_x4_corr_corrected.png
│ ├── bridge_Gauss_std1.8_x2_s.png
│ ├── bridge_Gauss_std1.8_x2_s_x2_corr_corrected.png
│ ├── zebra_Gauss_std3.2_x4_s.png
│ └── zebra_Gauss_std3.2_x4_s_x4_corr_corrected.png
└── blind_SR
│ ├── bird.png
│ ├── bird_.png
│ ├── bird_x2_corr_est.png
│ ├── butterfly.png
│ ├── butterfly_.png
│ ├── butterfly_x2_corr_est.png
│ ├── chip.png
│ ├── chip_LR.png
│ ├── chip_x4_corrected_est.png
│ ├── im_31.png
│ ├── im_31_x2_corr_est.png
│ ├── im_59.png
│ ├── im_59_x2_corr_est.png
│ ├── im_66.png
│ ├── im_66_x2_corr_est.png
│ ├── man2_Gauss_std3.2_x4_s.png
│ └── man2_Gauss_std3.2_x4_s_x4_corr_l0_est.png
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Correction-Filter
2 |
3 | The official implementation of the work "Correction Filter for Single Image Super-Resolution: Robustifying Off-the-Shelf Deep Super-Resolvers" (https://arxiv.org/abs/1912.00157 , Accepted to CVPR 2020 - oral)
4 |
5 | # Non-Blind
6 | 1. Downsample images and put into folders according to which down-sampling filter was used (it is recommended to use '.mat' for saving the LR images).
7 | 2. Define the sampling and reconstruction basis (s and r) in "correct_imgs.py" (lines 25-33).
8 | 3. Run: python correct_imgs.py --in_dir "Directory to the folder of the LR images" --out_dir "Directory of where to save the corrected images" --scale_factor "SR scale factor".
9 | 4. Run any off-the-shelf deep SR network trained using r (usually bicubic) on the images saved to out_dir
10 |
11 | Note that this code assumes that the images within a folder are sampled using the same kernel.
12 |
13 | # Blind
14 | 1. Define the SR network in "estimate_correction.py".
15 | 2. Run: estimate_correction.py --scale_factor "SR scale factor" --in_dir "Directory of the LR images" --out_dir "Directory to save the LR corrected images in it"
16 | 4. Run any off-the-shelf deep SR network trained using r (usually bicubic) on the images saved to out_dir
17 |
18 | # Citation:
19 |
20 | @ARTICLE{correction_filter,
21 | author = {{Abu Hussein}, Shady and {Tirer}, Tom and {Giryes}, Raja},
22 | title = "{Correction Filter for Single Image Super-Resolution: Robustifying Off-the-Shelf Deep Super-Resolvers}",
23 | journal = {In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
24 | year = "2020"
25 | }
26 |
27 | # Results
28 | ## Non-Blind Super-Resolution
29 | Non-blind super-resolution with scale factor of 4 on Gaussian model with std 4.5/sqrt(2) (left is DBPN without correction, right is with correction filter)
30 |
31 |
32 |
33 |
34 |
35 | Non-blind super-resolution with scale factor of 2 on Gaussian model with std 2.5/sqrt(2) (left is DBPN without correction, right is with correction filter)
36 |
37 |
38 |
39 | ## Blind Super-Resolution
40 | ### Synthetic Images
41 | Here we demonstrate the performance of our method on images that were sampled from their ground-truth image.
42 | #### Man image from Set14
43 | Blind super-resolution with scale factor of 4 on Gaussian model with std 4.5/sqrt(2) (left is DBPN without correction, right is with estimated correction filter)
44 |
45 |
46 |
47 | #### Images from DIV2KRK dataset
48 |
49 | Blind super-resolution with scale factor of 2 tested on images from DIV2KRK dataset http://www.wisdom.weizmann.ac.il/~vision/kernelgan/ (left is DBPN without correction, right is with estimated correction filter)
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 | ## Real-World Super-Resolution
58 | Here we present the results of our approach on images with no ground-truth images
59 |
60 | ### Images from Set5 dataset
61 | Here we take images from Set5 and apply our blind SR (scale factor of 2) algorithm on them directly (without down-sampling them).
62 |
63 | On the left is DBPN without correction, right is with estimated correction filter.
64 |
65 |
66 |
67 |
68 |
69 | ### Chip image
70 |
71 | Super resolution with scale factor of 4 on the famous chip image. On the left is the original LR image, in the middle is the result of DBPN applied directly, and on the right is DBPN applied with the estimated correction filter.
72 |
73 |
74 |
--------------------------------------------------------------------------------
/correct_imgs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision.transforms as transforms
4 | import utils
5 | import os
6 | import Config
7 | import correction_func
8 | import matplotlib.pyplot as plt
9 | import scipy.io as io
10 |
11 | from PIL import Image
12 |
13 | import argparse
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--in_dir', default='./input_x4/4.5/', type=str)
16 | parser.add_argument('--out_dir', default='./output/corrected_x4/', type=str)
17 | parser.add_argument('--opt_suffix', default='', type=str)
18 | parser.add_argument('--scale_factor', type=int, default=4)
19 | parser.add_argument('--eps', type=float, default=0)
20 | args = parser.parse_args()
21 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 |
23 | #############################################
24 |
25 | # Define the reconstruction basis here
26 | r = utils.get_bicubic(args.scale_factor).to(args.device)
27 | r = r/r.sum()
28 |
29 | # Define the sampling basis here
30 | sigma = 4.5/np.sqrt(2)
31 | s_size = 32 + args.scale_factor%2
32 | s = utils.get_gauss_flt(s_size, sigma).to(args.device)
33 | s = s/s.sum()
34 |
35 | #############################################
36 |
37 | if(not os.path.isdir(args.out_dir)):
38 | os.mkdir(args.out_dir)
39 |
40 | imgs = [f for f in os.listdir(args.in_dir) if os.path.isfile(os.path.join(args.in_dir, f)) and ('.mat' in f)]
41 | imgs.sort()
42 |
43 | for img_in in imgs:
44 | y = np.moveaxis(io.loadmat(args.in_dir + img_in)['img'], 2, 0)
45 | y = torch.tensor(y.real).float().unsqueeze(0).to(args.device)
46 |
47 | Corr_flt = correction_func.Correction_Filter(s, args.scale_factor, (y.shape[2]*args.scale_factor, y.shape[3]*args.scale_factor), r=r, eps=args.eps, inv_type='Tikhonov')
48 |
49 | if y.shape[1] == 1:
50 | y = y.repeat(1,3,1,1)
51 | img = img_in[0:-4] + '_x%d_corr.png' %(args.scale_factor)
52 |
53 | y_h = Corr_flt.correct_img(y)
54 |
55 | utils.save_img_torch(y_h.real, args.out_dir + img[0:-4] + '_corrected.png', clamp=True)
--------------------------------------------------------------------------------
/correction_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.fft
3 | import numpy as np
4 | import utils
5 |
6 | from torch.optim.lr_scheduler import StepLR
7 |
8 | class Correction_Filter():
9 | def __init__(self, s, scale_factor, x_shape, eps=0, r=None, inv_type='naive'):
10 | self.s = s.clone()
11 | self.r = None
12 | if r != None:
13 | self.r = r.clone()
14 | else:
15 | self.r = utils.get_bicubic(scale_factor).float().to(s.device)
16 | self.r = self.r/self.r.sum()
17 | self.shape = x_shape
18 | self.scale_factor = scale_factor
19 | self.eps = eps
20 | self.inv_type = inv_type
21 | self.H = self.find_H(self.s, self.r)
22 |
23 | def correct_img(self, y):
24 | y_h = utils.fft_Filter_(y, self.H)
25 | return y_h
26 |
27 | def correct_img_(self, y, s):
28 | self.H = self.find_H(s, self.r)
29 | y_h = utils.fft_Filter_(y, self.H)
30 | return y_h
31 |
32 | def find_H(self, s, r):
33 | R = utils.fft_torch(r, self.shape)
34 | S = utils.fft_torch(s, self.shape)
35 |
36 | R, S = utils.shift_by(R, 0.5*(not self.scale_factor%2)), utils.shift_by(S, 0.5*(not self.scale_factor%2))
37 |
38 | # Find Q = S*R
39 | Q = S.conj() * R
40 | q = torch.fft.ifftn(Q, dim=(-2,-1))
41 |
42 | q_d = q[:,:,0::self.scale_factor,0::self.scale_factor]
43 | Q_d = torch.fft.fftn(q_d, dim=(-2,-1))
44 |
45 | # Find R*R
46 | RR = R.conj() * R
47 | rr = torch.fft.ifftn(RR, dim=(-2,-1))
48 | rr_d = rr[:,:,0::self.scale_factor,0::self.scale_factor]
49 | RR_d = torch.fft.fftn(rr_d, dim=(-2,-1))
50 |
51 | # Invert S*R
52 | Q_d_inv = utils.dagger(Q_d, self.eps, mode=self.inv_type)
53 |
54 | H = RR_d * Q_d_inv
55 |
56 | return H
57 |
58 | def est_corr(img_name, y_full, R_dag, S_conj, args, log_file=None, ref_bic=None, s_target=None):
59 | # Crop to area with the most high frequency texture
60 | crop = min(args.crop, min(y_full.shape[2], y_full.shape[3]))
61 | topleft_x, topleft_y = utils.crop_high_freq(y_full, crop, args.device)
62 | print('crop (x,y) = (%d, %d)' %(topleft_x, topleft_y))
63 |
64 | init_sd1 = utils.get_bicubic(1, (31, 31)).to(args.device)
65 | init_sd1= init_sd1/init_sd1.sum()
66 | s_d_1 = torch.autograd.Variable(init_sd1, requires_grad=True)
67 |
68 | init_sd2 = utils.get_bicubic(1, (31, 31)).to(args.device)
69 | init_sd2 = init_sd2/init_sd2.sum()
70 | s_d_2 = torch.autograd.Variable(init_sd2, requires_grad=True)
71 |
72 | init_sd3 = utils.get_bicubic(1, (31, 31)).to(args.device)
73 | init_sd3 = init_sd3/init_sd3.sum()
74 | s_d_3 = torch.autograd.Variable(init_sd3, requires_grad=True)
75 |
76 | init_sd4 = utils.get_bicubic(args.scale_factor, (32, 32)).to(args.device)
77 | init_sd4 = init_sd4/init_sd4.sum()
78 | s_d_4 = torch.autograd.Variable(init_sd4, requires_grad=True)
79 | optimizer_sd = torch.optim.Adam([{'params' : s_d_1, 'lr': args.lr_s}, {'params' : s_d_2, 'lr': args.lr_s}, {'params' : s_d_3, 'lr': args.lr_s}, {'params' : s_d_4, 'lr': args.lr_s}])
80 |
81 | objective = torch.nn.L1Loss()
82 |
83 | s_c = torch.fft.ifftn(utils.fft_torch(s_d_1/s_d_1.sum(), y_full.shape[2:4])*utils.fft_torch(s_d_2/s_d_2.sum(), y_full.shape[2:4])*
84 | utils.fft_torch(s_d_3/s_d_3.sum(), y_full.shape[2:4])*utils.fft_torch(s_d_4/s_d_4.sum(), y_full.shape[2:4]) ,dim=(-2,-1))
85 | with torch.no_grad():
86 | corr_flt = Correction_Filter(s_c, args.scale_factor, (y_full.shape[2]*args.scale_factor, y_full.shape[3]*args.scale_factor), inv_type='Tikhonov', eps=0)
87 |
88 | for itr in range(args.iterations):
89 | optimizer_sd.zero_grad()
90 |
91 | s_c = torch.fft.ifftn(utils.fft_torch(s_d_1/s_d_1.sum(), y_full.shape[2:4])*utils.fft_torch(s_d_2/s_d_2.sum(), y_full.shape[2:4])*
92 | utils.fft_torch(s_d_3/s_d_3.sum(), y_full.shape[2:4])*utils.fft_torch(s_d_4/s_d_4.sum(), y_full.shape[2:4]) ,dim=(-2,-1))
93 |
94 | y_full_h = torch.abs(corr_flt.correct_img_(y_full, s_c)).float()
95 | y_h = y_full_h[:,:,topleft_y:topleft_y+crop, topleft_x:topleft_x+crop]
96 |
97 | x_hat = R_dag(y_h + torch.randn_like(y_h)*args.per_std)
98 |
99 | x_hat1 = utils.fft_Filter_(x_hat, utils.fft_torch(utils.flip_torch(s_d_1)/s_d_1.sum(), s = x_hat.shape[2:4]))
100 | x_hat2 = utils.fft_Filter_(x_hat1, utils.fft_torch(utils.flip_torch(s_d_2)/s_d_2.sum(), s = x_hat1.shape[2:4]))
101 | x_hat3 = utils.fft_Filter_(x_hat2, utils.fft_torch(utils.flip_torch(s_d_3)/s_d_3.sum(), s = x_hat2.shape[2:4]))
102 |
103 | y_hat = torch.abs(S_conj(x_hat3, s_d_4/s_d_4.sum()))
104 |
105 | shave = ((crop - y_hat.shape[2])//2, (crop - y_hat.shape[3])//2 )
106 | y = y_full[:,:,topleft_y+shave[0]:topleft_y+crop-shave[0], topleft_x+shave[1]:topleft_x+crop-shave[1]]
107 | consistency = objective(y_hat[:,:,2:-2, 2:-2], y[:,:,2:-2, 2:-2])
108 | with torch.no_grad():
109 | x_c, y_c = utils.get_center_of_mass(torch.roll(s_c.real, (s_c.shape[2]//2, s_c.shape[3]//2), dims=(-2,-1)), args.device)
110 |
111 | abs_s = torch.abs(s_c)
112 | l0 = torch.mean( abs_s[abs_s > 0]**0.5 ) # Relaxed l_0
113 | loss = consistency + args.lambda_l0*l0
114 |
115 | loss.backward()
116 | optimizer_sd.step()
117 |
118 | with torch.no_grad():
119 | if(args.save_trace and np.mod(itr, 10) == 0):
120 | utils.save_img_torch(y_full_h, args.out_dir + 'within_loop.png')
121 |
122 | opt_s = s_c.clone()
123 | s_norm = s_c/s_c.sum()
124 | out_log = img_name[:-4] + \
125 | '| Itr = %d' %itr + \
126 | '| loss = %.7f' %(loss.item()) + \
127 | '| x_c, y_c = %.2f/%.2f, %.2f/%.2f' %(x_c, (s_c.shape[3] - 1)/2, y_c, (s_c.shape[2] - 1)/2)
128 | if not ref_bic == None:
129 | out_log += '| PSNR bic = %.5f' %(-10*torch.log10( torch.mean( (y_full_h - ref_bic)**2 ) ))
130 | if not s_target == None:
131 | l_s_test = torch.sum( torch.abs(s_norm - s_target)).item()
132 | out_log += '| SAE(s) = %.3f' %l_s_test
133 | if not log_file == None:
134 | log_file.write(out_log + '\n')
135 | print(out_log)
136 | return opt_s
137 |
--------------------------------------------------------------------------------
/downsample_imgs.py:
--------------------------------------------------------------------------------
1 | import utils
2 | import torch
3 | import torch.fft
4 | import os
5 | import numpy as np
6 | from PIL import Image
7 | import torchvision.transforms as transforms
8 | import scipy.io as io
9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 |
11 |
12 | scale_factor = 4
13 | std = 4.5/np.sqrt(2)
14 | s = utils.get_gauss_flt(32, std).to(device)
15 | s = s/s.sum()
16 |
17 | in_dir = '../../SR_testing_datasets/Set14/'
18 | out_dir = './input_x%d/' %scale_factor
19 | out_GT = './GT_x%d/' %scale_factor
20 |
21 | imgs = [f for f in os.listdir(in_dir) if '.png' in f]
22 | imgs.sort()
23 |
24 | for img in imgs:
25 | I = utils.load_img_torch(in_dir + img, device)
26 |
27 | if I.shape[2] % scale_factor:
28 | I = I[:,:,:-(I.shape[2]%scale_factor),:]
29 | if I.shape[3] % scale_factor:
30 | I = I[:,:,:,:-(I.shape[3]%scale_factor)]
31 | utils.save_img_torch(I, out_GT + img)
32 |
33 | y = utils.fft_Down_(I, s, scale_factor)
34 |
35 | y_np = np.moveaxis(np.array(torch.abs(y)[0,:].cpu()), 0, 2)
36 | utils.save_img_torch(torch.abs(y), out_dir + '/PNG/' + img[:-4] + '_Gauss_std%1.1f_x%d_s.png'%(std, scale_factor))
37 | io.savemat(out_dir + img[:-4] + '_Gauss_std%1.1f_x%d_s.mat' %(std, scale_factor), {'img': y_np})
38 |
39 | I_PIL = transforms.ToPILImage()(I[0,:].cpu())
40 | W, H = I_PIL.size
41 | I_PIL_bic_down = I_PIL.resize((W//scale_factor, H//scale_factor), Image.BICUBIC)
42 | I_PIL_bic_down.save(out_dir + '/bicubic/' + img[:-4] + '_bicubic_down_PIL.png')
43 |
44 | S = utils.fft_torch(s, y.shape[2:4])
45 | s_ = torch.roll(torch.fft.ifftn(S, dim=(-2,-1)).real, (S.shape[2]//2, S.shape[3]//2), dims=(2,3))
46 | utils.save_img_torch(s_/s_.max(), out_dir + '/Filters/' + img[:-4] + '_Gauss_std%1.1f_x%d_s.png' %(std, scale_factor))
47 |
--------------------------------------------------------------------------------
/estimate_correction.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.fft
4 | import utils
5 | import os
6 | import correction_func
7 | import scipy.io as io
8 | import argparse
9 |
10 | torch.backends.cudnn.enabled = False
11 | torch.backends.cudnn.deterministic = True
12 | np.random.seed(0)
13 | torch.manual_seed(0)
14 | torch.cuda.manual_seed(0)
15 |
16 | torch.backends.cudnn.enabled = True
17 | torch.backends.cudnn.benchmark =False
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--scale_factor', type=int, default=2)
21 | parser.add_argument('--iterations', type=int, default=500)
22 | parser.add_argument('--lr_s', type=float, default=1e-4)
23 | parser.add_argument('--out_dir', default='./output/estimated_x2/', type=str)
24 | parser.add_argument('--in_dir', default='./input_x2/', type=str)
25 | parser.add_argument('--opt_suffix', default='', type=str)
26 | parser.add_argument('--eps', type=float, default=0)
27 | parser.add_argument('--per_std', type=float, default=0.005)
28 | parser.add_argument('--lambda_l0', type=float, default=1)
29 | parser.add_argument('--crop', type=int, default=150)
30 | parser.add_argument('--save_trace', action='store_true')
31 | parser.add_argument('--suffix', default='', type=str)
32 |
33 | args = parser.parse_args()
34 |
35 | args.gpu = 'cuda' if torch.cuda.is_available() else 'cpu'
36 | args.device = torch.device(args.gpu)
37 |
38 | #############################################
39 |
40 | # Import the SR network here (e.g. DBPN):
41 | from dbpn_iterative import Net as DBPNITER
42 |
43 | # Define the desired SR model here (e.g. DBPN):
44 | SR_model = DBPNITER(num_channels=3, base_filter=64, feat = 256, num_stages=3, scale_factor=args.scale_factor)
45 | SR_model = torch.nn.DataParallel(SR_model, device_ids=[args.gpu], output_device=args.device)
46 | state_dict = torch.load('./models/DBPN-RES-MR64-3_%dx.pth' %args.scale_factor, map_location=args.gpu)
47 | SR_model.load_state_dict(state_dict)
48 | SR_model = SR_model.module
49 | SR_model = SR_model.eval()
50 | r = utils.get_bicubic(args.scale_factor).to(args.device)
51 | r = r/r.sum()
52 | R_dag = lambda I: SR_model(I) + args.scale_factor**2 * torch.abs(utils.fft_Up_(I, r, args.scale_factor))
53 |
54 | #############################################
55 |
56 | if(not os.path.isdir(args.out_dir)):
57 | os.mkdir(args.out_dir)
58 |
59 | imgs = [f for f in os.listdir(args.in_dir) if os.path.isfile(os.path.join(args.in_dir, f)) and ('.mat' in f) or ('.png' in f)]
60 | imgs.sort()
61 |
62 | S_conj = lambda I, s: utils.fft_Down_(I, utils.flip_torch(s), args.scale_factor)
63 |
64 |
65 | ref_bic = None
66 |
67 | for img_in in imgs:
68 | if '.mat' in img_in:
69 | y_full = io.loadmat(args.in_dir + img_in)['img']
70 | y_full = torch.tensor(np.moveaxis(y_full, 2, 0)).unsqueeze(0).to(args.device).float()
71 | else:
72 | y_full = utils.load_img_torch(args.in_dir + img_in, args.device)
73 | if y_full.shape[1] == 1:
74 | y_full = y_full.repeat(1,3,1,1)
75 | if y_full.shape[1] == 4:
76 | y_full = y_full[:,:-1,:]
77 |
78 | img = img_in[0:-4] + '_x%d_corrected.png' %(args.scale_factor)
79 |
80 | if(not os.path.isdir(args.out_dir + '/Filters/')):
81 | os.mkdir(args.out_dir + '/Filters/')
82 |
83 | log_file = open(args.out_dir + '/Filters/' + img[:-4] + args.suffix + '.txt', 'w')
84 |
85 | # Estimate the anti-aliasing filter using the proposed algorithm
86 | opt_s = correction_func.est_corr(img, y_full, R_dag, S_conj, args, log_file, ref_bic=ref_bic)
87 |
88 | # Apply the correction filter using the estimated anti-aliasing filter
89 | corr_flt = correction_func.Correction_Filter(opt_s/opt_s.sum(), args.scale_factor, (y_full.shape[2]*args.scale_factor, y_full.shape[3]*args.scale_factor), inv_type='Tikhonov', eps=args.eps)
90 | y_h = corr_flt.correct_img(y_full)
91 |
92 | # Save the result
93 | utils.save_img_torch(y_h.real, args.out_dir + img[0:-4] + args.suffix + '_est.png')
94 | utils.save_img_torch(torch.roll(opt_s.real, ((opt_s.shape[2]//2), (opt_s.shape[3]//2)), dims=(-2,-1))/opt_s.real.max(), args.out_dir + '/Filters/' + img[0:-4] + args.suffix + '_est_s.png')
95 | log_file.close()
96 |
--------------------------------------------------------------------------------
/figs/SR/baboon_Gauss_std3.2_x4_s.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/baboon_Gauss_std3.2_x4_s.png
--------------------------------------------------------------------------------
/figs/SR/baboon_Gauss_std3.2_x4_s_x4_corr_corrected.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/baboon_Gauss_std3.2_x4_s_x4_corr_corrected.png
--------------------------------------------------------------------------------
/figs/SR/bridge_Gauss_std1.8_x2_s.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/bridge_Gauss_std1.8_x2_s.png
--------------------------------------------------------------------------------
/figs/SR/bridge_Gauss_std1.8_x2_s_x2_corr_corrected.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/bridge_Gauss_std1.8_x2_s_x2_corr_corrected.png
--------------------------------------------------------------------------------
/figs/SR/zebra_Gauss_std3.2_x4_s.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/zebra_Gauss_std3.2_x4_s.png
--------------------------------------------------------------------------------
/figs/SR/zebra_Gauss_std3.2_x4_s_x4_corr_corrected.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/SR/zebra_Gauss_std3.2_x4_s_x4_corr_corrected.png
--------------------------------------------------------------------------------
/figs/blind_SR/bird.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/bird.png
--------------------------------------------------------------------------------
/figs/blind_SR/bird_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/bird_.png
--------------------------------------------------------------------------------
/figs/blind_SR/bird_x2_corr_est.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/bird_x2_corr_est.png
--------------------------------------------------------------------------------
/figs/blind_SR/butterfly.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/butterfly.png
--------------------------------------------------------------------------------
/figs/blind_SR/butterfly_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/butterfly_.png
--------------------------------------------------------------------------------
/figs/blind_SR/butterfly_x2_corr_est.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/butterfly_x2_corr_est.png
--------------------------------------------------------------------------------
/figs/blind_SR/chip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/chip.png
--------------------------------------------------------------------------------
/figs/blind_SR/chip_LR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/chip_LR.png
--------------------------------------------------------------------------------
/figs/blind_SR/chip_x4_corrected_est.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/chip_x4_corrected_est.png
--------------------------------------------------------------------------------
/figs/blind_SR/im_31.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_31.png
--------------------------------------------------------------------------------
/figs/blind_SR/im_31_x2_corr_est.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_31_x2_corr_est.png
--------------------------------------------------------------------------------
/figs/blind_SR/im_59.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_59.png
--------------------------------------------------------------------------------
/figs/blind_SR/im_59_x2_corr_est.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_59_x2_corr_est.png
--------------------------------------------------------------------------------
/figs/blind_SR/im_66.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_66.png
--------------------------------------------------------------------------------
/figs/blind_SR/im_66_x2_corr_est.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/im_66_x2_corr_est.png
--------------------------------------------------------------------------------
/figs/blind_SR/man2_Gauss_std3.2_x4_s.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/man2_Gauss_std3.2_x4_s.png
--------------------------------------------------------------------------------
/figs/blind_SR/man2_Gauss_std3.2_x4_s_x4_corr_l0_est.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shadyabh/Correction-Filter/9d7ca79281aa9a621a5bd2cc5e4880adfc8696c5/figs/blind_SR/man2_Gauss_std3.2_x4_s_x4_corr_l0_est.png
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.fft
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | from scipy import interpolate
6 | from scipy import fftpack
7 | from scipy import integrate
8 | from scipy import signal
9 | from PIL import Image
10 |
11 | def flip(x):
12 | return x.flip([2,3])
13 |
14 | def flip_torch(x):
15 | x_ = torch.flip(torch.roll(x, ((x.shape[2]//2), (x.shape[3]//2)), dims=(2,3)), dims=(2,3))
16 | return torch.roll(x_, (- (x_.shape[2]//2), -(x_.shape[3]//2)), dims=(2,3))
17 |
18 | def flip_np(x):
19 | x_ = np.flip(np.roll(x, ((x.shape[0]//2), (x.shape[1]//2)), (0,1)))
20 | return np.roll(x_, (- (x_.shape[0]//2), -(x_.shape[1]//2)), (0,1))
21 |
22 | def shift_by(H, shift):
23 | k_x = np.linspace(0, H.shape[3]-1, H.shape[3])
24 | k_y = np.linspace(0, H.shape[2]-1, H.shape[2])
25 |
26 | k_x[((k_x.shape[0] + 1)//2):] -= H.shape[3]
27 | k_y[((k_y.shape[0] + 1)//2):] -= H.shape[2]
28 |
29 | exp_x, exp_y = np.meshgrid(np.exp(-1j * 2* np.pi * k_x * shift / H.shape[3]), np.exp(-1j * 2* np.pi * k_y * shift / H.shape[2]))
30 |
31 | exp_x_torch = (torch.tensor(np.real(exp_x)) + 1j*torch.tensor(np.imag(exp_x))).unsqueeze(0).unsqueeze(0).to(H.device)
32 | exp_y_torch = (torch.tensor(np.real(exp_y)) + 1j*torch.tensor(np.imag(exp_y))).unsqueeze(0).unsqueeze(0).to(H.device)
33 |
34 | return H * exp_x_torch * exp_y_torch
35 |
36 | def fft_torch(x, s=None, zero_centered=True):
37 | # s = (Ny, Nx)
38 | __,__,H,W = x.shape
39 | if s == None:
40 | s = (H, W)
41 | if zero_centered:
42 | x_ = torch.roll(x, ((H//2), (W//2)), dims=(2,3))
43 | else:
44 | x_ = x
45 | x_pad = torch.nn.functional.pad(x_, (0, s[1] - W, 0, s[0] - H))
46 | if zero_centered:
47 | x_pad_ = torch.roll(x_pad, (- (H//2), -(W//2)), dims=(2,3))
48 | else:
49 | x_pad_ = x_pad
50 | return torch.fft.fftn(x_pad_, dim=(-2,-1))
51 |
52 | def bicubic_ker(x, y, a=-0.5):
53 | # X:
54 | abs_phase = np.abs(x)
55 | abs_phase2 = abs_phase**2
56 | abs_phase3 = abs_phase**3
57 | out_x = np.zeros_like(x)
58 | out_x[abs_phase <= 1] = (a+2)*abs_phase3[abs_phase <= 1] - (a+3)*abs_phase2[abs_phase <= 1] + 1
59 | out_x[(abs_phase > 1) & (abs_phase < 2)] = a*abs_phase3[(abs_phase > 1) & (abs_phase < 2)] -\
60 | 5*a*abs_phase2[(abs_phase > 1) & (abs_phase < 2)] +\
61 | 8*a*abs_phase[(abs_phase > 1) & (abs_phase < 2)] - 4*a
62 | # Y:
63 | abs_phase = np.abs(y)
64 | abs_phase2 = abs_phase**2
65 | abs_phase3 = abs_phase**3
66 | out_y = np.zeros_like(y)
67 | out_y[abs_phase <= 1] = (a+2)*abs_phase3[abs_phase <= 1] - (a+3)*abs_phase2[abs_phase <= 1] + 1
68 | out_y[(abs_phase > 1) & (abs_phase < 2)] = a*abs_phase3[(abs_phase > 1) & (abs_phase < 2)] -\
69 | 5*a*abs_phase2[(abs_phase > 1) & (abs_phase < 2)] +\
70 | 8*a*abs_phase[(abs_phase > 1) & (abs_phase < 2)] - 4*a
71 |
72 | return out_x*out_y
73 |
74 | def build_flt(f, size):
75 | is_even_x = not size[1] % 2
76 | is_even_y = not size[0] % 2
77 |
78 | grid_x = np.linspace(-(size[1]//2 - is_even_x*0.5), (size[1]//2 - is_even_x*0.5), size[1])
79 | grid_y = np.linspace(-(size[0]//2 - is_even_y*0.5), (size[0]//2 - is_even_y*0.5), size[0])
80 |
81 | x, y = np.meshgrid(grid_x, grid_y)
82 |
83 | h =f(x, y)
84 | h = np.roll(h, (- (h.shape[0]//2), -(h.shape[1]//2)), (0,1))
85 |
86 | return torch.tensor(h).float().unsqueeze(0).unsqueeze(0)
87 |
88 | def get_bicubic(scale, size=None):
89 | f = lambda x,y: bicubic_ker(x/scale, y/scale)
90 | if size:
91 | h = build_flt(f, (size[0], size[1]))
92 | else:
93 | h = build_flt(f, (4*scale + 8 + scale%2, 4*scale + 8 + scale%2))
94 | return h
95 |
96 | def get_box(supp, size=None):
97 | if size == None:
98 | size = (supp[0]*2, supp[1]*2)
99 |
100 | h = np.zeros(size)
101 |
102 | h[0:supp[0]//2 , 0:supp[1]//2] = 1
103 | h[0:supp[0]//2 , -(supp[1]//2):] = 1
104 | h[-(supp[0]//2):, 0:supp[1]//2] = 1
105 | h[-(supp[0]//2):, -(supp[1]//2):] = 1
106 |
107 | return torch.tensor(h).float().unsqueeze(0).unsqueeze(0)
108 |
109 | def get_delta(size):
110 | h = torch.zeros(1,1,size,size)
111 | h[0,0,0,0] = 1
112 | return h
113 |
114 | def get_gauss_flt(flt_size, std):
115 | f = lambda x,y: np.exp( -(x**2 + y**2)/2/std**2 )
116 | h = build_flt(f, (flt_size,flt_size))
117 | return h
118 |
119 | def fft_Filter_(x, A):
120 | X_fft = torch.fft.fftn(x, dim=(-2,-1))
121 | HX = A * X_fft
122 | return torch.fft.ifftn(HX, dim=(-2,-1))
123 |
124 | def fft_Down_(x, h, alpha):
125 | X_fft = torch.fft.fftn(x, dim=(-2,-1))
126 | H = fft_torch(h, s=X_fft.shape[2:4])
127 | HX = H * X_fft
128 | margin = (alpha - 1)//2
129 | y = torch.fft.ifftn(HX, dim=(-2,-1))[:,:,margin:HX.shape[2]-margin:alpha, margin:HX.shape[3]-margin:alpha]
130 | return y
131 |
132 | def fft_Up_(y, h, alpha):
133 | x = torch.zeros(y.shape[0], y.shape[1], y.shape[2]*alpha, y.shape[3]*alpha).to(y.device)
134 | H = fft_torch(h, s=x.shape[2:4])
135 | start = alpha//2
136 | x[:,:,start::alpha, start::alpha] = y
137 | X = torch.fft.fftn(x, dim=(-2,-1))
138 | HX = H * X
139 | return torch.fft.ifftn(HX, dim=(-2,-1))
140 |
141 | def zero_SV(H, eps):
142 | H_real = H.real
143 | H_imag = H.imag
144 | abs_H2 = H_real**2 + H_imag**2
145 | H[abs_H2/abs_H2.max() <= eps**2] = 0
146 | return H
147 |
148 | def dagger(X, eps=0, mode='Tikhonov'):
149 | real = X.real
150 | imag = X.imag
151 | abs2 = real**2 + imag**2
152 | if mode == 'naive':
153 | out = X.clone()
154 | out[abs2/abs2.max() > eps**2] = 1/X[abs2/abs2.max() > eps**2]
155 | out[abs2/abs2.max() <= eps**2] = 0
156 | return out
157 | if mode == 'Tikhonov':
158 | return X.conj()/(abs2 + eps**2)
159 |
160 | def load_img_torch(dir, device):
161 | I = Image.open(dir)
162 | I = transforms.ToTensor()(I).unsqueeze(0)
163 | return I.to(device)
164 |
165 | def save_img_torch(I, dir, clamp=True):
166 | if clamp:
167 | img = torch.clamp(I, 0, 1)[0,:].detach().cpu()
168 | else:
169 | img = I[0,:].detach().cpu()
170 | img = transforms.ToPILImage()(img)
171 | img.save(dir)
172 |
173 | def get_center_of_mass(s, device):
174 | idx_x = torch.linspace(0, s.shape[3]-1, s.shape[3]).to(device)
175 | idx_y = torch.linspace(0, s.shape[2]-1, s.shape[2]).to(device)
176 | i_y, i_x = torch.meshgrid(idx_y, idx_x)
177 |
178 | x_c = torch.sum(s*i_x/s.sum())
179 | y_c = torch.sum(s*i_y/s.sum())
180 |
181 | return x_c, y_c
182 |
183 | def crop_high_freq(I, crop_size, device):
184 | if(crop_size >= I.shape[2] or crop_size >= I.shape[3]):
185 | argmax_x = I.shape[3]//2 - crop_size//2
186 | argmax_y = I.shape[2]//2 - crop_size//2
187 | else:
188 | filt = torch.tensor([[ 0,-1, 0],
189 | [-1, 4,-1],
190 | [ 0,-1, 0]]).float().to(device)
191 | if(I.shape[1] > 1):
192 | I_gray = 0.2126*I[:,0:1,:,:] + 0.7152*I[:,1:2,:,:] + 0.0722*I[:,2:3,:,:]
193 | else:
194 | I_gray = I
195 | D = torch.abs(torch.nn.functional.conv2d(I_gray, filt.unsqueeze(0).unsqueeze(0)))
196 | Avg = torch.nn.functional.avg_pool2d(D, crop_size, stride=1, padding=0, ceil_mode=False)
197 | argmax = Avg.argmax()
198 | argmax_y = argmax//Avg.shape[3]
199 | argmax_x = argmax % Avg.shape[3]
200 | return argmax_x, argmax_y
--------------------------------------------------------------------------------