├── .gitattributes ├── IQA_pytorch ├── CW_SSIM.py ├── DISTS.py ├── FSIM.py ├── GMSD.py ├── LPIPSvgg.py ├── MAD.py ├── MS_SSIM.py ├── NLPD.py ├── SSIM.py ├── SteerPyrComplex.py ├── SteerPyrSpace.py ├── SteerPyrUtils.py ├── VIF.py ├── VIFs.py ├── VSI.py ├── __init__.py ├── __pycache__ │ ├── CW_SSIM.cpython-37.pyc │ ├── DISTS.cpython-37.pyc │ ├── MAD.cpython-37.pyc │ ├── SteerPyrComplex.cpython-37.pyc │ ├── SteerPyrSpace.cpython-37.pyc │ ├── SteerPyrUtils.cpython-37.pyc │ ├── VIF.cpython-37.pyc │ └── utils.cpython-37.pyc ├── images │ ├── r0.png │ └── r1.png ├── utils.py └── weights │ ├── DISTS.pt │ └── LPIPSvgg.pt ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples └── recover.py ├── images ├── diagram.svg ├── r0.png └── r1.png ├── requirements.txt └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /IQA_pytorch/CW_SSIM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from torchvision import transforms 6 | from .utils import fspecial_gauss 7 | from .SteerPyrComplex import SteerablePyramid 8 | import math 9 | 10 | 11 | class CW_SSIM(torch.nn.Module): 12 | ''' 13 | This is an pytorch implementation of Complex-Wavelet 14 | Structural SIMilarity (CW-SSIM) index. 15 | 16 | M. P. Sampat, Z. Wang, S. Gupta, A. C. Bovik, M. K. Markey. 17 | "Complex Wavelet Structural Similarity: A New Image Similarity Index", 18 | IEEE Transactions on Image Processing, 18(11), 2385-401, 2009. 19 | 20 | Matlab version: 21 | https://www.mathworks.com/matlabcentral/fileexchange/43017-complex-wavelet-structural-similarity-index-cw-ssim 22 | ''' 23 | def __init__(self, imgSize=[256,256], channels=3, level=4, ori=8, device = torch.device("cuda")): 24 | assert imgSize[0]==imgSize[1] 25 | super(CW_SSIM, self).__init__() 26 | self.ori = ori 27 | self.level = level 28 | self.channels = channels 29 | self.win7 = (torch.ones(channels,1,7,7)/(7*7)).to(device) 30 | s = imgSize[0]/2**(level-1) 31 | self.w = fspecial_gauss(s-7+1, s/4, 1).to(device) 32 | self.SP = SteerablePyramid(imgSize=imgSize, K=ori, N=level, hilb=True,device=device) 33 | 34 | def abs(self, x): 35 | return torch.sqrt(x[:,0,...]**2+x[:,1,...]**2+1e-12) 36 | 37 | def conj(self, x, y): 38 | a = x[:,0,...] 39 | b = x[:,1,...] 40 | c = y[:,0,...] 41 | d = -y[:,1,...] 42 | return torch.stack((a*c-b*d,b*c+a*d),dim=1) 43 | 44 | def conv2d_complex(self, x, win, groups = 1): 45 | real = F.conv2d(x[:,0,...], win, groups = groups)# - F.conv2d(x[:,1], win, groups = groups) 46 | imaginary = F.conv2d(x[:,1,...], win, groups = groups)# + F.conv2d(x[:,0], win, groups = groups) 47 | return torch.stack((real,imaginary),dim=1) 48 | 49 | def cw_ssim(self, x, y): 50 | cw_x = self.SP(x) 51 | cw_y = self.SP(y) 52 | bandind = self.level 53 | band_cssim = [] 54 | for i in range(self.ori): 55 | 56 | band1 = cw_x[bandind][:,:,:,i,:,:] 57 | band2 = cw_y[bandind][:,:,:,i,:,:] 58 | corr = self.conj(band1,band2) 59 | corr_band = self.conv2d_complex(corr, self.win7, groups = self.channels) 60 | varr = (self.abs(band1))**2+(self.abs(band2))**2 61 | varr_band = F.conv2d(varr, self.win7, stride=1, padding=0, groups = self.channels) 62 | cssim_map = (2*self.abs(corr_band) + 1e-12)/(varr_band + 1e-12) 63 | band_cssim.append((cssim_map*self.w.repeat(cssim_map.shape[0],1,1,1)).sum([2,3]).mean(1)) 64 | 65 | return torch.stack(band_cssim,dim=1).mean(1) 66 | 67 | def forward(self, x, y, as_loss=True): 68 | assert x.shape == y.shape 69 | x = x * 255 70 | y = y * 255 71 | if as_loss: 72 | score = self.cw_ssim(x, y) 73 | return 1 - score.mean() 74 | else: 75 | with torch.no_grad(): 76 | score = self.cw_ssim(x, y) 77 | return score 78 | 79 | if __name__ == '__main__': 80 | from PIL import Image 81 | import argparse 82 | from utils import prepare_image 83 | 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--ref', type=str, default='images/r0.png') 86 | parser.add_argument('--dist', type=str, default='images/r1.png') 87 | args = parser.parse_args() 88 | 89 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 90 | 91 | ref = prepare_image(Image.open(args.ref).convert("L"),repeatNum=1).to(device) 92 | dist = prepare_image(Image.open(args.dist).convert("L"),repeatNum=1).to(device) 93 | dist.requires_grad_(True) 94 | 95 | model = CW_SSIM(imgSize=[256,256], channels=1, level=4, ori=8) 96 | 97 | score = model(dist, ref, as_loss=False) 98 | print('score: %.4f' % score.item()) 99 | # score: 0.9561 100 | -------------------------------------------------------------------------------- /IQA_pytorch/DISTS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import torch 5 | from torchvision import models,transforms 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import inspect 9 | from .utils import downsample 10 | 11 | class L2pooling(nn.Module): 12 | def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): 13 | super(L2pooling, self).__init__() 14 | self.padding = (filter_size - 2 )//2 15 | self.stride = stride 16 | self.channels = channels 17 | a = np.hanning(filter_size)[1:-1] 18 | # a = torch.hann_window(5,periodic=False) 19 | g = torch.Tensor(a[:,None]*a[None,:]) 20 | g = g/torch.sum(g) 21 | self.register_buffer('filter', g[None,None,:,:].repeat((self.channels,1,1,1))) 22 | 23 | def forward(self, input): 24 | input = input**2 25 | out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1]) 26 | return (out+1e-12).sqrt() 27 | 28 | class DISTS(torch.nn.Module): 29 | ''' 30 | Refer to https://github.com/dingkeyan93/DISTS 31 | ''' 32 | def __init__(self, channels=3, load_weights=True): 33 | assert channels == 3 34 | super(DISTS, self).__init__() 35 | vgg_pretrained_features = models.vgg16(pretrained=True).features 36 | self.stage1 = torch.nn.Sequential() 37 | self.stage2 = torch.nn.Sequential() 38 | self.stage3 = torch.nn.Sequential() 39 | self.stage4 = torch.nn.Sequential() 40 | self.stage5 = torch.nn.Sequential() 41 | for x in range(0,4): 42 | self.stage1.add_module(str(x), vgg_pretrained_features[x]) 43 | self.stage2.add_module(str(4), L2pooling(channels=64)) 44 | for x in range(5, 9): 45 | self.stage2.add_module(str(x), vgg_pretrained_features[x]) 46 | self.stage3.add_module(str(9), L2pooling(channels=128)) 47 | for x in range(10, 16): 48 | self.stage3.add_module(str(x), vgg_pretrained_features[x]) 49 | self.stage4.add_module(str(16), L2pooling(channels=256)) 50 | for x in range(17, 23): 51 | self.stage4.add_module(str(x), vgg_pretrained_features[x]) 52 | self.stage5.add_module(str(23), L2pooling(channels=512)) 53 | for x in range(24, 30): 54 | self.stage5.add_module(str(x), vgg_pretrained_features[x]) 55 | 56 | for param in self.parameters(): 57 | param.requires_grad = False 58 | 59 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1,-1,1,1)) 60 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1,-1,1,1)) 61 | 62 | self.chns = [3,64,128,256,512,512] 63 | self.register_parameter("alpha", nn.Parameter(torch.randn(1, sum(self.chns),1,1))) 64 | self.register_parameter("beta", nn.Parameter(torch.randn(1, sum(self.chns),1,1))) 65 | self.alpha.data.normal_(0.1,0.01) 66 | self.beta.data.normal_(0.1,0.01) 67 | if load_weights: 68 | weights = torch.load(os.path.abspath(os.path.join(inspect.getfile(DISTS),'..','weights/DISTS.pt'))) 69 | self.alpha.data = weights['alpha'] 70 | self.beta.data = weights['beta'] 71 | 72 | def forward_once(self, x): 73 | h = (x-self.mean)/self.std 74 | h = self.stage1(h) 75 | h_relu1_2 = h 76 | h = self.stage2(h) 77 | h_relu2_2 = h 78 | h = self.stage3(h) 79 | h_relu3_3 = h 80 | h = self.stage4(h) 81 | h_relu4_3 = h 82 | h = self.stage5(h) 83 | h_relu5_3 = h 84 | return [x,h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3] 85 | 86 | def forward(self, x, y, as_loss=True, resize = True): 87 | assert x.shape == y.shape 88 | if resize: 89 | x, y = downsample(x, y) 90 | if as_loss: 91 | feats0 = self.forward_once(x) 92 | feats1 = self.forward_once(y) 93 | else: 94 | with torch.no_grad(): 95 | feats0 = self.forward_once(x) 96 | feats1 = self.forward_once(y) 97 | dist1 = 0 98 | dist2 = 0 99 | c1 = 1e-6 100 | c2 = 1e-6 101 | w_sum = self.alpha.sum() + self.beta.sum() 102 | alpha = torch.split(self.alpha/w_sum, self.chns, dim=1) 103 | beta = torch.split(self.beta/w_sum, self.chns, dim=1) 104 | for k in range(len(self.chns)): 105 | x_mean = feats0[k].mean([2,3], keepdim=True) 106 | y_mean = feats1[k].mean([2,3], keepdim=True) 107 | S1 = (2*x_mean*y_mean+c1)/(x_mean**2+y_mean**2+c1) 108 | dist1 = dist1+(alpha[k]*S1).sum(1,keepdim=True) 109 | 110 | x_var = ((feats0[k]-x_mean)**2).mean([2,3], keepdim=True) 111 | y_var = ((feats1[k]-y_mean)**2).mean([2,3], keepdim=True) 112 | xy_cov = (feats0[k]*feats1[k]).mean([2,3],keepdim=True) - x_mean*y_mean 113 | S2 = (2*xy_cov+c2)/(x_var+y_var+c2) 114 | dist2 = dist2+(beta[k]*S2).sum(1,keepdim=True) 115 | 116 | score = 1 - (dist1+dist2).squeeze() 117 | if as_loss: 118 | return score.mean() 119 | else: 120 | return score 121 | 122 | 123 | if __name__ == '__main__': 124 | from PIL import Image 125 | import argparse 126 | from utils import prepare_image 127 | 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--ref', type=str, default='images/r0.png') 130 | parser.add_argument('--dist', type=str, default='images/r1.png') 131 | args = parser.parse_args() 132 | 133 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 134 | 135 | ref = prepare_image(Image.open(args.ref).convert("RGB")).to(device) 136 | dist = prepare_image(Image.open(args.dist).convert("RGB")).to(device) 137 | 138 | model = DISTS().to(device) 139 | # print_network(model) 140 | 141 | score = model(ref, dist, as_loss=False) 142 | print('score: %.4f' % score.item()) 143 | # score: 0.3347 144 | 145 | -------------------------------------------------------------------------------- /IQA_pytorch/FSIM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import torch 5 | from torchvision import models,transforms 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import inspect 9 | from numpy.fft import fft2, ifft2, fftshift, ifftshift 10 | import math 11 | from .utils import abs, real, imag, downsample 12 | 13 | def lowpassfilter(size, cutoff, n): 14 | """ 15 | Constructs a low-pass Butterworth filter: 16 | f = 1 / (1 + (w/cutoff)^2n) 17 | usage: f = lowpassfilter(sze, cutoff, n) 18 | where: size is a tuple specifying the size of filter to construct 19 | [rows cols]. 20 | cutoff is the cutoff frequency of the filter 0 - 0.5 21 | n is the order of the filter, the higher n is the sharper 22 | the transition is. (n must be an integer >= 1). Note 23 | that n is doubled so that it is always an even integer. 24 | The frequency origin of the returned filter is at the corners. 25 | """ 26 | 27 | if cutoff < 0. or cutoff > 0.5: 28 | raise Exception('cutoff must be between 0 and 0.5') 29 | elif n % 1: 30 | raise Exception('n must be an integer >= 1') 31 | if len(size) == 1: 32 | rows = cols = size 33 | else: 34 | rows, cols = size 35 | 36 | if (cols % 2): 37 | xvals = np.arange(-(cols - 1) / 2., 38 | ((cols - 1) / 2.) + 1) / float(cols - 1) 39 | else: 40 | xvals = np.arange(-cols / 2., cols / 2.) / float(cols) 41 | 42 | if (rows % 2): 43 | yvals = np.arange(-(rows - 1) / 2., 44 | ((rows - 1) / 2.) + 1) / float(rows - 1) 45 | else: 46 | yvals = np.arange(-rows / 2., rows / 2.) / float(rows) 47 | 48 | x, y = np.meshgrid(xvals, yvals, sparse=True) 49 | radius = np.sqrt(x * x + y * y) 50 | 51 | return ifftshift(1. / (1. + (radius / cutoff) ** (2. * n))) 52 | 53 | def filtergrid(rows, cols): 54 | 55 | # Set up u1 and u2 matrices with ranges normalised to +/- 0.5 56 | u1, u2 = np.meshgrid(np.linspace(-0.5, 0.5, cols, endpoint=(cols % 2)), 57 | np.linspace(-0.5, 0.5, rows, endpoint=(rows % 2)), 58 | sparse=True) 59 | 60 | # Quadrant shift to put 0 frequency at the top left corner 61 | u1 = ifftshift(u1) 62 | u2 = ifftshift(u2) 63 | 64 | # Compute frequency values as a radius from centre (but quadrant shifted) 65 | radius = np.sqrt(u1 * u1 + u2 * u2) 66 | 67 | return radius, u1, u2 68 | 69 | def phasecong2(im): 70 | nscale = 4 71 | norient = 4 72 | minWaveLength = 6 73 | mult = 2 74 | sigmaOnf = 0.55 75 | dThetaOnSigma = 1.2 76 | k = 2.0 77 | epsilon = .0001 78 | thetaSigma = np.pi/norient/dThetaOnSigma 79 | 80 | _, _, rows,cols = im.shape 81 | imagefft = torch.rfft(im,2,onesided=False) 82 | 83 | lp = lowpassfilter((rows,cols),.45,15) 84 | 85 | radius, _, _ = filtergrid(rows, cols) 86 | radius[0, 0] = 1. 87 | logGaborList = [] 88 | logGaborDenom = 2. * np.log(sigmaOnf) ** 2. 89 | for s in range(nscale): 90 | wavelength = minWaveLength * mult ** s 91 | fo = 1. / wavelength # Centre frequency of filter 92 | logRadOverFo = (np.log(radius / fo)) 93 | logGabor = np.exp(-(logRadOverFo * logRadOverFo) / logGaborDenom) 94 | logGabor *= lp # Apply the low-pass filter 95 | logGabor[0, 0] = 0. # Undo the radius fudge 96 | logGaborList.append(logGabor) 97 | 98 | # Matrix of radii 99 | cy = np.floor(rows/2) 100 | cx = np.floor(cols/2) 101 | y, x = np.mgrid[0:rows, 0:cols] 102 | y = (y-cy)/rows 103 | x = (x-cx)/cols 104 | radius = np.sqrt(x**2 + y**2) 105 | theta = np.arctan2(-y, x) 106 | radius = ifftshift(radius) # Quadrant shift radius and theta so that filters 107 | theta = ifftshift(theta) # are constructed with 0 frequency at the corners. 108 | radius[0,0] = 1 109 | sintheta = np.sin(theta) 110 | costheta = np.cos(theta) 111 | 112 | spreadList = [] 113 | for o in np.arange(norient): 114 | angl = o*np.pi/norient # Filter angle. 115 | ds = sintheta * math.cos(angl) - costheta * math.sin(angl) # Difference in sine. 116 | dc = costheta * math.cos(angl) + sintheta * math.sin(angl) # Difference in cosine. 117 | dtheta = np.abs(np.arctan2(ds,dc)) # Absolute angular distance. 118 | # dtheta = np.minimum(dtheta*NumberAngles/2, math.pi) 119 | spread = np.exp((-dtheta**2) / (2 * thetaSigma**2)); # Calculate the angular 120 | spreadList.append(spread) 121 | 122 | ifftFilterArray = [[],[],[],[]] 123 | filterArray = [[],[],[],[]] 124 | for o in np.arange(norient): 125 | for s in np.arange(nscale): 126 | filter = logGaborList[s] * spreadList[o] 127 | filterArray[o].append(torch.from_numpy(filter).reshape(1,1,rows,cols).float().to(im.device)) 128 | ifftFilt = np.real(ifft2(filter))*math.sqrt(rows*cols) 129 | ifftFilterArray[o].append(torch.from_numpy(ifftFilt).reshape(1,1,rows,cols).float().to(im.device)) 130 | 131 | EnergyAll = 0 132 | AnAll = 0 133 | for o in np.arange(norient): 134 | sumE_ThisOrient = 0 135 | sumO_ThisOrient = 0 136 | sumAn_ThisOrient = 0 137 | Energy = 0 138 | MatrixEOList = [] 139 | for s in np.arange(nscale): 140 | filter = filterArray[o][s] 141 | c = imagefft * filter.unsqueeze(-1).repeat(1,1,1,1,2) 142 | MatrixEO = torch.ifft(imagefft * filter.unsqueeze(-1).repeat(1,1,1,1,2), 2) 143 | MatrixEOList.append(MatrixEO) 144 | 145 | An = abs(MatrixEO) # Amplitude of even & odd filter response. 146 | sumAn_ThisOrient = sumAn_ThisOrient + An # Sum of amplitude responses. 147 | sumE_ThisOrient = sumE_ThisOrient + real(MatrixEO) # Sum of even filter convolution results. 148 | sumO_ThisOrient = sumO_ThisOrient + imag(MatrixEO) # Sum of odd filter convolution results. 149 | 150 | if s == 0: 151 | EM_n = torch.sum(filter**2,dim=[1,2,3]) 152 | maxAn = An 153 | else: 154 | maxAn = torch.max(maxAn,An) 155 | 156 | XEnergy = torch.sqrt(sumE_ThisOrient**2 + sumO_ThisOrient**2+1e-12) + epsilon 157 | MeanE = sumE_ThisOrient / XEnergy 158 | MeanO = sumO_ThisOrient / XEnergy 159 | for s in np.arange(nscale): 160 | EO = MatrixEOList[s] 161 | E = real(EO) 162 | O = imag(EO) 163 | Energy = Energy + E*MeanE + O*MeanO - torch.abs(E*MeanO - O*MeanE) 164 | 165 | meanE2n = torch.median((abs(MatrixEOList[0])**2).view(im.shape[0],-1),dim=1)[0] / -math.log(0.5) 166 | 167 | noisePower = meanE2n/EM_n 168 | EstSumAn2 = 0 169 | for s in np.arange(nscale): 170 | EstSumAn2 = EstSumAn2 + ifftFilterArray[o][s]**2 171 | EstSumAiAj = 0 172 | for si in np.arange(nscale-1): 173 | for sj in np.arange(si+1,nscale): 174 | EstSumAiAj = EstSumAiAj + ifftFilterArray[o][si]*ifftFilterArray[o][sj] 175 | sumEstSumAn2 = torch.sum(EstSumAn2,dim=[1,2,3]) 176 | sumEstSumAiAj = torch.sum(EstSumAiAj,dim=[1,2,3]) 177 | 178 | EstNoiseEnergy2 = 2*noisePower*sumEstSumAn2 + 4*noisePower*sumEstSumAiAj 179 | 180 | tau = torch.sqrt(EstNoiseEnergy2/2+1e-12) 181 | EstNoiseEnergySigma = torch.sqrt( (2-math.pi/2)*tau**2 +1e-12) 182 | T = tau*math.sqrt(math.pi/2) + k*EstNoiseEnergySigma 183 | T = T/1.7 184 | Energy = F.relu(Energy - T.view(-1,1,1,1)) 185 | 186 | EnergyAll = EnergyAll + Energy 187 | AnAll = AnAll + sumAn_ThisOrient 188 | 189 | ResultPC = EnergyAll / AnAll 190 | return ResultPC 191 | 192 | def fsim(imageRef, imageDis): 193 | 194 | channels = imageRef.shape[1] 195 | if channels == 3: 196 | Y1 = (0.299 * imageRef[:,0,:,:] + 0.587 * imageRef[:,1,:,:] + 0.114 * imageRef[:,2,:,:]).unsqueeze(1) 197 | Y2 = (0.299 * imageDis[:,0,:,:] + 0.587 * imageDis[:,1,:,:] + 0.114 * imageDis[:,2,:,:]).unsqueeze(1) 198 | I1 = (0.596 * imageRef[:,0,:,:] - 0.274 * imageRef[:,1,:,:] - 0.322 * imageRef[:,2,:,:]).unsqueeze(1) 199 | I2 = (0.596 * imageDis[:,0,:,:] - 0.274 * imageDis[:,1,:,:] - 0.322 * imageDis[:,2,:,:]).unsqueeze(1) 200 | Q1 = (0.211 * imageRef[:,0,:,:] - 0.523 * imageRef[:,1,:,:] + 0.312 * imageRef[:,2,:,:]).unsqueeze(1) 201 | Q2 = (0.211 * imageDis[:,0,:,:] - 0.523 * imageDis[:,1,:,:] + 0.312 * imageDis[:,2,:,:]).unsqueeze(1) 202 | Y1, Y2 = downsample(Y1, Y2) 203 | I1, I2 = downsample(I1, I2) 204 | Q1, Q2 = downsample(Q1, Q2) 205 | elif channels == 1: 206 | Y1, Y2 = downsample(imageRef, imageDis) 207 | else: 208 | raise ValueError('channels error') 209 | 210 | PC1 = phasecong2(Y1) 211 | PC2 = phasecong2(Y2) 212 | 213 | dx = torch.Tensor([[3, 0, -3], [10, 0, -10], [3, 0, -3]]).float()/16 214 | dy = torch.Tensor([[3, 10, 3], [0, 0, 0], [-3, -10, -3]]).float()/16 215 | dx = dx.reshape(1,1,3,3).to(imageRef.device) 216 | dy = dy.reshape(1,1,3,3).to(imageRef.device) 217 | IxY1 = F.conv2d(Y1, dx, stride=1, padding =1) 218 | IyY1 = F.conv2d(Y1, dy, stride=1, padding =1) 219 | gradientMap1 = torch.sqrt(IxY1**2 + IyY1**2+1e-12) 220 | IxY2 = F.conv2d(Y2, dx, stride=1, padding =1) 221 | IyY2 = F.conv2d(Y2, dy, stride=1, padding =1) 222 | gradientMap2 = torch.sqrt(IxY2**2 + IyY2**2+1e-12) 223 | 224 | T1 = 0.85 225 | T2 = 160 226 | PCSimMatrix = (2 * PC1 * PC2 + T1) / (PC1**2 + PC2**2 + T1) 227 | gradientSimMatrix = (2*gradientMap1*gradientMap2 + T2)/(gradientMap1**2 + gradientMap2**2 + T2) 228 | PCm = torch.max(PC1, PC2) 229 | SimMatrix = gradientSimMatrix * PCSimMatrix * PCm 230 | FSIM_val = torch.sum(SimMatrix,dim=[1,2,3]) / torch.sum(PCm,dim=[1,2,3]) 231 | if channels==1: 232 | return FSIM_val 233 | 234 | T3 = 200 235 | T4 = 200 236 | ISimMatrix = (2 * I1 * I2 + T3) / (I1**2 + I2**2 + T3) 237 | QSimMatrix = (2 * Q1 * Q2 + T4) / (Q1**2 + Q2**2 + T4) 238 | 239 | SimMatrixC = gradientSimMatrix * PCSimMatrix * PCm * \ 240 | torch.sign(gradientSimMatrix) * ((torch.abs(ISimMatrix * QSimMatrix)+1e-12) ** 0.03) 241 | 242 | return torch.sum(SimMatrixC,dim=[1,2,3]) / torch.sum(PCm,dim=[1,2,3]) 243 | 244 | class FSIM(torch.nn.Module): 245 | # Refer to https://sse.tongji.edu.cn/linzhang/IQA/FSIM/FSIM.htm 246 | 247 | def __init__(self, channels=3): 248 | super(FSIM, self).__init__() 249 | 250 | def forward(self, y, x, as_loss=True): 251 | assert x.shape == y.shape 252 | x = x * 255 253 | y = y * 255 254 | if as_loss: 255 | score = fsim(x, y) 256 | return 1 - score.mean() 257 | else: 258 | with torch.no_grad(): 259 | score = fsim(x, y) 260 | return score 261 | 262 | 263 | if __name__ == '__main__': 264 | from PIL import Image 265 | import argparse 266 | from utils import prepare_image 267 | 268 | parser = argparse.ArgumentParser() 269 | parser.add_argument('--ref', type=str, default='images/r0.png') 270 | parser.add_argument('--dist', type=str, default='images/r1.png') 271 | args = parser.parse_args() 272 | 273 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 274 | 275 | ref = prepare_image(Image.open(args.ref).convert("RGB"), repeatNum = 1).to(device) 276 | dist = prepare_image(Image.open(args.dist).convert("RGB"), repeatNum = 1).to(device) 277 | 278 | model = FSIM(channels=3).to(device) 279 | 280 | score = model(dist, ref, as_loss=False) 281 | print('score: %.4f' % score.item()) 282 | # score: 0.7843 283 | 284 | -------------------------------------------------------------------------------- /IQA_pytorch/GMSD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from torchvision import transforms 6 | 7 | class GMSD(nn.Module): 8 | # Refer to http://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.htm 9 | 10 | def __init__(self, channels=3): 11 | super(GMSD, self).__init__() 12 | self.channels = channels 13 | dx = (torch.Tensor([[1,0,-1],[1,0,-1],[1,0,-1]])/3.).unsqueeze(0).unsqueeze(0).repeat(channels,1,1,1) 14 | dy = (torch.Tensor([[1,1,1],[0,0,0],[-1,-1,-1]])/3.).unsqueeze(0).unsqueeze(0).repeat(channels,1,1,1) 15 | self.dx = nn.Parameter(dx, requires_grad=False) 16 | self.dy = nn.Parameter(dy, requires_grad=False) 17 | self.aveKernel = nn.Parameter(torch.ones(channels,1,2,2)/4., requires_grad=False) 18 | 19 | def gmsd(self, img1, img2, T=170): 20 | Y1 = F.conv2d(img1, self.aveKernel, stride=2, padding =0, groups = self.channels) 21 | Y2 = F.conv2d(img2, self.aveKernel, stride=2, padding =0, groups = self.channels) 22 | 23 | IxY1 = F.conv2d(Y1, self.dx, stride=1, padding =1, groups = self.channels) 24 | IyY1 = F.conv2d(Y1, self.dy, stride=1, padding =1, groups = self.channels) 25 | gradientMap1 = torch.sqrt(IxY1**2 + IyY1**2+1e-12) 26 | 27 | IxY2 = F.conv2d(Y2, self.dx, stride=1, padding =1, groups = self.channels) 28 | IyY2 = F.conv2d(Y2, self.dy, stride=1, padding =1, groups = self.channels) 29 | gradientMap2 = torch.sqrt(IxY2**2 + IyY2**2+1e-12) 30 | 31 | quality_map = (2*gradientMap1*gradientMap2 + T)/(gradientMap1**2+gradientMap2**2 + T) 32 | score = torch.std(quality_map.view(quality_map.shape[0],-1),dim=1) 33 | return score 34 | 35 | def forward(self, y, x, as_loss=True): 36 | assert x.shape == y.shape 37 | x = x * 255 38 | y = y * 255 39 | if as_loss: 40 | score = self.gmsd(x, y) 41 | return score.mean() 42 | else: 43 | with torch.no_grad(): 44 | score = self.gmsd(x, y) 45 | return score 46 | 47 | if __name__ == '__main__': 48 | from PIL import Image 49 | import argparse 50 | from utils import prepare_image 51 | 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('--ref', type=str, default='images/r0.png') 54 | parser.add_argument('--dist', type=str, default='images/r1.png') 55 | args = parser.parse_args() 56 | 57 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 58 | 59 | ref = prepare_image(Image.open(args.ref).convert("RGB")).to(device) 60 | dist = prepare_image(Image.open(args.dist).convert("RGB")).to(device) 61 | 62 | model = GMSD().to(device) 63 | 64 | score = model(ref, dist, as_loss=False) 65 | print('score: %.4f' % score.item()) 66 | # score: 0.1907 67 | 68 | -------------------------------------------------------------------------------- /IQA_pytorch/LPIPSvgg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import torch 5 | from torchvision import models,transforms 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import inspect 9 | 10 | 11 | class LPIPSvgg(torch.nn.Module): 12 | def __init__(self, channels=3): 13 | # Refer to https://github.com/richzhang/PerceptualSimilarity 14 | 15 | assert channels == 3 16 | super(LPIPSvgg, self).__init__() 17 | vgg_pretrained_features = models.vgg16(pretrained=True).features 18 | self.stage1 = torch.nn.Sequential() 19 | self.stage2 = torch.nn.Sequential() 20 | self.stage3 = torch.nn.Sequential() 21 | self.stage4 = torch.nn.Sequential() 22 | self.stage5 = torch.nn.Sequential() 23 | for x in range(0,4): 24 | self.stage1.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(4, 9): 26 | self.stage2.add_module(str(x), vgg_pretrained_features[x]) 27 | for x in range(9, 16): 28 | self.stage3.add_module(str(x), vgg_pretrained_features[x]) 29 | for x in range(16, 23): 30 | self.stage4.add_module(str(x), vgg_pretrained_features[x]) 31 | for x in range(23, 30): 32 | self.stage5.add_module(str(x), vgg_pretrained_features[x]) 33 | 34 | for param in self.parameters(): 35 | param.requires_grad = False 36 | 37 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1,-1,1,1)) 38 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1,-1,1,1)) 39 | 40 | self.chns = [64,128,256,512,512] 41 | self.weights = torch.load(os.path.abspath(os.path.join(inspect.getfile(LPIPSvgg),'..','weights/LPIPSvgg.pt'))) 42 | self.weights = list(self.weights.items()) 43 | 44 | def forward_once(self, x): 45 | h = (x-self.mean)/self.std 46 | h = self.stage1(h) 47 | h_relu1_2 = h 48 | h = self.stage2(h) 49 | h_relu2_2 = h 50 | h = self.stage3(h) 51 | h_relu3_3 = h 52 | h = self.stage4(h) 53 | h_relu4_3 = h 54 | h = self.stage5(h) 55 | h_relu5_3 = h 56 | outs = [h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3] 57 | for k in range(len(outs)): 58 | outs[k] = F.normalize(outs[k]) 59 | return outs 60 | 61 | def forward(self, x, y, as_loss=True): 62 | assert x.shape == y.shape 63 | if as_loss: 64 | feats0 = self.forward_once(x) 65 | feats1 = self.forward_once(y) 66 | else: 67 | with torch.no_grad(): 68 | feats0 = self.forward_once(x) 69 | feats1 = self.forward_once(y) 70 | score = 0 71 | for k in range(len(self.chns)): 72 | score = score + (self.weights[k][1]*(feats0[k]-feats1[k])**2).mean([2,3]).sum(1) 73 | if as_loss: 74 | return score.mean() 75 | else: 76 | return score 77 | 78 | if __name__ == '__main__': 79 | from PIL import Image 80 | import argparse 81 | from utils import prepare_image 82 | 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--ref', type=str, default='images/r0.png') 85 | parser.add_argument('--dist', type=str, default='images/r1.png') 86 | args = parser.parse_args() 87 | 88 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 89 | 90 | ref = prepare_image(Image.open(args.ref).convert("RGB")).to(device) 91 | dist = prepare_image(Image.open(args.dist).convert("RGB")).to(device) 92 | 93 | model = LPIPSvgg().to(device) 94 | 95 | score = model(ref, dist, as_loss=False) 96 | print('score: %.4f' % score.item()) 97 | # score: 0.5435 98 | 99 | -------------------------------------------------------------------------------- /IQA_pytorch/MAD.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import torch 5 | from torchvision import models,transforms 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import inspect 9 | from numpy.fft import fft2, ifft2, fftshift, ifftshift 10 | import math 11 | from .utils import abs, real, imag, downsample, batch_fftshift2d, batch_ifftshift2d 12 | 13 | MAX = nn.MaxPool2d((2,2), stride=1, padding=1) 14 | 15 | def extract_patches_2d(img, patch_shape=[64, 64], step=[27,27], batch_first=False, keep_last_patch=False): 16 | patch_H, patch_W = patch_shape[0], patch_shape[1] 17 | if(img.size(2) Ci_thrsh) & (Ci_dst > (C_slope * (Ci_ref - Ci_thrsh) + Cd_thrsh) ) 138 | idx2 = (Ci_ref <= Ci_thrsh) & (Ci_dst > Cd_thrsh) 139 | 140 | msk = Ci_ref.clone() 141 | msk = msk.masked_fill(~idx1,0) 142 | msk = msk.masked_fill(~idx2,0) 143 | msk[idx1] = Ci_dst[idx1] - (C_slope * (Ci_ref[idx1]-Ci_thrsh) + Cd_thrsh) 144 | msk[idx2] = Ci_dst[idx2] - Cd_thrsh 145 | 146 | win = torch.ones( (1,1,BSIZE, BSIZE) ).repeat(C,1,1,1).to(ref.device) / BSIZE**2 147 | xx = (ref_img-dst_img)**2 148 | # p = (BSIZE-1)//2 149 | # xx = F.pad(xx,(p,p,p,p),'reflect') 150 | lmse = F.conv2d(xx, win, stride=4, padding =0, groups = C) 151 | 152 | mp = msk * lmse 153 | # mp2 = mp[:,:, BSIZE+1:-BSIZE-1, BSIZE+1:-BSIZE-1] 154 | B, C, H, W = mp.shape 155 | return torch.norm( mp.reshape(B,C,-1) , dim=2 ) / math.sqrt( H*W ) * 200 156 | 157 | def gaborconvolve(im): 158 | 159 | nscale = 5 #Number of wavelet scales. 160 | norient = 4 #Number of filter orientations. 161 | minWaveLength = 3 #Wavelength of smallest scale filter. 162 | mult = 3 #Scaling factor between successive filters. 163 | sigmaOnf = 0.55 #Ratio of the standard deviation of the 164 | wavelength = [minWaveLength,minWaveLength*mult,minWaveLength*mult**2, minWaveLength*mult**3, minWaveLength*mult**4] 165 | dThetaOnSigma = 1.5 #Ratio of angular interval between filter orientations 166 | 167 | B, C, rows, cols = im.shape 168 | imagefft = torch.rfft(im,2, onesided=False) # Fourier transform of image 169 | 170 | # Pre-compute to speed up filter construction 171 | x = np.ones((rows,1)) * np.arange(-cols/2.,(cols/2.))/(cols/2.) 172 | y = np.dot(np.expand_dims(np.arange(-rows/2.,(rows/2.)),1) , np.ones((1,cols))/(rows/2.)) 173 | radius = np.sqrt(x**2 + y**2) # Matrix values contain *normalised* radius from centre. 174 | radius[int(np.round(rows/2+1)),int(np.round(cols/2+1))] = 1 # Get rid of the 0 radius value in the middle 175 | radius = np.log(radius+1e-12) 176 | 177 | theta = np.arctan2(-y,x) # Matrix values contain polar angle. 178 | # (note -ve y is used to give +ve 179 | # anti-clockwise angles) 180 | sintheta = np.sin(theta) 181 | costheta = np.cos(theta) 182 | 183 | thetaSigma = math.pi/norient/dThetaOnSigma # Calculate the standard deviation of the 184 | 185 | logGabors = [] 186 | for s in range(nscale): # For each scale. 187 | # Construct the filter - first calculate the radial filter component. 188 | fo = 1.0/wavelength[s] # Centre frequency of filter. 189 | rfo = fo/0.5 # Normalised radius from centre of frequency plane 190 | # corresponding to fo. 191 | tmp = -(2 * np.log(sigmaOnf)**2) 192 | tmp2= np.log(rfo) 193 | logGabors.append(np.exp( (radius-tmp2)**2 /tmp)) 194 | logGabors[s][int(np.round(rows/2)), int(np.round(cols/2))]=0 195 | 196 | 197 | E0 = [[],[],[],[]] 198 | for o in range(norient): # For each orientation. 199 | angl = o*math.pi/norient # Calculate filter angle. 200 | 201 | ds = sintheta * np.cos(angl) - costheta * np.sin(angl) # Difference in sine. 202 | dc = costheta * np.cos(angl) + sintheta * np.sin(angl) # Difference in cosine. 203 | dtheta = np.abs(np.arctan2(ds,dc)) # Absolute angular distance. 204 | spread = np.exp((-dtheta**2) / (2 * thetaSigma**2)) # Calculate the angular filter component. 205 | 206 | for s in range(nscale): # For each scale. 207 | 208 | filter = fftshift(logGabors[s] * spread) 209 | filter = torch.from_numpy(filter).reshape(1,1,rows,cols,1).repeat(1,C,1,1,2).to(im.device) 210 | # c = imagefft * filter 211 | e0 = torch.ifft( imagefft * filter, 2 ) 212 | E0[o].append(e0) 213 | 214 | return E0 215 | 216 | def lo_index(ref, dst): 217 | gabRef = gaborconvolve( ref ) 218 | gabDst = gaborconvolve( dst ) 219 | s = [0.5/13.25, 0.75/13.25, 1/13.25, 5/13.25, 6/13.25] 220 | 221 | BSIZE = 16 222 | mp = 0 223 | for gb_i in range(4): 224 | for gb_j in range(5): 225 | stdref, skwref, krtref = ical_stat( abs( gabRef[gb_i][gb_j] ) ) 226 | stddst, skwdst, krtdst = ical_stat( abs( gabDst[gb_i][gb_j] ) ) 227 | mp = mp + s[gb_i] * ( torch.abs( stdref - stddst ) + 2*torch.abs( skwref - skwdst ) + torch.abs( krtref - krtdst ) ) 228 | 229 | # mp2 = mp[:,:, BSIZE+1:-BSIZE-1, BSIZE+1:-BSIZE-1] 230 | B, C, rows, cols = mp.shape 231 | return torch.norm( mp.reshape(B,C,-1) , dim=2 ) / np.sqrt(rows * cols) 232 | 233 | 234 | def mad(ref, dst): 235 | HI = hi_index(ref, dst) 236 | LO = lo_index(ref, dst) 237 | thresh1 = 2.55 238 | thresh2 = 3.35 239 | b1 = math.exp(-thresh1/thresh2) 240 | b2 = 1 / (math.log(10)*thresh2) 241 | sig = 1 / ( 1 + b1*HI**b2 ) 242 | MAD = LO**(1-sig) * HI**(sig) 243 | return MAD.mean(1) 244 | 245 | class MAD(torch.nn.Module): 246 | # Refer to http://vision.eng.shizuoka.ac.jp/mod/page/view.php?id=23 247 | 248 | def __init__(self, channels=3): 249 | super(MAD, self).__init__() 250 | 251 | def forward(self, y, x, as_loss=True): 252 | assert x.shape == y.shape 253 | x = x * 255 254 | y = y * 255 255 | if as_loss: 256 | score = mad(x, y) 257 | return score.mean() 258 | else: 259 | with torch.no_grad(): 260 | score = mad(x, y) 261 | return score 262 | 263 | 264 | if __name__ == '__main__': 265 | from PIL import Image 266 | import argparse 267 | from utils import prepare_image 268 | 269 | parser = argparse.ArgumentParser() 270 | parser.add_argument('--ref', type=str, default='images/r0.png') 271 | parser.add_argument('--dist', type=str, default='images/r1.png') 272 | args = parser.parse_args() 273 | 274 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 275 | 276 | ref = prepare_image(Image.open(args.ref).convert("L"), repeatNum = 1).to(device) 277 | dist = prepare_image(Image.open(args.dist).convert("L"), repeatNum = 1).to(device) 278 | 279 | model = MAD(channels=1).to(device) 280 | 281 | score = model(dist, ref, as_loss=False) 282 | print('score: %.4f', score.item()) 283 | # score: 168 -------------------------------------------------------------------------------- /IQA_pytorch/MS_SSIM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from torchvision import transforms 6 | from .utils import fspecial_gauss 7 | from .SSIM import ssim 8 | 9 | def ms_ssim(X, Y, win): 10 | if not X.shape == Y.shape: 11 | raise ValueError('Input images must have the same dimensions.') 12 | 13 | weights = torch.FloatTensor( 14 | [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(X.device, dtype=X.dtype) 15 | 16 | levels = weights.shape[0] 17 | mcs = [] 18 | for _ in range(levels): 19 | ssim_val, cs = ssim(X, Y, win=win, get_cs=True) 20 | mcs.append(cs) 21 | padding = (X.shape[2] % 2, X.shape[3] % 2) 22 | X = F.avg_pool2d(X, kernel_size=2, padding=padding) 23 | Y = F.avg_pool2d(Y, kernel_size=2, padding=padding) 24 | 25 | mcs = torch.stack(mcs, dim=0) 26 | msssim_val = torch.prod((mcs[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_val ** weights[-1]), dim=0) 27 | return msssim_val 28 | 29 | class MS_SSIM(torch.nn.Module): 30 | def __init__(self, channels=3): 31 | super(MS_SSIM, self).__init__() 32 | self.win = fspecial_gauss(11, 1.5, channels) 33 | 34 | def forward(self, X, Y, as_loss=True): 35 | assert X.shape == Y.shape 36 | if as_loss: 37 | score = ms_ssim(X, Y, win=self.win) 38 | return 1 - score.mean() 39 | else: 40 | with torch.no_grad(): 41 | score = ms_ssim(X, Y, win=self.win) 42 | return score 43 | 44 | if __name__ == '__main__': 45 | from PIL import Image 46 | import argparse 47 | from utils import prepare_image 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--ref', type=str, default='images/r0.png') 51 | parser.add_argument('--dist', type=str, default='images/r1.png') 52 | args = parser.parse_args() 53 | 54 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 55 | 56 | ref = prepare_image(Image.open(args.ref).convert("RGB")).to(device) 57 | dist = prepare_image(Image.open(args.dist).convert("RGB")).to(device) 58 | 59 | model = MS_SSIM(channels=3) 60 | 61 | score = model(dist, ref, as_loss=False) 62 | print('score: %.4f' % score.item()) 63 | # score: 0.8524 64 | -------------------------------------------------------------------------------- /IQA_pytorch/NLPD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | LAPLACIAN_FILTER = np.array([[0.0025, 0.0125, 0.0200, 0.0125, 0.0025], 7 | [0.0125, 0.0625, 0.1000, 0.0625, 0.0125], 8 | [0.0200, 0.1000, 0.1600, 0.1000, 0.0200], 9 | [0.0125, 0.0625, 0.1000, 0.0625, 0.0125], 10 | [0.0025, 0.0125, 0.0200, 0.0125, 0.0025]], 11 | dtype=np.float32) 12 | 13 | class NLPD(nn.Module): 14 | """ 15 | Normalised lapalcian pyramid distance. 16 | Refer to https://www.cns.nyu.edu/pub/eero/laparra16a-preprint.pdf 17 | https://github.com/alexhepburn/nlpd-tensorflow 18 | """ 19 | def __init__(self, channels=3, k=6, filt=None): 20 | super(NLPD, self).__init__() 21 | if filt is None: 22 | filt = np.reshape(np.tile(LAPLACIAN_FILTER, (channels, 1, 1)), 23 | (channels, 1, 5, 5)) 24 | self.k = k 25 | self.channels = channels 26 | self.filt = nn.Parameter(torch.Tensor(filt), requires_grad=False) 27 | self.dn_filts, self.sigmas = self.DN_filters() 28 | self.pad_one = nn.ReflectionPad2d(1) 29 | self.pad_two = nn.ReflectionPad2d(2) 30 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', 31 | align_corners=True) 32 | 33 | def DN_filters(self): 34 | sigmas = [0.0248, 0.0185, 0.0179, 0.0191, 0.0220, 0.2782] 35 | dn_filts = [] 36 | dn_filts.append(torch.Tensor(np.reshape([[0, 0.1011, 0], 37 | [0.1493, 0, 0.1460], 38 | [0, 0.1015, 0.]]*self.channels, 39 | (self.channels, 1, 3, 3)).astype(np.float32))) 40 | 41 | dn_filts.append(torch.Tensor(np.reshape([[0, 0.0757, 0], 42 | [0.1986, 0, 0.1846], 43 | [0, 0.0837, 0]]*self.channels, 44 | (self.channels, 1, 3, 3)).astype(np.float32))) 45 | 46 | dn_filts.append(torch.Tensor(np.reshape([[0, 0.0477, 0], 47 | [0.2138, 0, 0.2243], 48 | [0, 0.0467, 0]]*self.channels, 49 | (self.channels, 1, 3, 3)).astype(np.float32))) 50 | 51 | dn_filts.append(torch.Tensor(np.reshape([[0, 0, 0], 52 | [0.2503, 0, 0.2616], 53 | [0, 0, 0]]*self.channels, 54 | (self.channels, 1, 3, 3)).astype(np.float32))) 55 | 56 | dn_filts.append(torch.Tensor(np.reshape([[0, 0, 0], 57 | [0.2598, 0, 0.2552], 58 | [0, 0, 0]]*self.channels, 59 | (self.channels, 1, 3, 3)).astype(np.float32))) 60 | 61 | dn_filts.append(torch.Tensor(np.reshape([[0, 0, 0], 62 | [0.2215, 0, 0.0717], 63 | [0, 0, 0]]*self.channels, 64 | (self.channels, 1, 3, 3)).astype(np.float32))) 65 | dn_filts = nn.ParameterList([nn.Parameter(x, requires_grad=False) 66 | for x in dn_filts]) 67 | sigmas = nn.ParameterList([nn.Parameter(torch.Tensor(np.array(x)), 68 | requires_grad=False) for x in sigmas]) 69 | return dn_filts, sigmas 70 | 71 | def pyramid(self, im): 72 | out = [] 73 | J = im 74 | pyr = [] 75 | for i in range(0, self.k): 76 | I = F.conv2d(self.pad_two(J), self.filt, stride=2, padding=0, 77 | groups=self.channels) 78 | I_up = self.upsample(I) 79 | I_up_conv = F.conv2d(self.pad_two(I_up), self.filt, stride=1, 80 | padding=0, groups=self.channels) 81 | if J.size() != I_up_conv.size(): 82 | I_up_conv = F.interpolate(I_up_conv, [J.size(2), J.size(3)]) 83 | out = J - I_up_conv 84 | out_conv = F.conv2d(self.pad_one(torch.abs(out)), self.dn_filts[i], 85 | stride=1, groups=self.channels) 86 | out_norm = out / (self.sigmas[i]+out_conv) 87 | pyr.append(out_norm) 88 | J = I 89 | return pyr 90 | 91 | def nlpd(self, x1, x2): 92 | y1 = self.pyramid(x1) 93 | y2 = self.pyramid(x2) 94 | total = [] 95 | for z1, z2 in zip(y1, y2): 96 | diff = (z1 - z2) ** 2 97 | sqrt = torch.sqrt(torch.mean(diff, (1, 2, 3))) 98 | total.append(sqrt) 99 | score = torch.stack(total,dim=1).mean(1) 100 | return score 101 | 102 | def forward(self, y, x, as_loss=True): 103 | assert x.shape == y.shape 104 | if as_loss: 105 | score = self.nlpd(x, y) 106 | return score.mean() 107 | else: 108 | with torch.no_grad(): 109 | score = self.nlpd(x, y) 110 | return score 111 | 112 | if __name__ == '__main__': 113 | from PIL import Image 114 | import argparse 115 | from utils import prepare_image 116 | 117 | parser = argparse.ArgumentParser() 118 | parser.add_argument('--ref', type=str, default='images/r0.png') 119 | parser.add_argument('--dist', type=str, default='images/r1.png') 120 | args = parser.parse_args() 121 | 122 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 123 | 124 | ref = prepare_image(Image.open(args.ref).convert("L")).to(device) 125 | dist = prepare_image(Image.open(args.dist).convert("L")).to(device) 126 | 127 | model = NLPD(channels=1).to(device) 128 | 129 | score = model(dist, ref, as_loss=False) 130 | print('score: %.4f' % score.item()) 131 | # score: 0.4016 -------------------------------------------------------------------------------- /IQA_pytorch/SSIM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from torchvision import transforms 6 | from .utils import fspecial_gauss 7 | 8 | def gaussian_filter(input, win): 9 | out = F.conv2d(input, win, stride=1, padding=0, groups=input.shape[1]) 10 | return out 11 | 12 | def ssim(X, Y, win, get_ssim_map=False, get_cs=False, get_weight=False): 13 | C1 = 0.01**2 14 | C2 = 0.03**2 15 | 16 | win = win.to(X.device) 17 | 18 | mu1 = gaussian_filter(X, win) 19 | mu2 = gaussian_filter(Y, win) 20 | mu1_sq = mu1.pow(2) 21 | mu2_sq = mu2.pow(2) 22 | mu1_mu2 = mu1 * mu2 23 | sigma1_sq = gaussian_filter(X * X, win) - mu1_sq 24 | sigma2_sq = gaussian_filter(Y * Y, win) - mu2_sq 25 | sigma12 = gaussian_filter(X * Y, win) - mu1_mu2 26 | 27 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) 28 | cs_map = F.relu(cs_map) #force the ssim response to be nonnegative to avoid negative results. 29 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map 30 | ssim_val = ssim_map.mean([1,2,3]) 31 | 32 | if get_weight: 33 | weights = torch.log((1+sigma1_sq/C2)*(1+sigma2_sq/C2)) 34 | return ssim_map, weights 35 | 36 | if get_ssim_map: 37 | return ssim_map 38 | 39 | if get_cs: 40 | return ssim_val, cs_map.mean([1,2,3]) 41 | 42 | return ssim_val 43 | 44 | class SSIM(torch.nn.Module): 45 | def __init__(self, channels=3): 46 | 47 | super(SSIM, self).__init__() 48 | self.win = fspecial_gauss(11, 1.5, channels) 49 | 50 | def forward(self, X, Y, as_loss=True): 51 | assert X.shape == Y.shape 52 | if as_loss: 53 | score = ssim(X, Y, win=self.win) 54 | return 1 - score.mean() 55 | else: 56 | with torch.no_grad(): 57 | score = ssim(X, Y, win=self.win) 58 | return score 59 | 60 | if __name__ == '__main__': 61 | from PIL import Image 62 | import argparse 63 | from utils import prepare_image 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--ref', type=str, default='images/r0.png') 67 | parser.add_argument('--dist', type=str, default='images/r1.png') 68 | args = parser.parse_args() 69 | 70 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 71 | 72 | ref = prepare_image(Image.open(args.ref).convert("RGB")).to(device) 73 | dist = prepare_image(Image.open(args.dist).convert("RGB")).to(device) 74 | 75 | model = SSIM(channels=3) 76 | 77 | score = model(dist, ref, as_loss=False) 78 | print('score: %.4f' % score.item()) 79 | # score: 0.6717 80 | 81 | # model = SSIM(channels=1) 82 | # score = 0 83 | # for i in range(3): 84 | # ref1 = ref[:,i,:,:].unsqueeze(1) 85 | # dist1= dist[:,i,:,:].unsqueeze(1) 86 | # score = score + model(ref1, dist1).item() 87 | # print('score: %.4f' % score) 88 | -------------------------------------------------------------------------------- /IQA_pytorch/SteerPyrComplex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .SteerPyrUtils import * 4 | 5 | 6 | class SteerablePyramid(nn.Module): 7 | # refer to https://github.com/LabForComputationalVision/pyrtools 8 | # https://github.com/olivierhenaff/steerablePyramid 9 | def __init__(self, imgSize=[256,256], K=4, N=4, hilb=True, includeHF=True, device=torch.device("cuda")): 10 | super(SteerablePyramid, self).__init__() 11 | assert imgSize[0]==imgSize[1] 12 | size = [ imgSize[0], imgSize[1]//2 + 1 ] 13 | # self.imgSize = imgSize 14 | self.hl0 = HL0_matrix( size ).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(device) 15 | 16 | self.l = [] 17 | self.b = [] 18 | self.s = [] 19 | 20 | self.K = K 21 | self.N = N 22 | self.hilb = hilb 23 | self.includeHF = includeHF 24 | 25 | self.indF = [ freq_shift( size[0], True, device ) ] 26 | self.indB = [ freq_shift( size[0], False, device ) ] 27 | 28 | 29 | for n in range( self.N ): 30 | 31 | l = L_matrix_cropped( size ).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(device) 32 | b = B_matrix( K, size ).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(device) 33 | s = S_matrix( K, size ).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(device) 34 | 35 | self.l.append( l.div_(4) ) 36 | self.b.append( b ) 37 | self.s.append( s ) 38 | 39 | size = [ l.size(-2), l.size(-1) ] 40 | 41 | self.indF.append( freq_shift( size[0], True, device ) ) 42 | self.indB.append( freq_shift( size[0], False, device ) ) 43 | 44 | 45 | def forward(self, x): 46 | 47 | fftfull = torch.rfft(x,2) 48 | xreal = fftfull[... , 0] 49 | xim = fftfull[... ,1] 50 | x = torch.cat((xreal.unsqueeze(1), xim.unsqueeze(1)), 1 ).unsqueeze( -3 ) 51 | x = torch.index_select( x, -2, self.indF[0] ) 52 | 53 | x = self.hl0 * x 54 | h0f = x.select( -3, 0 ).unsqueeze( -3 ) 55 | l0f = x.select( -3, 1 ).unsqueeze( -3 ) 56 | lf = l0f 57 | 58 | output = [] 59 | 60 | for n in range( self.N ): 61 | 62 | bf = self.b[n] * lf 63 | lf = self.l[n] * central_crop( lf ) 64 | if self.hilb: 65 | hbf = self.s[n] * torch.cat( (bf.narrow(1,1,1), -bf.narrow(1,0,1)), 1 ) 66 | bf = torch.cat( ( bf , hbf ), -3 ) 67 | if self.includeHF and n == 0: 68 | bf = torch.cat( ( h0f, bf ), -3 ) 69 | 70 | output.append( bf ) 71 | 72 | output.append( lf ) 73 | 74 | for n in range( len( output ) ): 75 | output[n] = torch.index_select( output[n], -2, self.indB[n] ) 76 | sig_size = [output[n].shape[-2],(output[n].shape[-1]-1)*2] 77 | output[n] = torch.stack((output[n].select(1,0), output[n].select(1,1)),-1) 78 | output[n] = torch.irfft( output[n], 2, signal_sizes = sig_size) 79 | 80 | if self.includeHF: 81 | output.insert( 0, output[0].narrow( -3, 0, 1 ) ) 82 | output[1] = output[1].narrow( -3, 1, output[1].size(-3)-1 ) 83 | 84 | for n in range( len( output ) ): 85 | if self.hilb: 86 | if ((not self.includeHF) or 0 < n) and n < len(output)-1: 87 | nfeat = output[n].size(-3)//2 88 | o1 = output[n].narrow( -3, 0, nfeat ).unsqueeze(1) 89 | o2 = -output[n].narrow( -3, nfeat, nfeat ).unsqueeze(1) 90 | output[n] = torch.cat( (o2, o1), 1 ) 91 | else: 92 | output[n] = output[n].unsqueeze(1) 93 | 94 | for n in range( len( output ) ): 95 | if n>0: 96 | output[n] = output[n]*(2**(n-1)) 97 | 98 | return output 99 | 100 | 101 | if __name__ == '__main__': 102 | 103 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 104 | 105 | import numpy as np 106 | from PIL import Image 107 | imgSize = (256,256) 108 | network = SteerablePyramid(imgSize=imgSize, K=8, N=4, device=device) 109 | # x = torch.rand((1,1,imgSize,imgSize),requires_grad=True, device=device) 110 | 111 | x = torch.from_numpy(np.array(Image.open('images/r1.png').convert("L"))).float().unsqueeze(0).unsqueeze(0) 112 | # x = x.permute(2,0,1).unsqueeze(0)#.unsqueeze(0) 113 | x = (x).to(device)#.repeat(4,1,1,1) 114 | x = x[:,:,:,:256] 115 | x.requires_grad_(True) 116 | 117 | y = network(x) 118 | c = y[1][0][0][0][0] 119 | c0 = y[1][0][1][0][0] 120 | c1 = y[2][0][0][0][0] 121 | c2 = y[2][0][1][0][0] 122 | c3 = y[3][0][0][0][0] 123 | c4 = y[3][0][1][0][0] 124 | c5 = y[4][0][0][0][0] 125 | c6 = y[4][0][1][0][0] 126 | c7 = y[5][0][0][0][0] 127 | c = 0 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /IQA_pytorch/SteerPyrSpace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from torchvision import transforms 7 | from .utils import fspecial_gauss 8 | from .SteerPyrUtils import sp5_filters 9 | 10 | def corrDn(image, filt, step=1, channels=1): 11 | 12 | filt_ = torch.from_numpy(filt).float().unsqueeze(0).unsqueeze(0).repeat(channels,1,1,1).to(image.device) 13 | p = (filt_.shape[2]-1)//2 14 | image = F.pad(image, (p,p,p,p),'reflect') 15 | img = F.conv2d(image, filt_, stride=step, padding=0, groups = channels) 16 | return img 17 | 18 | def SteerablePyramidSpace(image, height=4, order=5, channels=1): 19 | num_orientations = order + 1 20 | filters = sp5_filters() 21 | 22 | hi0 = corrDn(image, filters['hi0filt'], step=1, channels=channels) 23 | pyr_coeffs = [] 24 | pyr_coeffs.append(hi0) 25 | lo = corrDn(image, filters['lo0filt'], step=1, channels=channels) 26 | for _ in range(height): 27 | bfiltsz = int(np.floor(np.sqrt(filters['bfilts'].shape[0]))) 28 | for b in range(num_orientations): 29 | filt = filters['bfilts'][:, b].reshape(bfiltsz, bfiltsz).T 30 | band = corrDn(lo, filt, step=1, channels=channels) 31 | pyr_coeffs.append(band) 32 | lo = corrDn(lo, filters['lofilt'], step=2, channels=channels) 33 | 34 | pyr_coeffs.append(lo) 35 | return pyr_coeffs 36 | 37 | 38 | if __name__ == '__main__': 39 | from PIL import Image 40 | import argparse 41 | from utils import prepare_image 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--ref', type=str, default='images/r0.png') 45 | parser.add_argument('--dist', type=str, default='images/r1.png') 46 | args = parser.parse_args() 47 | 48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | 50 | dist = prepare_image(Image.open(args.dist).convert("L"),repeatNum=1).to(device) 51 | x = SteerablePyramidSpace(dist*255,channels=1) 52 | c = 0 53 | 54 | -------------------------------------------------------------------------------- /IQA_pytorch/SteerPyrUtils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | 7 | def L( r ): 8 | if r <= math.pi / 4: 9 | return 2 10 | elif r >= math.pi / 2: 11 | return 0 12 | else: 13 | return 2 * math.cos( math.pi / 2 * math.log( 4 * r / math.pi ) / math.log( 2 ) ) 14 | 15 | def H( r ): 16 | if r <= math.pi / 4: 17 | return 0 18 | elif r >= math.pi / 2: 19 | return 1 20 | else: 21 | return math.cos( math.pi / 2 * math.log( 2 * r / math.pi ) / math.log( 2 ) ) 22 | 23 | def G( t, k, K ): 24 | 25 | t0 = math.pi * k / K 26 | aK = 2**(K-1) * math.factorial(K-1) / math.sqrt( K * math.factorial( 2 * (K-1) ) ) 27 | 28 | if (t - t0) > (math.pi/2): 29 | return G( t - math.pi, k, K ) 30 | elif (t - t0 ) < (-math.pi/2): 31 | return G( t + math.pi, k, K ) 32 | else: 33 | return aK * (math.cos( t - t0 ))**(K-1) 34 | 35 | def S( t, k, K ): 36 | 37 | t0 = math.pi * k / K 38 | dt = abs(t-t0) 39 | 40 | if dt < math.pi/2: 41 | return 1 42 | elif dt == math.pi/2: 43 | return 0 44 | else: 45 | return -1 46 | 47 | def L0( r ): 48 | return L( r/2 ) / 2 49 | 50 | def H0( r ): 51 | return H( r/2 ) 52 | 53 | def polar_map( s ): 54 | 55 | x = torch.linspace( 0, math.pi, s[1] ).view( 1, s[1] ).expand( s ) 56 | if s[0] % 2 == 0 : 57 | y = torch.linspace( -math.pi, math.pi, s[0]+1 ).narrow(0,1,s[0]) 58 | else: 59 | y = torch.linspace( -math.pi, math.pi, s[0] ) 60 | y = y.view( s[0], 1 ).expand( s ).mul( -1 ) 61 | 62 | r = ( x**2 + y**2 ).sqrt() 63 | t = torch.atan2( y, x ) 64 | 65 | return r, t 66 | 67 | def S_matrix( K, s ): 68 | 69 | _, t = polar_map( s ) 70 | sm = torch.Tensor( K, s[0], s[1] ) 71 | 72 | for k in range( K ): 73 | for i in range( s[0] ): 74 | for j in range( s[1] ): 75 | sm[k][i][j] = S( t[i][j], k, K ) 76 | 77 | return sm 78 | 79 | def G_matrix( K, s ): 80 | 81 | _, t = polar_map( s ) 82 | g = torch.Tensor( K, s[0], s[1] ) 83 | 84 | for k in range( K ): 85 | for i in range( s[0] ): 86 | for j in range( s[1] ): 87 | g[k][i][j] = G( t[i][j], k, K ) 88 | 89 | return g 90 | 91 | def B_matrix( K, s ): 92 | 93 | g = G_matrix( K, s ) 94 | 95 | r, _ = polar_map( s ) 96 | h = r.apply_( H ).unsqueeze(0) 97 | 98 | return h * g 99 | 100 | def L_matrix( s ): 101 | 102 | r, _ = polar_map( s ) 103 | 104 | return r.apply_( L ) 105 | 106 | def LB_matrix( K, s ): 107 | 108 | l = L_matrix( s ).unsqueeze(0) 109 | b = B_matrix( K, s ) 110 | 111 | return torch.cat( (l,b), 0 ) 112 | 113 | def HL0_matrix( s ): 114 | 115 | r, _ = polar_map( s ) 116 | h = r.clone().apply_( H0 ).view( 1, s[0], s[1] ) 117 | l = r.clone().apply_( L0 ).view( 1, s[0], s[1] ) 118 | 119 | return torch.cat( ( h, l ), 0 ) 120 | 121 | def central_crop( x ): 122 | 123 | ns = [ x.size(-2)//2 , x.size(-1)//2 + 1 ] 124 | 125 | return x.narrow( -2, ns[1]-1, ns[0] ).narrow( -1, 0, ns[1] ) 126 | 127 | def cropped_size( s ): 128 | 129 | return [ s[0]//2 , s[1]//2 + 1 ] 130 | 131 | def L_matrix_cropped( s ): 132 | 133 | l = L_matrix( s ) 134 | 135 | ns = cropped_size( s ) 136 | 137 | return l.narrow( 0, ns[1]-1, ns[0] ).narrow( 1, 0, ns[1] ) 138 | 139 | def freq_shift( imgSize, fwd, device ): 140 | ind = torch.LongTensor( imgSize ).to(device) 141 | sgn = 1 142 | if fwd: 143 | sgn = -1 144 | for i in range( imgSize ): 145 | ind[i] = (i + sgn*((imgSize-1)//2) ) % imgSize 146 | 147 | return Variable( ind ) 148 | 149 | 150 | ########## 151 | def sp5_filters(): 152 | filters = {} 153 | filters['harmonics'] = np.array([1, 3, 5]) 154 | filters['mtx'] = ( 155 | np.array([[0.3333, 0.2887, 0.1667, 0.0000, -0.1667, -0.2887], 156 | [0.0000, 0.1667, 0.2887, 0.3333, 0.2887, 0.1667], 157 | [0.3333, -0.0000, -0.3333, -0.0000, 0.3333, -0.0000], 158 | [0.0000, 0.3333, 0.0000, -0.3333, 0.0000, 0.3333], 159 | [0.3333, -0.2887, 0.1667, -0.0000, -0.1667, 0.2887], 160 | [-0.0000, 0.1667, -0.2887, 0.3333, -0.2887, 0.1667]])) 161 | filters['hi0filt'] = ( 162 | np.array([[-0.00033429, -0.00113093, -0.00171484, 163 | -0.00133542, -0.00080639, -0.00133542, 164 | -0.00171484, -0.00113093, -0.00033429], 165 | [-0.00113093, -0.00350017, -0.00243812, 166 | 0.00631653, 0.01261227, 0.00631653, 167 | -0.00243812, -0.00350017, -0.00113093], 168 | [-0.00171484, -0.00243812, -0.00290081, 169 | -0.00673482, -0.00981051, -0.00673482, 170 | -0.00290081, -0.00243812, -0.00171484], 171 | [-0.00133542, 0.00631653, -0.00673482, 172 | -0.07027679, -0.11435863, -0.07027679, 173 | -0.00673482, 0.00631653, -0.00133542], 174 | [-0.00080639, 0.01261227, -0.00981051, 175 | -0.11435863, 0.81380200, -0.11435863, 176 | -0.00981051, 0.01261227, -0.00080639], 177 | [-0.00133542, 0.00631653, -0.00673482, 178 | -0.07027679, -0.11435863, -0.07027679, 179 | -0.00673482, 0.00631653, -0.00133542], 180 | [-0.00171484, -0.00243812, -0.00290081, 181 | -0.00673482, -0.00981051, -0.00673482, 182 | -0.00290081, -0.00243812, -0.00171484], 183 | [-0.00113093, -0.00350017, -0.00243812, 184 | 0.00631653, 0.01261227, 0.00631653, 185 | -0.00243812, -0.00350017, -0.00113093], 186 | [-0.00033429, -0.00113093, -0.00171484, 187 | -0.00133542, -0.00080639, -0.00133542, 188 | -0.00171484, -0.00113093, -0.00033429]])) 189 | filters['lo0filt'] = ( 190 | np.array([[0.00341614, -0.01551246, -0.03848215, -0.01551246, 191 | 0.00341614], 192 | [-0.01551246, 0.05586982, 0.15925570, 0.05586982, 193 | -0.01551246], 194 | [-0.03848215, 0.15925570, 0.40304148, 0.15925570, 195 | -0.03848215], 196 | [-0.01551246, 0.05586982, 0.15925570, 0.05586982, 197 | -0.01551246], 198 | [0.00341614, -0.01551246, -0.03848215, -0.01551246, 199 | 0.00341614]])) 200 | filters['lofilt'] = ( 201 | 2 * np.array([[0.00085404, -0.00244917, -0.00387812, -0.00944432, 202 | -0.00962054, -0.00944432, -0.00387812, -0.00244917, 203 | 0.00085404], 204 | [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 205 | 0.01002988, 0.00410600, -0.00661117, -0.00523281, 206 | -0.00244917], 207 | [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 208 | 0.03981393, 0.03277038, 0.01396746, -0.00661117, 209 | -0.00387812], 210 | [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 211 | 0.08169618, 0.06426333, 0.03277038, 0.00410600, 212 | -0.00944432], 213 | [-0.00962054, 0.01002988, 0.03981393, 0.08169618, 214 | 0.10096540, 0.08169618, 0.03981393, 0.01002988, 215 | -0.00962054], 216 | [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 217 | 0.08169618, 0.06426333, 0.03277038, 0.00410600, 218 | -0.00944432], 219 | [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 220 | 0.03981393, 0.03277038, 0.01396746, -0.00661117, 221 | -0.00387812], 222 | [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 223 | 0.01002988, 0.00410600, -0.00661117, -0.00523281, 224 | -0.00244917], 225 | [0.00085404, -0.00244917, -0.00387812, -0.00944432, 226 | -0.00962054, -0.00944432, -0.00387812, -0.00244917, 227 | 0.00085404]])) 228 | filters['bfilts'] = ( 229 | np.array([[0.00277643, 0.00496194, 0.01026699, 0.01455399, 0.01026699, 230 | 0.00496194, 0.00277643, -0.00986904, -0.00893064, 231 | 0.01189859, 0.02755155, 0.01189859, -0.00893064, 232 | -0.00986904, -0.01021852, -0.03075356, -0.08226445, 233 | -0.11732297, -0.08226445, -0.03075356, -0.01021852, 234 | 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 235 | 0.00000000, 0.00000000, 0.01021852, 0.03075356, 0.08226445, 236 | 0.11732297, 0.08226445, 0.03075356, 0.01021852, 0.00986904, 237 | 0.00893064, -0.01189859, -0.02755155, -0.01189859, 238 | 0.00893064, 0.00986904, -0.00277643, -0.00496194, 239 | -0.01026699, -0.01455399, -0.01026699, -0.00496194, 240 | -0.00277643], 241 | [-0.00343249, -0.00640815, -0.00073141, 0.01124321, 242 | 0.00182078, 0.00285723, 0.01166982, -0.00358461, 243 | -0.01977507, -0.04084211, -0.00228219, 0.03930573, 244 | 0.01161195, 0.00128000, 0.01047717, 0.01486305, 245 | -0.04819057, -0.12227230, -0.05394139, 0.00853965, 246 | -0.00459034, 0.00790407, 0.04435647, 0.09454202, 247 | -0.00000000, -0.09454202, -0.04435647, -0.00790407, 248 | 0.00459034, -0.00853965, 0.05394139, 0.12227230, 249 | 0.04819057, -0.01486305, -0.01047717, -0.00128000, 250 | -0.01161195, -0.03930573, 0.00228219, 0.04084211, 251 | 0.01977507, 0.00358461, -0.01166982, -0.00285723, 252 | -0.00182078, -0.01124321, 0.00073141, 0.00640815, 253 | 0.00343249], 254 | [0.00343249, 0.00358461, -0.01047717, -0.00790407, 255 | -0.00459034, 0.00128000, 0.01166982, 0.00640815, 256 | 0.01977507, -0.01486305, -0.04435647, 0.00853965, 257 | 0.01161195, 0.00285723, 0.00073141, 0.04084211, 0.04819057, 258 | -0.09454202, -0.05394139, 0.03930573, 0.00182078, 259 | -0.01124321, 0.00228219, 0.12227230, -0.00000000, 260 | -0.12227230, -0.00228219, 0.01124321, -0.00182078, 261 | -0.03930573, 0.05394139, 0.09454202, -0.04819057, 262 | -0.04084211, -0.00073141, -0.00285723, -0.01161195, 263 | -0.00853965, 0.04435647, 0.01486305, -0.01977507, 264 | -0.00640815, -0.01166982, -0.00128000, 0.00459034, 265 | 0.00790407, 0.01047717, -0.00358461, -0.00343249], 266 | [-0.00277643, 0.00986904, 0.01021852, -0.00000000, 267 | -0.01021852, -0.00986904, 0.00277643, -0.00496194, 268 | 0.00893064, 0.03075356, -0.00000000, -0.03075356, 269 | -0.00893064, 0.00496194, -0.01026699, -0.01189859, 270 | 0.08226445, -0.00000000, -0.08226445, 0.01189859, 271 | 0.01026699, -0.01455399, -0.02755155, 0.11732297, 272 | -0.00000000, -0.11732297, 0.02755155, 0.01455399, 273 | -0.01026699, -0.01189859, 0.08226445, -0.00000000, 274 | -0.08226445, 0.01189859, 0.01026699, -0.00496194, 275 | 0.00893064, 0.03075356, -0.00000000, -0.03075356, 276 | -0.00893064, 0.00496194, -0.00277643, 0.00986904, 277 | 0.01021852, -0.00000000, -0.01021852, -0.00986904, 278 | 0.00277643], 279 | [-0.01166982, -0.00128000, 0.00459034, 0.00790407, 280 | 0.01047717, -0.00358461, -0.00343249, -0.00285723, 281 | -0.01161195, -0.00853965, 0.04435647, 0.01486305, 282 | -0.01977507, -0.00640815, -0.00182078, -0.03930573, 283 | 0.05394139, 0.09454202, -0.04819057, -0.04084211, 284 | -0.00073141, -0.01124321, 0.00228219, 0.12227230, 285 | -0.00000000, -0.12227230, -0.00228219, 0.01124321, 286 | 0.00073141, 0.04084211, 0.04819057, -0.09454202, 287 | -0.05394139, 0.03930573, 0.00182078, 0.00640815, 288 | 0.01977507, -0.01486305, -0.04435647, 0.00853965, 289 | 0.01161195, 0.00285723, 0.00343249, 0.00358461, 290 | -0.01047717, -0.00790407, -0.00459034, 0.00128000, 291 | 0.01166982], 292 | [-0.01166982, -0.00285723, -0.00182078, -0.01124321, 293 | 0.00073141, 0.00640815, 0.00343249, -0.00128000, 294 | -0.01161195, -0.03930573, 0.00228219, 0.04084211, 295 | 0.01977507, 0.00358461, 0.00459034, -0.00853965, 296 | 0.05394139, 0.12227230, 0.04819057, -0.01486305, 297 | -0.01047717, 0.00790407, 0.04435647, 0.09454202, 298 | -0.00000000, -0.09454202, -0.04435647, -0.00790407, 299 | 0.01047717, 0.01486305, -0.04819057, -0.12227230, 300 | -0.05394139, 0.00853965, -0.00459034, -0.00358461, 301 | -0.01977507, -0.04084211, -0.00228219, 0.03930573, 302 | 0.01161195, 0.00128000, -0.00343249, -0.00640815, 303 | -0.00073141, 0.01124321, 0.00182078, 0.00285723, 304 | 0.01166982]]).T) 305 | return filters -------------------------------------------------------------------------------- /IQA_pytorch/VIF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from torchvision import transforms 6 | from .utils import fspecial_gauss 7 | from .SteerPyrSpace import SteerablePyramidSpace 8 | import math 9 | 10 | 11 | class VIF(torch.nn.Module): 12 | # Refer to https://live.ece.utexas.edu/research/Quality/VIF.htm 13 | 14 | def __init__(self, channels=3, level=4, ori=6, device = torch.device("cuda")): 15 | 16 | super(VIF, self).__init__() 17 | self.ori = ori-1 18 | self.level = level 19 | self.channels = channels 20 | self.M=3 21 | self.subbands=[4, 7, 10, 13, 16, 19, 22, 25] 22 | self.sigma_nsq=0.4 23 | self.tol = 1e-12 24 | 25 | def corrDn(self, image, filt, step=1, channels=1,start=[0,0],end=[0,0]): 26 | 27 | filt_ = torch.from_numpy(filt).float().unsqueeze(0).unsqueeze(0).repeat(channels,1,1,1).to(image.device) 28 | p = (filt_.shape[2]-1)//2 29 | image = F.pad(image, (p,p,p,p),'reflect') 30 | img = F.conv2d(image, filt_, stride=1, padding=0, groups = channels) 31 | img = img[:,:,start[0]:end[0]:step,start[1]:end[1]:step] 32 | return img 33 | 34 | def vifsub_est_M(self, org, dist): 35 | 36 | g_all = [] 37 | vv_all = [] 38 | for i in range(len(self.subbands)): 39 | sub=self.subbands[i]-1 40 | y=org[sub] 41 | yn=dist[sub] 42 | 43 | lev=np.ceil((sub-1)/6) 44 | winsize=int(2**lev+1) 45 | win = np.ones((winsize,winsize)) 46 | 47 | newsizeX=int(np.floor(y.shape[2]/self.M)*self.M) 48 | newsizeY=int(np.floor(y.shape[3]/self.M)*self.M) 49 | y=y[:,:,:newsizeX,:newsizeY] 50 | yn=yn[:,:,:newsizeX,:newsizeY] 51 | 52 | winstart=[int(1*np.floor(self.M/2)),int(1*np.floor(self.M/2))] 53 | winend=[int(y.shape[2]-np.ceil(self.M/2))+1,int(y.shape[3]-np.ceil(self.M/2))+1] 54 | 55 | mean_x = self.corrDn(y,win/(winsize**2),step=self.M, channels=self.channels,start=winstart,end=winend) 56 | mean_y = self.corrDn(yn,win/(winsize**2),step=self.M, channels=self.channels,start=winstart,end=winend) 57 | cov_xy = self.corrDn(y*yn, win, step=self.M, channels=self.channels,start=winstart,end=winend) - (winsize**2)*mean_x*mean_y 58 | ss_x = self.corrDn(y**2,win, step=self.M, channels=self.channels,start=winstart,end=winend) - (winsize**2)*mean_x**2 59 | ss_y = self.corrDn(yn**2,win, step=self.M, channels=self.channels,start=winstart,end=winend) - (winsize**2)*mean_y**2 60 | 61 | ss_x = F.relu(ss_x) 62 | ss_y = F.relu(ss_y) 63 | 64 | g = cov_xy/(ss_x+self.tol) 65 | vv = (ss_y - g*cov_xy)/(winsize**2) 66 | 67 | g = g.masked_fill(ss_x < self.tol,0) 68 | vv [ss_x < self.tol] = ss_y [ss_x < self.tol] 69 | ss_x = ss_x.masked_fill(ss_x < self.tol,0) 70 | 71 | g = g.masked_fill(ss_y < self.tol,0) 72 | vv = vv.masked_fill(ss_y < self.tol,0) 73 | 74 | vv[g<0]=ss_y[g<0] 75 | g = F.relu(g) 76 | 77 | vv = vv.masked_fill(vv < self.tol, self.tol) 78 | 79 | g_all.append(g) 80 | vv_all.append(vv) 81 | return g_all, vv_all 82 | 83 | def refparams_vecgsm(self, org): 84 | ssarr, l_arr, cu_arr = [], [], [] 85 | for i in range(len(self.subbands)): 86 | sub=self.subbands[i]-1 87 | y=org[sub] 88 | M = self.M 89 | newsizeX=int(np.floor(y.shape[2]/M)*M) 90 | newsizeY=int(np.floor(y.shape[3]/M)*M) 91 | y=y[:,:,:newsizeX,:newsizeY] 92 | B,C,H,W = y.shape 93 | 94 | temp=[] 95 | for j in range(M): 96 | for k in range(M): 97 | temp.append(y[:,:,k:H-(M-k)+1, j:W-(M-j)+1].reshape(B,C,-1)) 98 | temp = torch.stack(temp,dim=3) 99 | mcu = torch.mean(temp,dim=2).unsqueeze(2).repeat(1,1,temp.shape[2],1) 100 | cu=torch.matmul((temp-mcu).permute(0,1,3,2),temp-mcu)/temp.shape[2] 101 | 102 | temp=[] 103 | for j in range(M): 104 | for k in range(M): 105 | temp.append(y[:,:,k:H+1:M, j:W+1:M].reshape(B,C,-1)) 106 | temp = torch.stack(temp,dim=2) 107 | ss=torch.matmul(torch.pinverse(cu),temp) 108 | # ss = torch.matmul(torch.pinverse(cu),temp) 109 | ss=torch.sum(ss*temp,dim=2)/(M*M) 110 | ss=ss.reshape(B,C,H//M,W//M) 111 | v,_ = torch.symeig(cu,eigenvectors=True) 112 | l_arr.append(v) 113 | ssarr.append(ss) 114 | cu_arr.append(cu) 115 | 116 | return ssarr, l_arr, cu_arr 117 | 118 | def vif(self, x, y): 119 | sp_x = SteerablePyramidSpace(x, height=self.level, order=self.ori, channels=self.channels)[::-1] 120 | sp_y = SteerablePyramidSpace(y, height=self.level, order=self.ori, channels=self.channels)[::-1] 121 | g_all, vv_all = self.vifsub_est_M(sp_y, sp_x) 122 | ss_arr, l_arr, cu_arr = self.refparams_vecgsm(sp_y) 123 | num, den = [], [] 124 | 125 | for i in range(len(self.subbands)): 126 | sub=self.subbands[i] 127 | g=g_all[i] 128 | vv=vv_all[i] 129 | ss=ss_arr[i] 130 | lamda = l_arr[i] 131 | neigvals=lamda.shape[2] 132 | lev=np.ceil((sub-1)/6) 133 | winsize=2**lev+1 134 | offset=(winsize-1)/2 135 | offset=int(np.ceil(offset/self.M)) 136 | 137 | _,_,H,W = g.shape 138 | g= g[:,:,offset:H-offset,offset:W-offset] 139 | vv=vv[:,:,offset:H-offset,offset:W-offset] 140 | ss=ss[:,:,offset:H-offset,offset:W-offset] 141 | 142 | temp1=0 143 | temp2=0 144 | for j in range(neigvals): 145 | cc = lamda[:,:,j].unsqueeze(2).unsqueeze(3) 146 | temp1=temp1+torch.sum(torch.log2(1+g*g*ss*cc/(vv+self.sigma_nsq)),dim=[2,3]) 147 | temp2=temp2+torch.sum(torch.log2(1+ss*cc/(self.sigma_nsq)),dim=[2,3]) 148 | num.append(temp1.mean(1)) 149 | den.append(temp2.mean(1)) 150 | 151 | return torch.stack(num,dim=1).sum(1)/(torch.stack(den,dim=1).sum(1)+1e-12) 152 | 153 | def forward(self, y, x, as_loss=True): 154 | assert x.shape == y.shape 155 | x = x * 255 156 | y = y * 255 157 | if as_loss: 158 | score = self.vif(x, y) 159 | return 1 - score.mean() 160 | else: 161 | with torch.no_grad(): 162 | score = self.vif(x, y) 163 | return score 164 | 165 | if __name__ == '__main__': 166 | from PIL import Image 167 | import argparse 168 | from utils import prepare_image 169 | 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument('--ref', type=str, default='images/r0.png') 172 | parser.add_argument('--dist', type=str, default='images/r1.png') 173 | args = parser.parse_args() 174 | 175 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 176 | 177 | ref = prepare_image(Image.open(args.ref).convert("L"),repeatNum=1).to(device) 178 | dist = prepare_image(Image.open(args.dist).convert("L"),repeatNum=1).to(device) 179 | dist.requires_grad_(True) 180 | model = VIF(channels=1) 181 | 182 | score = model(dist, ref, as_loss=False) 183 | print('score: %.4f' % score.item()) 184 | # score: 0.1804 185 | -------------------------------------------------------------------------------- /IQA_pytorch/VIFs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from .utils import fspecial_gauss 6 | 7 | class VIFs(nn.Module): 8 | def __init__(self, channels=3): 9 | super(VIFs, self).__init__() 10 | '''spatial domain VIF 11 | https://live.ece.utexas.edu/research/Quality/VIF.htm 12 | ''' 13 | self.channels = channels 14 | self.eps = 1e-10 15 | 16 | def vif(self, img1, img2): 17 | num = 0 18 | den = 0 19 | sigma_nsq=2 20 | channels = self.channels 21 | eps = self.eps 22 | for scale in range(1,5): 23 | N = 2**(4-scale+1)+1 24 | win = fspecial_gauss(N,N/5,channels).to(img1.device) 25 | 26 | if scale > 1: 27 | img1 = F.conv2d(img1, win, padding =0, groups = channels) 28 | img2 = F.conv2d(img2, win, padding =0, groups = channels) 29 | img1 = img1[:,:,0::2,0::2] 30 | img2 = img2[:,:,0::2,0::2] 31 | 32 | mu1 = F.conv2d(img1, win, padding =0, groups = channels) 33 | mu2 = F.conv2d(img2, win, padding =0, groups = channels) 34 | mu1_sq = mu1*mu1 35 | mu2_sq = mu2*mu2 36 | mu1_mu2 = mu1*mu2 37 | sigma1_sq = F.conv2d(img1*img1, win, padding =0, groups = channels) - mu1_sq 38 | sigma2_sq = F.conv2d(img2*img2, win, padding =0, groups = channels) - mu2_sq 39 | sigma12 = F.conv2d(img1*img2, win, padding =0, groups = channels) - mu1_mu2 40 | 41 | sigma1_sq = F.relu(sigma1_sq) 42 | sigma2_sq = F.relu(sigma2_sq) 43 | 44 | g = sigma12/(sigma1_sq+eps) 45 | sv_sq = sigma2_sq-g*sigma12 46 | sigma1_sq = F.relu(sigma1_sq-eps) 47 | 48 | g = g.masked_fill(sigma2_sq 0.04045).type(torch.FloatTensor).to(device) 40 | rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask 41 | 42 | rgb_to_xyz = torch.tensor([ 43 | # X Y Z 44 | [0.412453, 0.212671, 0.019334], # R 45 | [0.357580, 0.715160, 0.119193], # G 46 | [0.180423, 0.072169, 0.950227], # B 47 | ]).type(torch.FloatTensor).to(device)#.unsqueeze(0).repeat(B,1,1) 48 | 49 | xyz_pixels = torch.matmul(rgb_pixels, rgb_to_xyz) 50 | 51 | # XYZ to Lab 52 | t = torch.tensor([1/0.950456, 1.0, 1/1.088754]).type(torch.FloatTensor).to(device) 53 | xyz_normalized_pixels = (xyz_pixels / t) 54 | 55 | epsilon = 6.0/29.0 56 | linear_mask = (xyz_normalized_pixels <= (epsilon**3)).type(torch.FloatTensor).to(device) 57 | exponential_mask = (xyz_normalized_pixels > (epsilon**3)).type(torch.FloatTensor).to(device) 58 | fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4.0/29.0) * linear_mask + ((xyz_normalized_pixels).clamp_(min=eps).pow(1/3)) * exponential_mask 59 | # convert to lab 60 | fxfyfz_to_lab = torch.tensor([ 61 | # l a b 62 | [ 0.0, 500.0, 0.0], # fx 63 | [116.0, -500.0, 200.0], # fy 64 | [ 0.0, 0.0, -200.0], # fz 65 | ]).type(torch.FloatTensor).to(device)#.unsqueeze(0).repeat(B,1,1) 66 | lab_pixels = torch.matmul(fxfyfz_pixels, fxfyfz_to_lab) + torch.tensor([-16.0, 0.0, 0.0]).type(torch.FloatTensor).to(device) 67 | 68 | return torch.reshape(lab_pixels.permute(0,2,1), srgb.shape) 69 | 70 | def SDSP(img,sigmaF,omega0,sigmaD,sigmaC): 71 | B, C, rows, cols = img.shape 72 | 73 | lab = rgb_to_lab_NCHW(img/255) 74 | LChannel, AChannel, BChannel = lab[:,0,:,:].unsqueeze(1),lab[:,1,:,:].unsqueeze(1),lab[:,2,:,:].unsqueeze(1) 75 | LFFT = torch.rfft(LChannel,2,onesided=False) 76 | AFFT = torch.rfft(AChannel,2,onesided=False) 77 | BFFT = torch.rfft(BChannel,2,onesided=False) 78 | 79 | LG = logGabor(rows,cols,omega0,sigmaF) 80 | LG = torch.from_numpy(LG).reshape(1, 1, rows,cols,1).repeat(B,1,1,1,2).float().to(img.device) 81 | 82 | FinalLResult = real(torch.ifft(LFFT*LG,2)) 83 | FinalAResult = real(torch.ifft(AFFT*LG,2)) 84 | FinalBResult = real(torch.ifft(BFFT*LG,2)) 85 | 86 | SFMap = torch.sqrt(FinalLResult**2 + FinalAResult**2 + FinalBResult**2+eps) 87 | 88 | coordinateMtx = torch.from_numpy(np.arange(0,rows)).float().reshape(1,1,rows,1).repeat(B,1,1,cols).to(img.device) 89 | centerMtx = torch.ones_like(coordinateMtx)*rows/2 90 | coordinateMty = torch.from_numpy(np.arange(0,cols)).float().reshape(1,1,1,cols).repeat(B,1,rows,1).to(img.device) 91 | centerMty = torch.ones_like(coordinateMty)*cols/2 92 | SDMap = torch.exp(-((coordinateMtx - centerMtx)**2+(coordinateMty - centerMty)**2)/(sigmaD**2)) 93 | 94 | normalizedA = spatial_normalize(AChannel) 95 | 96 | normalizedB = spatial_normalize(BChannel) 97 | 98 | labDistSquare = normalizedA**2 + normalizedB**2 99 | SCMap = 1 - torch.exp(-labDistSquare / (sigmaC**2)) 100 | VSMap = SFMap * SDMap * SCMap 101 | 102 | normalizedVSMap = spatial_normalize(VSMap) 103 | return normalizedVSMap 104 | 105 | def vsi(image1,image2): 106 | 107 | constForVS = 1.27 108 | constForGM = 386 109 | constForChrom = 130 110 | alpha = 0.40 111 | lamda = 0.020 112 | sigmaF = 1.34 113 | omega0 = 0.0210 114 | sigmaD = 145 115 | sigmaC = 0.001 116 | 117 | saliencyMap1 = SDSP(image1,sigmaF,omega0,sigmaD,sigmaC) 118 | saliencyMap2 = SDSP(image2,sigmaF,omega0,sigmaD,sigmaC) 119 | 120 | L1 = (0.06 * image1[:,0,:,:] + 0.63 * image1[:,1,:,:] + 0.27 * image1[:,2,:,:]).unsqueeze(1) 121 | L2 = (0.06 * image2[:,0,:,:] + 0.63 * image2[:,1,:,:] + 0.27 * image2[:,2,:,:]).unsqueeze(1) 122 | M1 = (0.30 * image1[:,0,:,:] + 0.04 * image1[:,1,:,:] - 0.35 * image1[:,2,:,:]).unsqueeze(1) 123 | M2 = (0.30 * image2[:,0,:,:] + 0.04 * image2[:,1,:,:] - 0.35 * image2[:,2,:,:]).unsqueeze(1) 124 | N1 = (0.34 * image1[:,0,:,:] - 0.60 * image1[:,1,:,:] + 0.17 * image1[:,2,:,:]).unsqueeze(1) 125 | N2 = (0.34 * image2[:,0,:,:] - 0.60 * image2[:,1,:,:] + 0.17 * image2[:,2,:,:]).unsqueeze(1) 126 | 127 | L1, L2 = downsample(L1, L2) 128 | M1, M2 = downsample(M1, M2) 129 | N1, N2 = downsample(N1, N2) 130 | saliencyMap1, saliencyMap2 = downsample(saliencyMap1, saliencyMap2) 131 | 132 | dx = torch.Tensor([[3, 0, -3], [10, 0, -10], [3, 0, -3]]).float()/16 133 | dy = torch.Tensor([[3, 10, 3], [0, 0, 0], [-3, -10, -3]]).float()/16 134 | dx = dx.reshape(1,1,3,3).to(image1.device) 135 | dy = dy.reshape(1,1,3,3).to(image1.device) 136 | IxY1 = F.conv2d(L1, dx, stride=1, padding =1) 137 | IyY1 = F.conv2d(L1, dy, stride=1, padding =1) 138 | gradientMap1 = torch.sqrt(IxY1**2 + IyY1**2+eps) 139 | IxY2 = F.conv2d(L2, dx, stride=1, padding =1) 140 | IyY2 = F.conv2d(L2, dy, stride=1, padding =1) 141 | gradientMap2 = torch.sqrt(IxY2**2 + IyY2**2+eps) 142 | 143 | 144 | VSSimMatrix = (2 * saliencyMap1 * saliencyMap2 + constForVS) / (saliencyMap1**2 + saliencyMap2**2 + constForVS) 145 | gradientSimMatrix = (2*gradientMap1*gradientMap2 + constForGM) /(gradientMap1**2 + gradientMap2**2 + constForGM) 146 | 147 | weight = torch.max(saliencyMap1, saliencyMap2) 148 | 149 | ISimMatrix = (2 * M1 * M2 + constForChrom) / (M1**2 + M2**2 + constForChrom) 150 | QSimMatrix = (2 * N1 * N2 + constForChrom) / (N1**2 + N2**2 + constForChrom) 151 | 152 | # SimMatrixC = (torch.sign(gradientSimMatrix) * (torch.abs(gradientSimMatrix)+eps) ** alpha) * VSSimMatrix * \ 153 | # (torch.sign(ISimMatrix * QSimMatrix)*(torch.abs(ISimMatrix * QSimMatrix)+eps) ** lamda) * weight 154 | SimMatrixC = ((torch.abs(gradientSimMatrix)+eps) ** alpha) * VSSimMatrix * \ 155 | ((torch.abs(ISimMatrix * QSimMatrix)+eps) ** lamda) * weight 156 | 157 | return torch.sum(SimMatrixC,dim=[1,2,3]) / torch.sum(weight,dim=[1,2,3]) 158 | 159 | class VSI(torch.nn.Module): 160 | # Refer to https://sse.tongji.edu.cn/linzhang/IQA/VSI/VSI.htm 161 | 162 | def __init__(self, channels=3): 163 | super(VSI, self).__init__() 164 | assert channels == 3 165 | 166 | def forward(self, y, x, as_loss=True): 167 | assert x.shape == y.shape 168 | x = x * 255 169 | y = y * 255 170 | if as_loss: 171 | score = vsi(x, y) 172 | return 1 - score.mean() 173 | else: 174 | with torch.no_grad(): 175 | score = vsi(x, y) 176 | return score 177 | 178 | 179 | if __name__ == '__main__': 180 | from PIL import Image 181 | import argparse 182 | 183 | parser = argparse.ArgumentParser() 184 | parser.add_argument('--ref', type=str, default='images/r0.png') 185 | parser.add_argument('--dist', type=str, default='images/r1.png') 186 | args = parser.parse_args() 187 | 188 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 189 | 190 | ref = prepare_image(Image.open(args.ref).convert("RGB")).to(device) 191 | dist = prepare_image(Image.open(args.dist).convert("RGB")).to(device) 192 | 193 | model = VSI().to(device) 194 | score = model(dist, ref, as_loss=False) 195 | print('score: %.4f' % score.item()) 196 | # score: 0.9322 197 | 198 | -------------------------------------------------------------------------------- /IQA_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .SSIM import SSIM 2 | from .MS_SSIM import MS_SSIM 3 | from .CW_SSIM import CW_SSIM 4 | from .GMSD import GMSD 5 | from .NLPD import NLPD 6 | from .FSIM import FSIM 7 | from .VSI import VSI 8 | from .VIF import VIF 9 | from .VIFs import VIFs 10 | from .MAD import MAD 11 | from .LPIPSvgg import LPIPSvgg 12 | from .DISTS import DISTS -------------------------------------------------------------------------------- /IQA_pytorch/__pycache__/CW_SSIM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/__pycache__/CW_SSIM.cpython-37.pyc -------------------------------------------------------------------------------- /IQA_pytorch/__pycache__/DISTS.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/__pycache__/DISTS.cpython-37.pyc -------------------------------------------------------------------------------- /IQA_pytorch/__pycache__/MAD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/__pycache__/MAD.cpython-37.pyc -------------------------------------------------------------------------------- /IQA_pytorch/__pycache__/SteerPyrComplex.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/__pycache__/SteerPyrComplex.cpython-37.pyc -------------------------------------------------------------------------------- /IQA_pytorch/__pycache__/SteerPyrSpace.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/__pycache__/SteerPyrSpace.cpython-37.pyc -------------------------------------------------------------------------------- /IQA_pytorch/__pycache__/SteerPyrUtils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/__pycache__/SteerPyrUtils.cpython-37.pyc -------------------------------------------------------------------------------- /IQA_pytorch/__pycache__/VIF.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/__pycache__/VIF.cpython-37.pyc -------------------------------------------------------------------------------- /IQA_pytorch/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /IQA_pytorch/images/r0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/images/r0.png -------------------------------------------------------------------------------- /IQA_pytorch/images/r1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/images/r1.png -------------------------------------------------------------------------------- /IQA_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision import transforms 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | 9 | def abs(x): 10 | return torch.sqrt(x[:,:,:,:,0]**2+x[:,:,:,:,1]**2+1e-12) 11 | 12 | def real(x): 13 | return x[:,:,:,:,0] 14 | 15 | def imag(x): 16 | return x[:,:,:,:,1] 17 | 18 | def roll_n(X, axis, n): 19 | f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim())) 20 | b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim())) 21 | front = X[f_idx] 22 | back = X[b_idx] 23 | return torch.cat([back, front], axis) 24 | 25 | def batch_fftshift2d(x): 26 | real, imag = torch.unbind(x, -1) 27 | for dim in range(1, len(real.size())): 28 | n_shift = real.size(dim)//2 29 | if real.size(dim) % 2 != 0: 30 | n_shift += 1 # for odd-sized images 31 | real = roll_n(real, axis=dim, n=n_shift) 32 | imag = roll_n(imag, axis=dim, n=n_shift) 33 | return torch.stack((real, imag), -1) # last dim=2 (real&imag) 34 | 35 | def batch_ifftshift2d(x): 36 | real, imag = torch.unbind(x, -1) 37 | for dim in range(len(real.size()) - 1, 0, -1): 38 | real = roll_n(real, axis=dim, n=real.size(dim)//2) 39 | imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) 40 | return torch.stack((real, imag), -1) # last dim=2 (real&imag) 41 | 42 | def preprocess_lab(lab): 43 | L_chan, a_chan, b_chan =torch.unbind(lab,dim=2) 44 | # L_chan: black and white with input range [0, 100] 45 | # a_chan/b_chan: color channels with input range ~[-110, 110], not exact 46 | # [0, 100] => [-1, 1], ~[-110, 110] => [-1, 1] 47 | return [L_chan / 50.0 - 1.0, a_chan / 110.0, b_chan / 110.0] 48 | 49 | def deprocess_lab(L_chan, a_chan, b_chan): 50 | #TODO This is axis=3 instead of axis=2 when deprocessing batch of images 51 | # ( we process individual images but deprocess batches) 52 | #return tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3) 53 | return torch.stack([(L_chan + 1) / 2.0 * 100.0, a_chan * 110.0, b_chan * 110.0], dim=2) 54 | 55 | def rgb_to_lab(srgb): 56 | srgb = srgb/255 57 | srgb_pixels = torch.reshape(srgb, [-1, 3]) 58 | linear_mask = (srgb_pixels <= 0.04045).type(torch.FloatTensor).to(device) 59 | exponential_mask = (srgb_pixels > 0.04045).type(torch.FloatTensor).to(device) 60 | rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask 61 | 62 | rgb_to_xyz = torch.tensor([ 63 | # X Y Z 64 | [0.412453, 0.212671, 0.019334], # R 65 | [0.357580, 0.715160, 0.119193], # G 66 | [0.180423, 0.072169, 0.950227], # B 67 | ]).type(torch.FloatTensor).to(device) 68 | 69 | xyz_pixels = torch.mm(rgb_pixels, rgb_to_xyz) 70 | 71 | 72 | # XYZ to Lab 73 | xyz_normalized_pixels = torch.mul(xyz_pixels, torch.tensor([1/0.950456, 1.0, 1/1.088754]).type(torch.FloatTensor).to(device)) 74 | 75 | epsilon = 6.0/29.0 76 | linear_mask = (xyz_normalized_pixels <= (epsilon**3)).type(torch.FloatTensor).to(device) 77 | exponential_mask = (xyz_normalized_pixels > (epsilon**3)).type(torch.FloatTensor).to(device) 78 | fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4.0/29.0) * linear_mask + ((xyz_normalized_pixels+0.000001) ** (1.0/3.0)) * exponential_mask 79 | # convert to lab 80 | fxfyfz_to_lab = torch.tensor([ 81 | # l a b 82 | [ 0.0, 500.0, 0.0], # fx 83 | [116.0, -500.0, 200.0], # fy 84 | [ 0.0, 0.0, -200.0], # fz 85 | ]).type(torch.FloatTensor).to(device) 86 | lab_pixels = torch.mm(fxfyfz_pixels, fxfyfz_to_lab) + torch.tensor([-16.0, 0.0, 0.0]).type(torch.FloatTensor).to(device) 87 | #return tf.reshape(lab_pixels, tf.shape(srgb)) 88 | return torch.reshape(lab_pixels, srgb.shape) 89 | 90 | def lab_to_rgb(lab): 91 | lab_pixels = torch.reshape(lab, [-1, 3]) 92 | # convert to fxfyfz 93 | lab_to_fxfyfz = torch.tensor([ 94 | # fx fy fz 95 | [1/116.0, 1/116.0, 1/116.0], # l 96 | [1/500.0, 0.0, 0.0], # a 97 | [ 0.0, 0.0, -1/200.0], # b 98 | ]).type(torch.FloatTensor).to(device) 99 | fxfyfz_pixels = torch.mm(lab_pixels + torch.tensor([16.0, 0.0, 0.0]).type(torch.FloatTensor).to(device), lab_to_fxfyfz) 100 | 101 | # convert to xyz 102 | epsilon = 6.0/29.0 103 | linear_mask = (fxfyfz_pixels <= epsilon).type(torch.FloatTensor).to(device) 104 | exponential_mask = (fxfyfz_pixels > epsilon).type(torch.FloatTensor).to(device) 105 | 106 | 107 | xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29.0)) * linear_mask + ((fxfyfz_pixels+0.000001) ** 3) * exponential_mask 108 | 109 | # denormalize for D65 white point 110 | xyz_pixels = torch.mul(xyz_pixels, torch.tensor([0.950456, 1.0, 1.088754]).type(torch.FloatTensor).to(device)) 111 | 112 | 113 | xyz_to_rgb = torch.tensor([ 114 | # r g b 115 | [ 3.2404542, -0.9692660, 0.0556434], # x 116 | [-1.5371385, 1.8760108, -0.2040259], # y 117 | [-0.4985314, 0.0415560, 1.0572252], # z 118 | ]).type(torch.FloatTensor).to(device) 119 | 120 | rgb_pixels = torch.mm(xyz_pixels, xyz_to_rgb) 121 | # avoid a slightly negative number messing up the conversion 122 | #clip 123 | rgb_pixels[rgb_pixels > 1] = 1 124 | rgb_pixels[rgb_pixels < 0] = 0 125 | 126 | linear_mask = (rgb_pixels <= 0.0031308).type(torch.FloatTensor).to(device) 127 | exponential_mask = (rgb_pixels > 0.0031308).type(torch.FloatTensor).to(device) 128 | srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + (((rgb_pixels+0.000001) ** (1/2.4) * 1.055) - 0.055) * exponential_mask 129 | 130 | return torch.reshape(srgb_pixels, lab.shape) 131 | 132 | def spatial_normalize(x): 133 | min_v = torch.min(x.view(x.shape[0],1,-1),dim=2)[0] 134 | range_v = torch.max(x.view(x.shape[0],1,-1),dim=2)[0] - min_v 135 | return (x - min_v.unsqueeze(2).unsqueeze(3)) / (range_v.unsqueeze(2).unsqueeze(3)+1e-12) 136 | 137 | def fspecial_gauss(size, sigma, channels): 138 | # Function to mimic the 'fspecial' gaussian MATLAB function 139 | x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 140 | g = np.exp(-((x**2 + y**2)/(2.0*sigma**2))) 141 | g = torch.from_numpy(g/g.sum()).float().unsqueeze(0).unsqueeze(0) 142 | return g.repeat(channels,1,1,1) 143 | 144 | def downsample(img1, img2, maxSize = 256): 145 | _,channels,H,W = img1.shape 146 | f = int(max(1,np.round(min(H,W)/maxSize))) 147 | if f>1: 148 | aveKernel = (torch.ones(channels,1,f,f)/f**2).to(img1.device) 149 | img1 = F.conv2d(img1, aveKernel, stride=f, padding = 0, groups = channels) 150 | img2 = F.conv2d(img2, aveKernel, stride=f, padding = 0, groups = channels) 151 | return img1, img2 152 | 153 | def extract_patches_2d(img, patch_shape=[64, 64], step=[27,27], batch_first=True, keep_last_patch=False): 154 | patch_H, patch_W = patch_shape[0], patch_shape[1] 155 | if(img.size(2)256: 182 | image = transforms.functional.resize(image,256) 183 | image = transforms.ToTensor()(image) 184 | return image.unsqueeze(0).repeat(repeatNum,1,1,1) 185 | 186 | def print_network(net): 187 | num_params = 0 188 | for param in net.parameters(): 189 | num_params += param.numel() 190 | print(net) 191 | print('Total number of parameters: %d' % num_params) 192 | 193 | -------------------------------------------------------------------------------- /IQA_pytorch/weights/DISTS.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/weights/DISTS.pt -------------------------------------------------------------------------------- /IQA_pytorch/weights/LPIPSvgg.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/IQA_pytorch/weights/LPIPSvgg.pt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 dingkeyan93 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | -- IQA_pytorch 2 | -- IQA_pytorch 3 | -- __init__.py 4 | -- SSIM.py 5 | -- MS_SSIM.py 6 | -- CW_SSIM.py 7 | -- FSIM.py 8 | -- VSI.py 9 | -- VIF.py 10 | -- VIFs.py 11 | -- NLPD.py 12 | -- GMSD.py 13 | -- MAD.py 14 | -- LPIPSvgg.py 15 | -- DISTS.py 16 | -- weights 17 | --DISTS.pt 18 | --LPIPSvgg.pt 19 | -- README.md 20 | -- requirements.txt 21 | -- LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Perceptual Optimization of Image Quality Assessment (IQA) Models 2 | 3 | This repository re-implemented the existing IQA models with PyTorch, including 4 | - [SSIM](https://www.cns.nyu.edu/~lcv/ssim/), [MS-SSIM](https://ece.uwaterloo.ca/~z70wang/publications/msssim.html), [CW-SSIM](https://www.mathworks.com/matlabcentral/fileexchange/43017-complex-wavelet-structural-similarity-index-cw-ssim), 5 | - [FSIM](https://sse.tongji.edu.cn/linzhang/IQA/FSIM/FSIM.htm), [VSI](https://sse.tongji.edu.cn/linzhang/IQA/VSI/VSI.htm), [GMSD](https://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.htm), 6 | - [NLPD](https://www.cns.nyu.edu/~lcv/NLPyr/), [MAD](http://vision.eng.shizuoka.ac.jp/mod/url/view.php?id=54), 7 | - [VIF](https://live.ece.utexas.edu/research/Quality/VIF.htm), 8 | - [LPIPS](https://github.com/richzhang/PerceptualSimilarity), [DISTS](https://github.com/dingkeyan93/DISTS). 9 | 10 | **Note:** The reproduced results may be a little different from the original matlab version. 11 | 12 | #### Installation: 13 | - ```pip install IQA_pytorch``` 14 | 15 | #### Requirements: 16 | - Python>=3.6 17 | - Pytorch>=1.2 18 | 19 | #### Usage: 20 | ```python 21 | from IQA_pytorch import SSIM, GMSD, LPIPSvgg, DISTS 22 | D = SSIM(channels=3) 23 | # Calculate score of the image X with the reference Y 24 | # X: (N,3,H,W) 25 | # Y: (N,3,H,W) 26 | # Tensor, data range: 0~1 27 | score = D(X, Y, as_loss=False) 28 | # set 'as_loss=True' to get a value as loss for optimizations. 29 | loss = D(X, Y, as_loss=True) 30 | loss.backward() 31 | ``` 32 | 33 | ### DNN-based optimization examples: 34 | - Image denoising 35 | - Blind image deblurring 36 | - Single image super-resolution 37 | - Lossy image compression 38 | 39 | ![diagram](images/diagram.svg) 40 | 41 | For the experiment results, please see [Comparison of Image Quality Models for Optimization of Image Processing Systems 42 | ](https://arxiv.org/abs/2005.01338) 43 | 44 | ### Citation: 45 | ``` 46 | @article{ding2020optim, 47 | title={Comparison of Image Quality Models for Optimization of Image Processing Systems}, 48 | author={Ding, Keyan and Ma, Kede and Wang, Shiqi and Simoncelli, Eero P.}, 49 | journal = {CoRR}, 50 | volume = {abs/2005.01338}, 51 | year={2020}, 52 | url = {https://arxiv.org/abs/2005.01338} 53 | } 54 | ``` -------------------------------------------------------------------------------- /examples/recover.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from torch import optim 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | from torchvision import transforms 9 | import imageio 10 | 11 | from IQA_pytorch import SSIM, MS_SSIM, CW_SSIM, GMSD, LPIPSvgg, DISTS, NLPD, FSIM, VSI, VIFs, VIF, MAD 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | ref_path = 'images/r0.png' 16 | pred_path = 'images/r1.png' 17 | 18 | model = SSIM(channels=3).to(device) 19 | transform = transforms.Compose( 20 | [ 21 | transforms.ToTensor(), 22 | ] 23 | ) 24 | ref_img = Image.open(ref_path).convert("RGB") 25 | ref = transform(ref_img).unsqueeze(0) 26 | ref = Variable(ref.float().to(device), requires_grad=False) 27 | 28 | pred_img = Image.open(pred_path).convert("RGB") 29 | pred = transform(pred_img).unsqueeze(0) 30 | pred = Variable(pred.float().to(device), requires_grad=True) 31 | 32 | # pred = torch.rand_like(pred) 33 | # pred.requires_grad_(True) 34 | # pred_img = pred.squeeze().data.cpu().numpy().transpose(1, 2, 0) 35 | # pred_img = pred.squeeze().data.cpu().numpy().transpose(1, 2, 0) 36 | 37 | model.eval() 38 | fig = plt.figure(figsize=(4,1.5),dpi=300) 39 | plt.subplot(131) 40 | plt.imshow(pred_img) 41 | plt.title('initial',fontsize=6) 42 | plt.axis('off') 43 | plt.subplot(133) 44 | plt.imshow(ref_img) 45 | plt.title('reference',fontsize=6) 46 | plt.axis('off') 47 | 48 | lr = 0.005 49 | optimizer = torch.optim.Adam([pred], lr=lr) 50 | 51 | for i in range(20000): 52 | dist = model(pred, ref) 53 | optimizer.zero_grad() 54 | dist.backward() 55 | # torch.nn.utils.clip_grad_norm_([pred], 1) 56 | optimizer.step() 57 | pred.data.clamp_(min=0,max=1) 58 | 59 | # print(dist.item()) 60 | # break 61 | 62 | if i % 50 == 0: 63 | pred_img = pred.squeeze().data.cpu().numpy().transpose(1, 2, 0) 64 | plt.subplot(132) 65 | plt.imshow(np.clip(pred_img, 0, 1)) 66 | # imageio.imwrite('results/temp.png',pred_img) 67 | plt.title('iter: %d, dists: %.3g' % (i, dist.item()),fontsize=6) 68 | plt.axis('off') 69 | plt.pause(1) 70 | plt.cla() 71 | 72 | if (i+1) % 2000 == 0: 73 | lr = max(1e-4, lr*0.5) 74 | optimizer = torch.optim.Adam([pred], lr=lr) 75 | 76 | -------------------------------------------------------------------------------- /images/diagram.svg: -------------------------------------------------------------------------------- 1 | InputImage processingsystemIQA model evaluationOutputFeedbackReference -------------------------------------------------------------------------------- /images/r0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/images/r0.png -------------------------------------------------------------------------------- /images/r1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingkeyan93/IQA-optimization/6b46c3b221b25ff277070aa4390a24c5fe25a7f3/images/r1.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | with open("README.md", "r") as fh: 3 | long_description = fh.read() 4 | 5 | setup( 6 | name='IQA_pytorch', 7 | version='0.1', 8 | description='IQA models in PyTorch', 9 | long_description=long_description, 10 | long_description_content_type="text/markdown", 11 | packages=['IQA_pytorch'], 12 | data_files= [('', ['IQA_pytorch/weights/LPIPSvgg.pt','IQA_pytorch/weights/DISTS.pt'])], 13 | include_package_data=True, 14 | author='Keyan Ding', 15 | author_email='dingkeyan93@outlook.com', 16 | install_requires=["torch>=1.0"], 17 | url='https://github.com/dingkeyan93/IQA-pytorch', 18 | keywords = ['pytorch', 'similarity', 'IQA','metric','image-quality'], 19 | platforms = "python", 20 | license='MIT', 21 | ) 22 | 23 | # python setup.py sdist bdist_wheel 24 | # twine check dist/* 25 | # twine upload dist/* --------------------------------------------------------------------------------