├── README.md ├── calc_ssim.py └── example ├── car_gt.jpeg └── car_pred.jpeg /README.md: -------------------------------------------------------------------------------- 1 | # Reproduce the matlab ssim in pytorch 2 | 3 | 4 | SSIM are **not** consistent in several implementations. For example, skimage calculate the average of the SSIM of the individual channels, while matlab adopts a 3D gaussian kernel. The inconsistence can be found here https://github.com/scikit-image/scikit-image/issues/4985 . 5 | Some communities, such as low-level vison community, tend to report the ssim provided by matlab in the paper. However, I did not find such a reproduce (they failed to reproduce the result with a difference < 0.0001) in the python community, so I created one. 6 | 7 | #### usage 8 | python calc_ssim.py 9 | 10 | #### result 11 | It should print a number ~ 0.8434 (The result of matlab is 0.8433). 12 | The difference (0.0001) might be a numerical problem caused by the difference between pytorch and matlab implementation. 13 | -------------------------------------------------------------------------------- /calc_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | 5 | def generate_1d_gaussian_kernel(): 6 | return cv2.getGaussianKernel(11, 1.5) 7 | 8 | def generate_2d_gaussian_kernel(): 9 | kernel = generate_1d_gaussian_kernel() 10 | return np.outer(kernel, kernel.transpose()) 11 | 12 | def generate_3d_gaussian_kernel(): 13 | kernel = generate_1d_gaussian_kernel() 14 | window = generate_2d_gaussian_kernel() 15 | return np.stack([window * k for k in kernel], axis=0) 16 | 17 | class SSIM(): 18 | 19 | def __init__(self, device='cpu'): 20 | self.device = device 21 | 22 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate') 23 | conv3d.weight.requires_grad = False 24 | conv3d.weight[0, 0, :, :, :] = torch.tensor(generate_3d_gaussian_kernel()) 25 | self.conv3d = conv3d.to(device) 26 | 27 | conv2d = torch.nn.Conv2d(1, 1, (11, 11), stride=1, padding=(5, 5), bias=False, padding_mode='replicate') 28 | conv2d.weight.requires_grad = False 29 | conv2d.weight[0, 0, :, :] = torch.tensor(generate_2d_gaussian_kernel()) 30 | self.conv2d = conv2d.to(device) 31 | 32 | 33 | def calc(self, img1, img2): 34 | assert len(img1.shape) == len(img2.shape) 35 | with torch.no_grad(): 36 | img1 = torch.tensor(img1).to(self.device).float() 37 | img2 = torch.tensor(img2).to(self.device).float() 38 | 39 | if len(img1.shape) == 2: 40 | conv = self.conv2d 41 | elif len(img1.shape) == 3: 42 | conv = self.conv3d 43 | else: 44 | raise not NotImplementedError('only support 2d / 3d images.') 45 | return self._ssim(img1, img2, conv) 46 | 47 | 48 | def _ssim(self, img1, img2, conv): 49 | img1 = img1.unsqueeze(0).unsqueeze(0) 50 | img2 = img2.unsqueeze(0).unsqueeze(0) 51 | 52 | C1 = (0.01 * 255) ** 2 53 | C2 = (0.03 * 255) ** 2 54 | 55 | mu1 = conv(img1) 56 | mu2 = conv(img2) 57 | 58 | mu1_sq = mu1 ** 2 59 | mu2_sq = mu2 ** 2 60 | mu1_mu2 = mu1 * mu2 61 | sigma1_sq = conv(img1 ** 2) - mu1_sq 62 | sigma2_sq = conv(img2 ** 2) - mu2_sq 63 | sigma12 = conv(img1 * img2) - mu1_mu2 64 | 65 | ssim_map = ((2 * mu1_mu2 + C1) * 66 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 67 | (sigma1_sq + sigma2_sq + C2)) 68 | 69 | return float(ssim_map.mean()) 70 | 71 | 72 | device = 'cpu' 73 | calculator = SSIM(device=device) 74 | 75 | tgt = cv2.imread('example/car_gt.jpeg') 76 | inp = cv2.imread('example/car_pred.jpeg') 77 | 78 | print(calculator.calc(inp, tgt)) 79 | -------------------------------------------------------------------------------- /example/car_gt.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayorx/matlab_ssim_pytorch_implementation/b0f03706923ffb08fb72438169e2ac683ad8f47a/example/car_gt.jpeg -------------------------------------------------------------------------------- /example/car_pred.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayorx/matlab_ssim_pytorch_implementation/b0f03706923ffb08fb72438169e2ac683ad8f47a/example/car_pred.jpeg --------------------------------------------------------------------------------