├── misc ├── arial.ttf ├── mandril_color.tif └── mandril_gray.tif ├── requirements.txt ├── LICENSE ├── README.md ├── example.ipynb └── fsim.py /misc/arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikhailiuk/pytorch-fsim/HEAD/misc/arial.ttf -------------------------------------------------------------------------------- /misc/mandril_color.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikhailiuk/pytorch-fsim/HEAD/misc/mandril_color.tif -------------------------------------------------------------------------------- /misc/mandril_gray.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikhailiuk/pytorch-fsim/HEAD/misc/mandril_gray.tif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | numpy==1.19.5 3 | imageio==2.9.0 4 | scipy==1.4.1 5 | matplotlib==3.3.1 6 | Pillow==7.2.0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Aliaksei Mikhailiuk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-fsim 2 | Differentiable implementation of the Feature Similarity Index Measure in Pytorch with CUDA support 3 | 4 | # Installation 5 | * Clone the repository 6 | * pip3 install -r requirements.txt 7 | 8 | # Basic usage 9 | 10 | ## Computing score 11 | ```python 12 | import imageio 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import matplotlib 16 | import torch as pt 17 | from torch.autograd import Variable 18 | from torch import optim 19 | from fsim import FSIM, FSIMc 20 | from PIL import Image, ImageDraw, ImageFont 21 | import os 22 | 23 | # Path to reference image 24 | img1_path ='./misc/mandril_color.tif' 25 | # Is it black and white? 26 | bw = False 27 | # Size of the batch for training 28 | batch_size = 1 29 | 30 | # Read reference and distorted images 31 | img1 = Image.open(img1_path).convert('RGB') 32 | img1 = pt.from_numpy(np.asarray(img1)) 33 | img1 = img1.permute(2,0,1) 34 | img1 = img1.unsqueeze(0).type(pt.FloatTensor) 35 | img2 = pt.clamp(pt.rand(img1.size())*255.0,0,255.0) 36 | 37 | # Create fake batch (for testing) 38 | img1b = pt.cat(batch_size*[img1],0) 39 | img2b = pt.cat(batch_size*[img2],0) 40 | 41 | if pt.cuda.is_available(): 42 | img1b = img1b.cuda() 43 | img2b = img2b.cuda() 44 | 45 | # Create FSIM loss 46 | FSIM_loss = FSIMc() 47 | loss = FSIM_loss(img1b,img2b) 48 | print(loss) 49 | ``` 50 | 51 | ## Optimizing 52 | Note: recovering reference image from the Gaussian noize is challenging and requires regularization. More on that in this [paper](https://link.springer.com/article/10.1007/s11263-020-01419-7). 53 | 54 | ```python 55 | import imageio 56 | import numpy as np 57 | import matplotlib.pyplot as plt 58 | import matplotlib 59 | import torch as pt 60 | from torch.autograd import Variable 61 | from torch import optim 62 | from fsim import FSIM, FSIMc 63 | from PIL import Image, ImageDraw, ImageFont 64 | import os 65 | 66 | # Path to reference image 67 | img1_path ='./misc/mandril_color.tif' 68 | # Is it black and white? 69 | bw = False 70 | # Size of the batch for training 71 | batch_size = 1 72 | 73 | # Read reference and distorted images 74 | img1 = Image.open(img1_path).convert('RGB') 75 | img1 = pt.from_numpy(np.asarray(img1)) 76 | img1 = img1.permute(2,0,1) 77 | img1 = img1.unsqueeze(0).type(pt.FloatTensor) 78 | img2 = pt.clamp(pt.rand(img1.size())*255.0,0,255.0) 79 | 80 | # Create fake batch (for testing) 81 | img1b = pt.cat(batch_size*[img1],0) 82 | img2b = pt.cat(batch_size*[img2],0) 83 | # Convert images to variables to support gradients 84 | img1b = Variable( img1b, requires_grad = False) 85 | img2b = Variable( img2b, requires_grad = True) 86 | 87 | if pt.cuda.is_available(): 88 | img1b = img1b.cuda() 89 | img2b = img2b.cuda() 90 | 91 | # Create FSIM loss 92 | FSIM_loss = FSIMc() 93 | 94 | # Tie optimizer to the distorted batch 95 | optimizer = optim.Adam([img2b], lr=0.1) 96 | 97 | # Check if the gradient propagates 98 | for ii in range(0,1000): 99 | optimizer.zero_grad() 100 | 101 | loss = -FSIM_loss(img1b,img2b) 102 | print(loss) 103 | loss = pt.sum(loss) 104 | loss.backward() 105 | optimizer.step() 106 | ``` 107 | 108 | # References 109 | The code is the direct implementation of the MATLAB version provided by: 110 | 111 | https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm 112 | -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import imageio\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import matplotlib\n", 13 | "import torch as pt\n", 14 | "from torch.autograd import Variable\n", 15 | "from torch import optim\n", 16 | "from fsim import FSIM, FSIMc\n", 17 | "from PIL import Image, ImageDraw, ImageFont\n", 18 | "import os\n", 19 | "%matplotlib inline" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "def read_convert_pt_image(image_path):\n", 29 | " '''\n", 30 | " Function to read an image from the file specified in image_path and \n", 31 | " convert it to pytorch tensor\n", 32 | " '''\n", 33 | " image = Image.open(image_path).convert('RGB')\n", 34 | "\n", 35 | " image = pt.from_numpy(np.asarray(image))\n", 36 | " \n", 37 | " image = image.permute(2,0,1)\n", 38 | " return image\n", 39 | "\n", 40 | "def save_image_score(img_torch, img_name, score, bw = False ):\n", 41 | " '''\n", 42 | " Function to save the image to the output folder with FSIM score imprinted\n", 43 | " '''\n", 44 | " img_mask = img_torch.squeeze(0).permute(1,2,0).data.numpy()\n", 45 | " \n", 46 | " img_mask[img_mask>255.] = 255.\n", 47 | " img_mask[img_mask<0] = 0.\n", 48 | "\n", 49 | " img = Image.fromarray(np.uint8(img_mask))\n", 50 | "\n", 51 | " d = ImageDraw.Draw(img)\n", 52 | " font = ImageFont.truetype(\"./misc/arial.ttf\", 40)\n", 53 | " d.text((10,10), 'FSIM='+str(round(score*1000)/1000), font=font, fill=(255,255,255))\n", 54 | "\n", 55 | " img.save(img_name+'.png')" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "scrolled": true 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "# The metric expects images to be in the range from 0 to 255.\n", 67 | "\n", 68 | "# Path to reference image\n", 69 | "img1_path ='./misc/mandril_color.tif'\n", 70 | "# Is it black and white?\n", 71 | "bw = False\n", 72 | "# Size of the batch for training\n", 73 | "batch_size = 1\n", 74 | "# Do we regenrate the image from noise (True), or clean up the noise from the image (False)\n", 75 | "noise = True\n", 76 | "# Save image\n", 77 | "save_image = False\n", 78 | "\n", 79 | "if save_image and not (os.path.isdir('output')):\n", 80 | " os.mkdir('output')\n", 81 | "\n", 82 | "# Read reference and distorted images\n", 83 | "img1 = read_convert_pt_image(img1_path)\n", 84 | "img1 = img1.unsqueeze(0).type(pt.FloatTensor)\n", 85 | "if noise:\n", 86 | " img2 = pt.clamp(pt.rand(img1.size())*255.0,0,255.0)\n", 87 | "else:\n", 88 | " img2 = pt.clamp(img1+200*pt.rand(img1.size()),0,255.0)\n", 89 | " \n", 90 | "\n", 91 | "# Create fake batch (for testing)\n", 92 | "img1b = pt.cat(batch_size*[img1],0)\n", 93 | "img2b = pt.cat(batch_size*[img2],0)\n", 94 | "# Convert images to variables to support gradients\n", 95 | "img1b = Variable( img1b, requires_grad = False)\n", 96 | "img2b = Variable( img2b, requires_grad = True)\n", 97 | "\n", 98 | "if pt.cuda.is_available():\n", 99 | " img1b = img1b.cuda()\n", 100 | " img2b = img2b.cuda()\n", 101 | "\n", 102 | "# Create FSIM loss\n", 103 | "FSIM_loss = FSIMc()\n", 104 | "\n", 105 | "# Tie optimizer to the distorted batch\n", 106 | "optimizer = optim.Adam([img2b], lr=0.1)\n", 107 | "\n", 108 | "# Check if the gradient propagates\n", 109 | "for ii in range(0,1000):\n", 110 | " optimizer.zero_grad()\n", 111 | "\n", 112 | " loss = -FSIM_loss(img1b,img2b) \n", 113 | " print(loss)\n", 114 | " loss = pt.sum(loss)\n", 115 | " loss.backward()\n", 116 | " optimizer.step()\n", 117 | " \n", 118 | " if ii%20 ==0 and save_image:\n", 119 | " save_image_score(img2b,'./output/optimized_image_'+str(ii),loss.item()*-1.0)\n", 120 | "\n" 121 | ] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "Python 3", 127 | "language": "python", 128 | "name": "python3" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.8.5" 141 | } 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 4 145 | } 146 | -------------------------------------------------------------------------------- /fsim.py: -------------------------------------------------------------------------------- 1 | import torch as pt 2 | import torch.nn as nn 3 | from torch import optim 4 | from torch.autograd import Variable 5 | import torch.nn.functional as FUN 6 | 7 | import numpy as np 8 | import math 9 | import imageio 10 | from scipy.io import loadmat 11 | 12 | ''' 13 | This code is a direct pytorch implementation of the original FSIM code provided by 14 | Lin ZHANG, Lei Zhang, Xuanqin Mou and David Zhang in Matlab. For the original version 15 | please see: 16 | 17 | https://www4.comp.polyu.edu.hk/~cslzhang/IQA/FSIM/FSIM.htm 18 | 19 | ''' 20 | 21 | 22 | class FSIM_base(nn.Module): 23 | 24 | def __init__(self): 25 | nn.Module.__init__(self) 26 | self.cuda_computation = False 27 | self.nscale = 4 # Number of wavelet scales 28 | self.norient = 4 # Number of filter orientations 29 | self.k = 2.0 # No of standard deviations of the noise 30 | # energy beyond the mean at which we set the 31 | # noise threshold point. 32 | # below which phase congruency values get 33 | # penalized. 34 | 35 | self.epsilon = .0001 # Used to prevent division by zero 36 | self.pi = math.pi 37 | 38 | minWaveLength = 6 # Wavelength of smallest scale filter 39 | mult = 2 # Scaling factor between successive filters 40 | sigmaOnf = 0.55 # Ratio of the standard deviation of the 41 | # Gaussian describing the log Gabor filter's 42 | # transfer function in the frequency domain 43 | # to the filter center frequency. 44 | dThetaOnSigma = 1.2 # Ratio of angular interval between filter orientations 45 | # and the standard deviation of the angular Gaussian 46 | # function used to construct filters in the 47 | # freq. plane. 48 | 49 | self.thetaSigma = self.pi/self.norient/dThetaOnSigma # Calculate the standard deviation of the 50 | # angular Gaussian function used to 51 | # construct filters in the freq. plane. 52 | 53 | 54 | self.fo = (1.0/(minWaveLength*pt.pow(mult,(pt.arange(0,self.nscale,dtype=pt.float64))))).unsqueeze(0) # Centre frequency of filter 55 | self.den = 2*(math.log(sigmaOnf))**2 56 | self.dx = -pt.tensor([[[[3, 0, -3], [10, 0,-10], [3,0,-3]]]])/16.0 57 | self.dy = -pt.tensor([[[[3, 10, 3], [0, 0, 0], [-3 ,-10, -3]]]])/16.0 58 | self.T1 = 0.85 59 | self.T2 = 160 60 | self.T3 = 200; 61 | self.T4 = 200; 62 | self.lambdac = 0.03 63 | 64 | def set_arrays_to_cuda(self): 65 | self.cuda_computation = True 66 | self.fo = self.fo.cuda() 67 | self.dx = self.dx.cuda() 68 | self.dy = self.dy.cuda() 69 | 70 | def forward_gradloss(self,imgr,imgd): 71 | I1,Q1,Y1 = self.process_image_channels(imgr) 72 | I2,Q2,Y2 = self.process_image_channels(imgd) 73 | 74 | 75 | #PCSimMatrix,PCm = self.calculate_phase_score(PC1,PC2) 76 | gradientMap1 = self.calculate_gradient_map(Y1) 77 | gradientMap2 = self.calculate_gradient_map(Y2) 78 | 79 | gradientSimMatrix = self.calculate_gradient_sim(gradientMap1,gradientMap2) 80 | #gradientSimMatrix= gradientSimMatrix.view(PCSimMatrix.size()) 81 | gradloss = pt.sum(pt.sum(pt.sum(gradientSimMatrix,1),1)) 82 | return gradloss 83 | 84 | def calculate_fsim(self,gradientSimMatrix,PCSimMatrix,PCm): 85 | SimMatrix = gradientSimMatrix * PCSimMatrix * PCm 86 | FSIM = pt.sum(pt.sum(SimMatrix,1),1) / pt.sum(pt.sum(PCm,1),1) 87 | return FSIM 88 | 89 | def calculate_fsimc(self, I1,Q1,I2,Q2,gradientSimMatrix,PCSimMatrix,PCm): 90 | 91 | ISimMatrix = (2*I1*I2 + self.T3) / (pt.pow(I1,2) + pt.pow(I2,2) + self.T3) 92 | QSimMatrix = (2*Q1*Q2 + self.T4) / (pt.pow(Q1,2) + pt.pow(Q2,2) + self.T4) 93 | SimMatrixC = gradientSimMatrix*PCSimMatrix*(pt.pow(pt.abs(ISimMatrix*QSimMatrix),self.lambdac))*PCm 94 | FSIMc = pt.sum(pt.sum(SimMatrixC,1),1)/pt.sum(pt.sum(PCm,1),1) 95 | 96 | return FSIMc 97 | 98 | def lowpassfilter(self, rows, cols): 99 | cutoff = .45 100 | n = 15 101 | x, y = self.create_meshgrid(cols,rows) 102 | radius = pt.sqrt(pt.pow(x,2) + pt.pow(y,2)).unsqueeze(0) 103 | f = self.ifftshift2d( 1 / (1.0 + pt.pow(pt.div(radius,cutoff),2*n)) ) 104 | return f 105 | 106 | def calculate_gradient_sim(self,gradientMap1,gradientMap2): 107 | 108 | gradientSimMatrix = (2*gradientMap1*gradientMap2 + self.T2) /(pt.pow(gradientMap1,2) + pt.pow(gradientMap2,2) + self.T2) 109 | return gradientSimMatrix 110 | 111 | def calculate_gradient_map(self,Y): 112 | IxY = FUN.conv2d(Y,self.dx, padding=1) 113 | IyY = FUN.conv2d(Y,self.dy, padding=1) 114 | gradientMap1 = pt.sqrt(pt.pow(IxY,2) + pt.pow(IyY,2)) 115 | return gradientMap1 116 | 117 | def calculate_phase_score(self,PC1,PC2): 118 | PCSimMatrix = (2 * PC1 * PC2 + self.T1) / (pt.pow(PC1,2) + pt.pow(PC2,2) + self.T1) 119 | PCm = pt.where(PC1>PC2, PC1,PC2) 120 | return PCSimMatrix,PCm 121 | 122 | def roll_1(self,x, n): 123 | return pt.cat((x[:,-n:,:,:,:], x[:,:-n,:,:,:]), dim=1) 124 | 125 | def ifftshift(self,tens,var_axis): 126 | len11 = int(tens.size()[var_axis]/2) 127 | len12 = tens.size()[var_axis]-len11 128 | return pt.cat((tens.narrow(var_axis,len11,len12),tens.narrow(var_axis,0,len11)),axis=var_axis) 129 | 130 | def ifftshift2d(self,tens): 131 | return self.ifftshift(self.ifftshift(tens,1),2) 132 | 133 | def create_meshgrid(self,cols,rows): 134 | ''' 135 | Set up X and Y matrices with ranges normalised to +/- 0.5 136 | The following code adjusts things appropriately for odd and even values 137 | of rows and columns. 138 | ''' 139 | 140 | if cols%2: 141 | xrange = pt.arange(start = -(cols-1)/2, end = (cols-1)/2+1, step = 1, requires_grad=False)/(cols-1) 142 | else: 143 | xrange = pt.arange(-(cols)/2, (cols)/2, step = 1, requires_grad=False)/(cols) 144 | 145 | if rows%2: 146 | yrange = pt.arange(-(rows-1)/2, (rows-1)/2+1, step = 1, requires_grad=False)/(rows-1) 147 | else: 148 | yrange = pt.arange(-(rows)/2, (rows)/2, step = 1, requires_grad=False)/(rows) 149 | 150 | x, y = pt.meshgrid([xrange, yrange]) 151 | 152 | if self.cuda_computation: 153 | x, y = x.cuda(), y.cuda() 154 | 155 | return x.T, y.T 156 | 157 | def process_image_channels(self,img): 158 | 159 | 160 | batch, rows, cols = img.shape[0],img.shape[2],img.shape[3] 161 | 162 | minDimension = min(rows,cols) 163 | 164 | Ycoef = pt.tensor([[0.299,0.587,0.114]]) 165 | Icoef = pt.tensor([[0.596,-0.274,-0.322]]) 166 | Qcoef = pt.tensor([[0.211,-0.523,0.312]]) 167 | 168 | if self.cuda_computation: 169 | Ycoef, Icoef, Qcoef = Ycoef.cuda(), Icoef.cuda(), Qcoef.cuda() 170 | 171 | Yfilt=pt.cat(batch*[pt.cat(rows*cols*[Ycoef.unsqueeze(2)],dim=2).view(1,3,rows,cols)],0) 172 | Ifilt=pt.cat(batch*[pt.cat(rows*cols*[Icoef.unsqueeze(2)],dim=2).view(1,3,rows,cols)],0) 173 | Qfilt=pt.cat(batch*[pt.cat(rows*cols*[Qcoef.unsqueeze(2)],dim=2).view(1,3,rows,cols)],0) 174 | 175 | # If images have three chanels 176 | if img.size()[1]==3: 177 | Y = pt.sum(Yfilt*img,1).unsqueeze(1) 178 | I = pt.sum(Ifilt*img,1).unsqueeze(1) 179 | Q = pt.sum(Qfilt*img,1).unsqueeze(1) 180 | else: 181 | Y = pt.mean(img,1).unsqueeze(1) 182 | I = pt.ones(Y.size(),dtype=pt.float64) 183 | Q = pt.ones(Y.size(),dtype=pt.float64) 184 | 185 | F = max(1,round(minDimension / 256)) 186 | 187 | aveKernel = nn.AvgPool2d(kernel_size = F, stride = F, padding =0)# max(0, math.floor(F/2))) 188 | if self.cuda_computation: 189 | aveKernel = aveKernel.cuda() 190 | 191 | # Make sure that the dimension of the returned image is the same as the input 192 | I = aveKernel(I) 193 | Q = aveKernel(Q) 194 | Y = aveKernel(Y) 195 | return I,Q,Y 196 | 197 | 198 | def phasecong2(self,img): 199 | ''' 200 | % Filters are constructed in terms of two components. 201 | % 1) The radial component, which controls the frequency band that the filter 202 | % responds to 203 | % 2) The angular component, which controls the orientation that the filter 204 | % responds to. 205 | % The two components are multiplied together to construct the overall filter. 206 | 207 | % Construct the radial filter components... 208 | 209 | % First construct a low-pass filter that is as large as possible, yet falls 210 | % away to zero at the boundaries. All log Gabor filters are multiplied by 211 | % this to ensure no extra frequencies at the 'corners' of the FFT are 212 | % incorporated as this seems to upset the normalisation process when 213 | % calculating phase congrunecy. 214 | ''' 215 | 216 | batch, rows, cols = img.shape[0],img.shape[2],img.shape[3] 217 | 218 | imagefft = pt.rfft(img,signal_ndim=2,onesided=False) 219 | 220 | x, y = self.create_meshgrid(cols,rows) 221 | 222 | radius = pt.cat(batch*[pt.sqrt(pt.pow(x,2) + pt.pow(y,2)).unsqueeze(0)],0) 223 | theta = pt.cat(batch*[pt.atan2(-y,x).unsqueeze(0)],0) 224 | 225 | radius = self.ifftshift2d(radius) # Matrix values contain *normalised* radius from centre 226 | theta = self.ifftshift2d(theta) # Matrix values contain polar angle. 227 | # (note -ve y is used to give +ve 228 | # anti-clockwise angles) 229 | 230 | radius[:,0,0] = 1 231 | 232 | sintheta = pt.sin(theta) 233 | costheta = pt.cos(theta) 234 | 235 | lp = self.lowpassfilter(rows,cols) # Radius .45, 'sharpness' 15 236 | lp = pt.cat(batch*[lp.unsqueeze(0)],0) 237 | 238 | term1 = pt.cat(rows*cols*[self.fo.unsqueeze(2)],dim=2).view(-1,self.nscale,rows,cols) 239 | term1 = pt.cat(batch*[term1.unsqueeze(0)],0).view(-1,self.nscale,rows,cols) 240 | 241 | term2 = pt.log(pt.cat(self.nscale*[radius.unsqueeze(1)],1)/term1) 242 | # Apply low-pass filter 243 | logGabor = pt.exp(-pt.pow(term2,2)/self.den) 244 | logGabor = logGabor*lp 245 | logGabor[:,:,0,0] = 0 # Set the value at the 0 frequency point of the filter 246 | # back to zero (undo the radius fudge). 247 | 248 | # Then construct the angular filter components... 249 | 250 | # For each point in the filter matrix calculate the angular distance from 251 | # the specified filter orientation. To overcome the angular wrap-around 252 | # problem sine difference and cosine difference values are first computed 253 | # and then the atan2 function is used to determine angular distance. 254 | angl = pt.arange(0,self.norient,dtype=pt.float64)/self.norient*self.pi 255 | 256 | if self.cuda_computation: 257 | angl = angl.cuda() 258 | ds_t1 = pt.cat(self.norient*[sintheta.unsqueeze(1)],1)*pt.cos(angl).view(-1,self.norient,1,1) 259 | ds_t2 = pt.cat(self.norient*[costheta.unsqueeze(1)],1)*pt.sin(angl).view(-1,self.norient,1,1) 260 | dc_t1 = pt.cat(self.norient*[costheta.unsqueeze(1)],1)*pt.cos(angl).view(-1,self.norient,1,1) 261 | dc_t2 = pt.cat(self.norient*[sintheta.unsqueeze(1)],1)*pt.sin(angl).view(-1,self.norient,1,1) 262 | ds = ds_t1-ds_t2 # Difference in sine. 263 | dc = dc_t1+dc_t2 # Difference in cosine. 264 | dtheta = pt.abs(pt.atan2(ds,dc)) # Absolute angular distance. 265 | spread = pt.exp(-pt.pow(dtheta,2)/(2*self.thetaSigma**2)) # Calculate the 266 | # angular filter component. 267 | 268 | logGabor_rep = pt.repeat_interleave(logGabor,self.norient,1).view(-1,self.nscale,self.norient,rows,cols) 269 | 270 | # Batch size, scale, orientation, pixels, pixels 271 | spread_rep = pt.cat(self.nscale*[spread]).view(-1,self.nscale,self.norient,rows,cols) 272 | filter_log_spread = logGabor_rep*spread_rep 273 | array_of_zeros = pt.zeros(filter_log_spread.unsqueeze(5).size(),dtype=pt.float64) 274 | if self.cuda_computation: 275 | array_of_zeros = array_of_zeros.cuda() 276 | filter_log_spread_zero = pt.cat((filter_log_spread.unsqueeze(5),array_of_zeros), dim=5) 277 | ifftFilterArray = pt.ifft(filter_log_spread_zero,signal_ndim =2).select(5,0)*math.sqrt(rows*cols) 278 | 279 | imagefft_repeat = pt.cat(self.nscale*self.norient*[imagefft],dim=1).view(-1,self.nscale,self.norient,rows,cols,2) 280 | filter_log_spread_repeat = pt.cat(2*[filter_log_spread.unsqueeze(5)],dim=5) 281 | # Convolve image with even and odd filters returning the result in EO 282 | EO = pt.ifft(filter_log_spread_repeat*imagefft_repeat,signal_ndim=2) 283 | 284 | E = EO.select(5, 0) 285 | O = EO.select(5, 1) 286 | An = pt.sqrt(pt.pow(E,2)+pt.pow(O,2)) 287 | sumAn_ThisOrient = pt.sum(An,1) 288 | sumE_ThisOrient = pt.sum(E,1) # Sum of even filter convolution results 289 | sumO_ThisOrient = pt.sum(O,1) # Sum of odd filter convolution results. 290 | 291 | # Get weighted mean filter response vector, this gives the weighted mean 292 | # phase angle. 293 | XEnergy = pt.sqrt(pt.pow(sumE_ThisOrient,2) + pt.pow(sumO_ThisOrient,2)) + self.epsilon 294 | MeanE = sumE_ThisOrient / XEnergy 295 | MeanO = sumO_ThisOrient / XEnergy 296 | 297 | MeanO = pt.cat(self.nscale*[MeanO.unsqueeze(1)],1) 298 | MeanE = pt.cat(self.nscale*[MeanE.unsqueeze(1)],1) 299 | 300 | 301 | # Now calculate An(cos(phase_deviation) - | sin(phase_deviation)) | by 302 | # using dot and cross products between the weighted mean filter response 303 | # vector and the individual filter response vectors at each scale. This 304 | # quantity is phase congruency multiplied by An, which we call energy. 305 | Energy = pt.sum( E*MeanE+O*MeanO - pt.abs(E*MeanO-O*MeanE),1) 306 | abs_EO = pt.sqrt(pt.pow(E,2) + pt.pow(O,2)) 307 | 308 | # % Compensate for noise 309 | # We estimate the noise power from the energy squared response at the 310 | # smallest scale. If the noise is Gaussian the energy squared will have a 311 | # Chi-squared 2DOF pdf. We calculate the median energy squared response 312 | # as this is a robust statistic. From this we estimate the mean. 313 | # The estimate of noise power is obtained by dividing the mean squared 314 | # energy value by the mean squared filter value 315 | medianE2n = pt.pow(abs_EO.select(1,0),2).view(-1,self.norient,rows*cols).median(2).values 316 | 317 | EM_n = pt.sum(pt.sum(pt.pow(filter_log_spread.select(1,0),2),3),2) 318 | noisePower = -(medianE2n/math.log(0.5))/EM_n 319 | 320 | # Now estimate the total energy^2 due to noise 321 | # Estimate for sum(An^2) + sum(Ai.*Aj.*(cphi.*cphj + sphi.*sphj)) 322 | EstSumAn2 = pt.sum(pt.pow(ifftFilterArray,2),1) 323 | 324 | sumEstSumAn2 = pt.sum(pt.sum(EstSumAn2,2),2) 325 | roll_t1 = ifftFilterArray*self.roll_1(ifftFilterArray,1) 326 | roll_t2 = ifftFilterArray*self.roll_1(ifftFilterArray,2) 327 | roll_t3 = ifftFilterArray*self.roll_1(ifftFilterArray,3) 328 | rolling_mult = roll_t1+roll_t2+roll_t3 329 | EstSumAiAj = pt.sum(rolling_mult,1)/2 330 | sumEstSumAiAj = pt.sum(pt.sum(EstSumAiAj,2),2) 331 | 332 | EstNoiseEnergy2 = 2*noisePower*sumEstSumAn2+4*noisePower*sumEstSumAiAj 333 | tau = pt.sqrt(EstNoiseEnergy2/2) 334 | EstNoiseEnergy = tau*math.sqrt(self.pi/2) 335 | EstNoiseEnergySigma = pt.sqrt( (2-self.pi/2)*pt.pow(tau,2)) 336 | 337 | 338 | # The estimated noise effect calculated above is only valid for the PC_1 measure. 339 | # The PC_2 measure does not lend itself readily to the same analysis. However 340 | # empirically it seems that the noise effect is overestimated roughly by a factor 341 | # of 1.7 for the filter parameters used here. 342 | T = (EstNoiseEnergy + self.k*EstNoiseEnergySigma)/1.7 # Noise threshold 343 | 344 | T_exp = pt.cat(rows*cols*[T.unsqueeze(2)],dim=2).view(-1,self.norient,rows,cols) 345 | AnAll = pt.sum(sumAn_ThisOrient,1) 346 | array_of_zeros_energy = pt.zeros(Energy.size(),dtype=pt.float64) 347 | if self.cuda_computation: 348 | array_of_zeros_energy =array_of_zeros_energy.cuda() 349 | 350 | EnergyAll = pt.sum(pt.where((Energy - T_exp)<0.0, array_of_zeros_energy,Energy - T_exp ),1) 351 | ResultPC = EnergyAll/AnAll 352 | 353 | return ResultPC 354 | 355 | class FSIM(FSIM_base): 356 | ''' 357 | Note, the input is expected to be from 0 to 255 358 | ''' 359 | 360 | def __init__(self): 361 | super().__init__() 362 | 363 | def forward(self,imgr,imgd): 364 | if imgr.is_cuda: 365 | self.set_arrays_to_cuda() 366 | 367 | I1,Q1,Y1 = self.process_image_channels(imgr) 368 | I2,Q2,Y2 = self.process_image_channels(imgd) 369 | PC1 = self.phasecong2(Y1) 370 | PC2 = self.phasecong2(Y2) 371 | 372 | PCSimMatrix,PCm = self.calculate_phase_score(PC1,PC2) 373 | gradientMap1 = self.calculate_gradient_map(Y1) 374 | gradientMap2 = self.calculate_gradient_map(Y2) 375 | 376 | gradientSimMatrix = self.calculate_gradient_sim(gradientMap1,gradientMap2) 377 | gradientSimMatrix= gradientSimMatrix.view(PCSimMatrix.size()) 378 | FSIM = self.calculate_fsim(gradientSimMatrix,PCSimMatrix,PCm) 379 | 380 | return FSIM.mean() 381 | 382 | class FSIMc(FSIM_base, nn.Module): 383 | ''' 384 | Note, the input is expected to be from 0 to 255 385 | ''' 386 | def __init__(self): 387 | super().__init__() 388 | 389 | def forward(self,imgr,imgd): 390 | if imgr.is_cuda: 391 | self.set_arrays_to_cuda() 392 | 393 | 394 | I1,Q1,Y1 = self.process_image_channels(imgr) 395 | I2,Q2,Y2 = self.process_image_channels(imgd) 396 | PC1 = self.phasecong2(Y1) 397 | PC2 = self.phasecong2(Y2) 398 | 399 | PCSimMatrix,PCm = self.calculate_phase_score(PC1,PC2) 400 | gradientMap1 = self.calculate_gradient_map(Y1) 401 | gradientMap2 = self.calculate_gradient_map(Y2) 402 | 403 | gradientSimMatrix = self.calculate_gradient_sim(gradientMap1,gradientMap2) 404 | gradientSimMatrix= gradientSimMatrix.view(PCSimMatrix.size()) 405 | FSIMc = self.calculate_fsimc(I1.squeeze(),Q1.squeeze(),I2.squeeze(),Q2.squeeze(),gradientSimMatrix,PCSimMatrix,PCm) 406 | 407 | return FSIMc.mean() --------------------------------------------------------------------------------