├── demo ├── RMSE_SR.txt ├── RMSE_bicubic.txt ├── 123.png ├── 2SR.png ├── 3HR.png └── 1bicubic.png ├── data └── dicts │ ├── Dh_512_US3_L0.1_PS5.pkl │ ├── Dl_512_US3_L0.1_PS5.pkl │ ├── Dh_1024_US3_L0.1_PS5.pkl │ ├── Dh_2048_US3_L0.1_PS3.pkl │ ├── Dh_2048_US3_L0.1_PS5.pkl │ ├── Dl_1024_US3_L0.1_PS5.pkl │ ├── Dl_2048_US3_L0.1_PS3.pkl │ └── Dl_2048_US3_L0.1_PS5.pkl ├── patch_pruning.py ├── backprojection.py ├── rescale.py ├── rnd_smp_patch.py ├── README.md ├── dict_train.py ├── featuresign.py ├── sample_patches.py ├── ScSR.py └── run.py /demo/RMSE_SR.txt: -------------------------------------------------------------------------------- 1 | 6.653080195131090058e+00 2 | -------------------------------------------------------------------------------- /demo/RMSE_bicubic.txt: -------------------------------------------------------------------------------- 1 | 8.342828868563501032e+00 2 | -------------------------------------------------------------------------------- /demo/123.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/demo/123.png -------------------------------------------------------------------------------- /demo/2SR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/demo/2SR.png -------------------------------------------------------------------------------- /demo/3HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/demo/3HR.png -------------------------------------------------------------------------------- /demo/1bicubic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/demo/1bicubic.png -------------------------------------------------------------------------------- /data/dicts/Dh_512_US3_L0.1_PS5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/data/dicts/Dh_512_US3_L0.1_PS5.pkl -------------------------------------------------------------------------------- /data/dicts/Dl_512_US3_L0.1_PS5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/data/dicts/Dl_512_US3_L0.1_PS5.pkl -------------------------------------------------------------------------------- /data/dicts/Dh_1024_US3_L0.1_PS5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/data/dicts/Dh_1024_US3_L0.1_PS5.pkl -------------------------------------------------------------------------------- /data/dicts/Dh_2048_US3_L0.1_PS3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/data/dicts/Dh_2048_US3_L0.1_PS3.pkl -------------------------------------------------------------------------------- /data/dicts/Dh_2048_US3_L0.1_PS5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/data/dicts/Dh_2048_US3_L0.1_PS5.pkl -------------------------------------------------------------------------------- /data/dicts/Dl_1024_US3_L0.1_PS5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/data/dicts/Dl_1024_US3_L0.1_PS5.pkl -------------------------------------------------------------------------------- /data/dicts/Dl_2048_US3_L0.1_PS3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/data/dicts/Dl_2048_US3_L0.1_PS3.pkl -------------------------------------------------------------------------------- /data/dicts/Dl_2048_US3_L0.1_PS5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrunoVox/ScSR/HEAD/data/dicts/Dl_2048_US3_L0.1_PS5.pkl -------------------------------------------------------------------------------- /patch_pruning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def patch_pruning(Xh, Xl): 4 | pvars = np.var(Xh, axis=0) 5 | threshold = np.percentile(pvars, 10) 6 | idx = pvars > threshold 7 | # print(pvars) 8 | Xh = Xh[:, idx] 9 | Xl = Xl[:, idx] 10 | return Xh, Xl -------------------------------------------------------------------------------- /backprojection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.transform import resize 3 | from scipy.signal import convolve2d 4 | import matplotlib.pyplot as plt 5 | 6 | def gauss2D(shape,sigma): 7 | """ 8 | 2D gaussian mask - should give the same result as MATLAB's 9 | fspecial('gaussian',[shape],[sigma]) 10 | """ 11 | m,n = [(ss-1.)/2. for ss in shape] 12 | y,x = np.ogrid[-m:m+1,-n:n+1] 13 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) 14 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 15 | sumh = h.sum() 16 | if sumh != 0: 17 | h /= sumh 18 | return h 19 | 20 | def backprojection(img_hr, img_lr, maxIter): 21 | p = gauss2D((5, 5), 1) 22 | p = np.multiply(p, p) 23 | p = np.divide(p, np.sum(p)) 24 | 25 | for i in range(maxIter): 26 | img_lr_ds = resize(img_hr, img_lr.shape, anti_aliasing=1) 27 | img_diff = img_lr - img_lr_ds 28 | 29 | img_diff = resize(img_diff, img_hr.shape) 30 | img_hr += convolve2d(img_diff, p, 'same') 31 | return img_hr -------------------------------------------------------------------------------- /rescale.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.transform import rescale 3 | from skimage.io import imread, imsave 4 | from os import listdir 5 | from tqdm import tqdm 6 | 7 | # Set train and val HR and LR paths 8 | train_hr_path = 'data/train_hr/' 9 | train_lr_path = 'data/train_lr/' 10 | val_hr_path = 'data/val_hr/' 11 | val_lr_path = 'data/val_lr/' 12 | 13 | numberOfImagesTrainHR = len(listdir(train_hr_path)) 14 | 15 | for i in tqdm(range(numberOfImagesTrainHR)): 16 | img_name = listdir(train_hr_path)[i] 17 | img = imread('{}{}'.format(train_hr_path, img_name)) 18 | new_img = rescale(img, (1/3), anti_aliasing=1) 19 | imsave('{}{}'.format(train_lr_path, img_name), new_img, quality=100) 20 | 21 | numberOfImagesValHR = len(listdir(val_hr_path)) 22 | 23 | for i in tqdm(range(numberOfImagesValHR)): 24 | img_name = listdir(val_hr_path)[i] 25 | img = imread('{}{}'.format(val_hr_path, img_name)) 26 | new_img = rescale(img, (1/3), anti_aliasing=1) 27 | imsave('{}{}'.format(val_lr_path, img_name), new_img, quality=100) -------------------------------------------------------------------------------- /rnd_smp_patch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from skimage.io import imread 4 | from sample_patches import sample_patches 5 | from tqdm import tqdm 6 | 7 | def rnd_smp_patch(img_path, patch_size, num_patch, upscale): 8 | img_dir = listdir(img_path) 9 | 10 | img_num = len(img_dir) 11 | nper_img = np.zeros((img_num, 1)) 12 | 13 | for i in tqdm(range(img_num)): 14 | img = imread('{}{}'.format(img_path, img_dir[i])) 15 | nper_img[i] = img.shape[0] * img.shape[1] 16 | 17 | nper_img = np.floor(nper_img * num_patch / np.sum(nper_img, axis=0)) 18 | 19 | for i in tqdm(range(img_num)): 20 | patch_num = int(nper_img[i]) 21 | img = imread('{}{}'.format(img_path, img_dir[i])) 22 | H, L = sample_patches(img, patch_size, patch_num, upscale) 23 | if i == 0: 24 | Xh = H 25 | Xl = L 26 | else: 27 | Xh = np.concatenate((Xh, H), axis=1) 28 | Xl = np.concatenate((Xl, L), axis=1) 29 | # print(Xh.shape) 30 | # patch_path = 31 | return Xh, Xl -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ScSR 2 | A Python implementation of the "Single Image Super-Resolution via Sparse Representation" paper. Done for educational purposes. 3 | 4 | # Requirements 5 | Sklearn 6 | 7 | Skimage 8 | 9 | spams 10 | 11 | tqdm 12 | 13 | # Usage 14 | Copy a dataset to "train_hr" folder. This step is not needed if you don't intend to learn your own dictionaries. 15 | 16 | Place some validation images in "val_hr" folder. 17 | 18 | Run rescale.py to create lower resolution images of Train and Val images. 19 | 20 | Open run.py and modify dictionary and parameter vars, if you want to. 21 | 22 | Execute run.py and check results in "data/results/". 23 | 24 | # Initial Results 25 | ![Bicubic interpolation; Super-Resolution; Original](/demo/123.png) 26 | 27 | Some optimizations for performance and result improvement are still required, but the code runs just fine in this initial state. 28 | 29 | This is a Python adaptation of the mentioned paper and the author's site (http://www.ifp.illinois.edu/~jyang29/) contains the original Matlab code, which made the work easier. 30 | 31 | This implementation will be improved in the future. 32 | -------------------------------------------------------------------------------- /dict_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rnd_smp_patch import rnd_smp_patch 3 | from patch_pruning import patch_pruning 4 | from spams import trainDL 5 | import pickle 6 | 7 | # ======================================================================== 8 | # Demo codes for dictionary training by joint sparse coding 9 | # 10 | # Reference 11 | # J. Yang et al. Image super-resolution as sparse representation of raw 12 | # image patches. CVPR 2008. 13 | # J. Yang et al. Image super-resolution via sparse representation. IEEE 14 | # Transactions on Image Processing, Vol 19, Issue 11, pp2861-2873, 2010 15 | # 16 | # Jianchao Yang 17 | # ECE Department, University of Illinois at Urbana-Champaign 18 | # For any questions, send email to jyang29@uiuc.edu 19 | # ========================================================================= 20 | 21 | dict_size = 2048 # dictionary size 22 | lmbd = 0.1 # sparsity regularization 23 | patch_size = 3 # image patch size 24 | nSmp = 100000 # number of patches to sample 25 | upscale = 3 # upscaling factor 26 | 27 | train_img_path = 'data/train_hr/' # Set your training images dir 28 | 29 | # Randomly sample image patches 30 | Xh, Xl = rnd_smp_patch(train_img_path, patch_size, nSmp, upscale) 31 | 32 | # Prune patches with small variances 33 | Xh, Xl = patch_pruning(Xh, Xl) 34 | Xh = np.asfortranarray(Xh) 35 | Xl = np.asfortranarray(Xl) 36 | 37 | # Dictionary learning 38 | Dh = trainDL(Xh, K=dict_size, lambda1=lmbd, iter=100) 39 | Dl = trainDL(Xl, K=dict_size, lambda1=lmbd, iter=100) 40 | 41 | # Saving dictionaries to files 42 | with open('data/dicts/'+ 'Dh_' + str(dict_size) + '_US' + str(upscale) + '_L' + str(lmbd) + '_PS' + str(patch_size) + '.pkl', 'wb') as f: 43 | pickle.dump(Dh, f, pickle.HIGHEST_PROTOCOL) 44 | 45 | with open('data/dicts/'+ 'Dl_' + str(dict_size) + '_US' + str(upscale) + '_L' + str(lmbd) + '_PS' + str(patch_size) + '.pkl', 'wb') as f: 46 | pickle.dump(Dl, f, pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /featuresign.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | def fss_yang(lmbd, A, b): 5 | 6 | """ 7 | L1QP_FeatureSign solves nonnegative quadradic programming 8 | using Feature Sign. 9 | 10 | min 0.5*x'*A*x+b'*x+\lambda*|x| 11 | 12 | [net,control]=NNQP_FeatureSign(net,A,b,control) 13 | """ 14 | 15 | EPS = 1e-9 16 | x = np.zeros((A.shape[1], 1)) 17 | # print('X =', x.shape) 18 | grad = np.dot(A, x) + b 19 | # print('GRAD =', grad.shape) 20 | ma = np.amax(np.multiply(abs(grad), np.isin(x, 0).T), axis=0) 21 | mi = np.zeros(grad.shape[1]) 22 | for j in range(grad.shape[1]): 23 | for i in range(grad.shape[0]): 24 | if grad[i, j] == ma[j]: 25 | mi[j] = i 26 | break 27 | mi = mi.astype(int) 28 | # print(grad[mi]) 29 | while True: 30 | 31 | if np.all(grad[mi]) > lmbd + EPS: 32 | x[mi] = (lmbd - grad[mi]) / A[mi, mi] 33 | elif np.all(grad[mi]) < - lmbd - EPS: 34 | x[mi] = (- lmbd - grad[mi]) / A[mi, mi] 35 | else: 36 | if np.all(x == 0): 37 | break 38 | 39 | while True: 40 | 41 | a = np.where(x != 0) 42 | Aa = A[a, a] 43 | ba = b[a] 44 | xa = x[a] 45 | 46 | vect = -lmbd * np.sign(xa) - ba 47 | x_new = np.linalg.lstsq(Aa, vect) 48 | idx = np.where(x_new != 0) 49 | o_new = np.dot((vect[idx] / 2 + ba[idx]).T, x_new[idx]) + lmbd * np.sum(abs(x_new[idx])) 50 | 51 | s = np.where(np.multiply(xa, x_new) < 0) 52 | if np.all(s == 0): 53 | x[a] = x_new 54 | loss = o_new 55 | break 56 | x_min = x_new 57 | o_min = o_new 58 | d = x_new - xa 59 | t = np.divide(d, xa) 60 | for zd in s.T: 61 | x_s = xa - d / t[zd] 62 | x_s[zd] = 0 63 | idx = np.where(x_s == 0) 64 | o_s = np.dot((np.dot(Aa[idx, idx], x_s[idx]) / 2 + ba[idx]).T, x_s[idx] + lmbd * np.sum(abs(x_s[idx]))) 65 | if o_s < o_min: 66 | x_min = x_s 67 | o_min = o_s 68 | 69 | x[a] = x_min 70 | loss = o_min 71 | 72 | grad = np.dot(A, sparse.csc_matrix(x)) + b 73 | 74 | ma, mi = max(np.multiply(abs(grad), np.where(x == 0))) 75 | if ma <= lmbd + EPS: 76 | break 77 | 78 | return x 79 | 80 | -------------------------------------------------------------------------------- /sample_patches.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.color import rgb2gray 3 | from skimage.transform import resize, rescale 4 | from scipy.signal import convolve2d 5 | from tqdm import tqdm 6 | 7 | def sample_patches(img, patch_size, patch_num, upscale): 8 | if img.shape[2] == 3: 9 | hIm = rgb2gray(img) 10 | else: 11 | hIm = img 12 | 13 | # Generate low resolution counter parts 14 | lIm = rescale(hIm, 1 / upscale) 15 | lIm = resize(lIm, hIm.shape) 16 | nrow, ncol = hIm.shape 17 | 18 | x = np.random.permutation(range(nrow - 2 * patch_size)) + patch_size 19 | y = np.random.permutation(range(ncol - 2 * patch_size)) + patch_size 20 | 21 | X, Y = np.meshgrid(x, y) 22 | xrow = np.ravel(X, order='F') 23 | ycol = np.ravel(Y, order='F') 24 | 25 | if patch_num < len(xrow): 26 | xrow = xrow[0 : patch_num] 27 | ycol = ycol[0 : patch_num] 28 | 29 | patch_num = len(xrow) 30 | 31 | H = np.zeros((patch_size ** 2, len(xrow))) 32 | L = np.zeros((4 * patch_size ** 2, len(xrow))) 33 | 34 | # Compute the first and second order gradients 35 | hf1 = [[-1, 0, 1], ] * 3 36 | vf1 = np.transpose(hf1) 37 | 38 | lImG11 = convolve2d(lIm, hf1, 'same') 39 | lImG12 = convolve2d(lIm, vf1, 'same') 40 | 41 | hf2 = [[1, 0, -2, 0, 1], ] * 3 42 | vf2 = np.transpose(hf2) 43 | 44 | lImG21 = convolve2d(lIm, hf2, 'same') 45 | lImG22 = convolve2d(lIm, vf2, 'same') 46 | 47 | for i in tqdm(range(patch_num)): 48 | row = xrow[i] 49 | col = ycol[i] 50 | 51 | Hpatch = np.ravel(hIm[row : row + patch_size, col : col + patch_size], order='F') 52 | # Hpatch = np.reshape(Hpatch, (Hpatch.shape[0], 1)) 53 | 54 | Lpatch1 = np.ravel(lImG11[row : row + patch_size, col : col + patch_size], order='F') 55 | Lpatch1 = np.reshape(Lpatch1, (Lpatch1.shape[0], 1)) 56 | Lpatch2 = np.ravel(lImG12[row : row + patch_size, col : col + patch_size], order='F') 57 | Lpatch2 = np.reshape(Lpatch2, (Lpatch2.shape[0], 1)) 58 | Lpatch3 = np.ravel(lImG21[row : row + patch_size, col : col + patch_size], order='F') 59 | Lpatch3 = np.reshape(Lpatch3, (Lpatch3.shape[0], 1)) 60 | Lpatch4 = np.ravel(lImG22[row : row + patch_size, col : col + patch_size], order='F') 61 | Lpatch4 = np.reshape(Lpatch4, (Lpatch4.shape[0], 1)) 62 | 63 | Lpatch = np.concatenate((Lpatch1, Lpatch2, Lpatch3, Lpatch4), axis=1) 64 | Lpatch = np.ravel(Lpatch, order='F') 65 | 66 | if i == 0: 67 | HP = np.zeros((Hpatch.shape[0], 1)) 68 | LP = np.zeros((Lpatch.shape[0], 1)) 69 | # print(HP.shape) 70 | HP[:, i] = Hpatch - np.mean(Hpatch) 71 | LP[:, i] = Lpatch 72 | else: 73 | HP_temp = Hpatch - np.mean(Hpatch) 74 | HP_temp = np.reshape(HP_temp, (HP_temp.shape[0], 1)) 75 | HP = np.concatenate((HP, HP_temp), axis=1) 76 | LP_temp = Lpatch 77 | LP_temp = np.reshape(LP_temp, (LP_temp.shape[0], 1)) 78 | LP = np.concatenate((LP, LP_temp), axis=1) 79 | 80 | return HP, LP -------------------------------------------------------------------------------- /ScSR.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from sklearn.preprocessing import normalize 4 | from skimage.io import imread 5 | from skimage.color import rgb2ycbcr 6 | from skimage.transform import resize 7 | import pickle 8 | from featuresign import fss_yang 9 | from scipy.signal import convolve2d 10 | from tqdm import tqdm 11 | 12 | def extract_lr_feat(img_lr): 13 | h, w = img_lr.shape 14 | img_lr_feat = np.zeros((h, w, 4)) 15 | 16 | # First order gradient filters 17 | hf1 = [[-1, 0, 1], ] * 3 18 | vf1 = np.transpose(hf1) 19 | 20 | img_lr_feat[:, :, 0] = convolve2d(img_lr, hf1, 'same') 21 | img_lr_feat[:, :, 1] = convolve2d(img_lr, vf1, 'same') 22 | 23 | # Second order gradient filters 24 | hf2 = [[1, 0, -2, 0, 1], ] * 3 25 | vf2 = np.transpose(hf2) 26 | 27 | img_lr_feat[:, :, 2] = convolve2d(img_lr, hf2, 'same') 28 | img_lr_feat[:, :, 3] = convolve2d(img_lr, vf2, 'same') 29 | 30 | return img_lr_feat 31 | 32 | def create_list_step(start, stop, step): 33 | list_step = [] 34 | for i in range(start, stop, step): 35 | list_step = np.append(list_step, i) 36 | return list_step 37 | 38 | def lin_scale(xh, us_norm): 39 | hr_norm = np.sqrt(np.sum(np.multiply(xh, xh))) 40 | 41 | if hr_norm > 0: 42 | s = us_norm * 1.2 / hr_norm 43 | xh = np.multiply(xh, s) 44 | return xh 45 | 46 | def ScSR(img_lr_y, size, upscale, Dh, Dl, lmbd, overlap): 47 | 48 | patch_size = 3 49 | 50 | img_us = resize(img_lr_y, size) 51 | img_us_height, img_us_width = img_us.shape 52 | img_hr = np.zeros(img_us.shape) 53 | cnt_matrix = np.zeros(img_us.shape) 54 | 55 | img_lr_y_feat = extract_lr_feat(img_hr) 56 | 57 | gridx = np.append(create_list_step(0, img_us_width - patch_size - 1, patch_size - overlap), img_us_width - patch_size - 1) 58 | gridy = np.append(create_list_step(0, img_us_height - patch_size - 1, patch_size - overlap), img_us_height - patch_size - 1) 59 | 60 | count = 0 61 | 62 | for m in tqdm(range(0, len(gridx))): 63 | for n in range(0, len(gridy)): 64 | count += 1 65 | xx = int(gridx[m]) 66 | yy = int(gridy[n]) 67 | 68 | us_patch = img_us[yy : yy + patch_size, xx : xx + patch_size] 69 | us_mean = np.mean(np.ravel(us_patch, order='F')) 70 | us_patch = np.ravel(us_patch, order='F') - us_mean 71 | us_norm = np.sqrt(np.sum(np.multiply(us_patch, us_patch))) 72 | 73 | feat_patch = img_lr_y_feat[yy : yy + patch_size, xx : xx + patch_size, :] 74 | feat_patch = np.ravel(feat_patch, order='F') 75 | feat_norm = np.sqrt(np.sum(np.multiply(feat_patch, feat_patch))) 76 | 77 | if feat_norm > 1: 78 | y = np.divide(feat_patch, feat_norm) 79 | else: 80 | y = feat_patch 81 | 82 | b = np.dot(np.multiply(Dl.T, -1), y) 83 | w = fss_yang(lmbd, Dl, b) 84 | 85 | hr_patch = np.dot(Dh, w) 86 | hr_patch = lin_scale(hr_patch, us_norm) 87 | 88 | hr_patch = np.reshape(hr_patch, (patch_size, -1)) 89 | hr_patch += us_mean 90 | 91 | img_hr[yy : yy + patch_size, xx : xx + patch_size] += hr_patch 92 | cnt_matrix[yy : yy + patch_size, xx : xx + patch_size] += 1 93 | 94 | index = np.where(cnt_matrix < 1)[0] 95 | img_hr[index] = img_us[index] 96 | 97 | cnt_matrix[index] = 1 98 | img_hr = np.divide(img_hr, cnt_matrix) 99 | 100 | return img_hr -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir, mkdir 3 | from os.path import isdir 4 | from skimage.io import imread, imsave 5 | from skimage.color import rgb2ycbcr, ycbcr2rgb 6 | from skimage.transform import resize 7 | from scipy.misc import imresize 8 | from tqdm import tqdm 9 | import pickle 10 | from ScSR import ScSR 11 | from backprojection import backprojection 12 | from sklearn.metrics import mean_squared_error 13 | from sklearn.preprocessing import normalize 14 | 15 | def normalize_signal(img, channel): 16 | if np.mean(img[:, :, channel]) * 255 > np.mean(img_lr_ori[:, :, channel]): 17 | ratio = np.mean(img_lr_ori[:, :, channel]) / (np.mean(img[:, :, channel]) * 255) 18 | img[:, :, channel] = np.multiply(img[:, :, channel], ratio) 19 | elif np.mean(img[:, :, channel]) * 255 < np.mean(img_lr_ori[:, :, channel]): 20 | ratio = np.mean(img_lr_ori[:, :, channel]) / (np.mean(img[:, :, channel]) * 255) 21 | img[:, :, channel] = np.multiply(img[:, :, channel], ratio) 22 | return img[:, :, channel] 23 | 24 | def normalize_max(img): 25 | for m in range(img.shape[0]): 26 | for n in range(img.shape[1]): 27 | if img[m, n, 0] > 1: 28 | img[m, n, 0] = 1 29 | if img[m, n, 1] > 1: 30 | img[m, n, 1] = 1 31 | if img[m, n, 2] > 1: 32 | img[m, n, 2] = 1 33 | return img 34 | 35 | # Set which dictionary you want to use 36 | D_size = 2048 37 | US_mag = 3 38 | lmbd = 0.1 39 | patch_size= 3 40 | 41 | dict_name = str(D_size) + '_US' + str(US_mag) + '_L' + str(lmbd) + '_PS' + str(patch_size) 42 | 43 | with open('data/dicts/Dh_' + dict_name + '.pkl', 'rb') as f: 44 | Dh = pickle.load(f) 45 | Dh = normalize(Dh) 46 | with open('data/dicts/Dl_' + dict_name + '.pkl', 'rb') as f: 47 | Dl = pickle.load(f) 48 | Dl = normalize(Dl) 49 | 50 | ### SET PARAMETERS 51 | img_lr_dir = 'data/val_lr/' 52 | img_hr_dir = 'data/val_hr/' 53 | overlap = 1 54 | lmbd = 0.1 55 | upscale = 3 56 | maxIter = 100 57 | 58 | ### 59 | 60 | img_lr_file = listdir(img_lr_dir) 61 | 62 | for i in tqdm(range(len(img_lr_file))): 63 | # Read test image 64 | img_name = img_lr_file[i] 65 | img_name_dir = list(img_name) 66 | img_name_dir = np.delete(np.delete(np.delete(np.delete(img_name_dir, -1), -1), -1), -1) 67 | img_name_dir = ''.join(img_name_dir) 68 | if isdir('data/results/' + dict_name + '_' + img_name_dir) == False: 69 | new_dir = mkdir('{}{}'.format('data/results/' + dict_name + '_', img_name_dir)) 70 | img_lr = imread('{}{}'.format(img_lr_dir, img_name)) 71 | 72 | # Read and save ground truth image 73 | img_hr = imread('{}{}'.format(img_hr_dir, img_name)) 74 | imsave('{}{}{}{}'.format('data/results/' + dict_name + '_', img_name_dir, '/', '3HR.png'), img_hr, quality=100) 75 | img_hr_y = rgb2ycbcr(img_hr)[:, :, 0] 76 | 77 | # Change color space 78 | img_lr_ori = img_lr 79 | temp = img_lr 80 | img_lr = rgb2ycbcr(img_lr) 81 | img_lr_y = img_lr[:, :, 0] 82 | img_lr_cb = img_lr[:, :, 1] 83 | img_lr_cr = img_lr[:, :, 2] 84 | 85 | # Upscale chrominance to color SR images 86 | img_sr_cb = resize(img_lr_cb, (img_hr.shape[0], img_hr.shape[1]), order=0) 87 | img_sr_cr = resize(img_lr_cr, (img_hr.shape[0], img_hr.shape[1]), order=0) 88 | 89 | # Super Resolution via Sparse Representation 90 | img_sr_y = ScSR(img_lr_y, img_hr_y.shape, upscale, Dh, Dl, lmbd, overlap) 91 | img_sr_y = backprojection(img_sr_y, img_lr_y, maxIter) 92 | 93 | # Create colored SR images 94 | img_sr = np.stack((img_sr_y, img_sr_cb, img_sr_cr), axis=2) 95 | img_sr = ycbcr2rgb(img_sr) 96 | 97 | # Signal normalization 98 | for channel in range(len(img_sr.shape[2])): 99 | img_sr[:, :, channel] = normalize_signal(img_sr, channel) 100 | 101 | # Maximum pixel intensity normalization 102 | img_sr = normalize_max(img_sr) 103 | 104 | # Bicubic interpolation for reference 105 | img_bc = resize(img_lr_ori, (img_hr.shape[0], img_hr.shape[1])) 106 | imsave('{}{}{}{}'.format('data/results/' + dict_name + '_', img_name_dir, '/', '1bicubic.png'), img_bc, quality=100) 107 | img_bc_y = rgb2ycbcr(img_bc)[:, :, 0] 108 | 109 | # Compute RMSE for the illuminance 110 | rmse_bc_hr = np.sqrt(mean_squared_error(img_hr_y, img_bc_y)) 111 | rmse_bc_hr = np.zeros((1,)) + rmse_bc_hr 112 | rmse_sr_hr = np.sqrt(mean_squared_error(img_hr_y, img_sr_y)) 113 | rmse_sr_hr = np.zeros((1,)) + rmse_sr_hr 114 | np.savetxt('{}{}{}{}'.format('data/results/' + dict_name + '_', img_name_dir, '/', 'RMSE_bicubic.txt'), rmse_bc_hr) 115 | np.savetxt('{}{}{}{}'.format('data/results/' + dict_name + '_', img_name_dir, '/', 'RMSE_SR.txt'), rmse_sr_hr) 116 | 117 | imsave('{}{}{}{}'.format('data/results/' + dict_name + '_', img_name_dir, '/', '2SR.png'), img_sr, quality=100) --------------------------------------------------------------------------------