├── README.md ├── loss.py └── train_example.py /README.md: -------------------------------------------------------------------------------- 1 | # MS_SSIM_pytorch 2 | __ms_ssim__ loss function implemented in pytorch 3 | 4 | # references 5 | 6 | [tensorflow implement on stackoverflow](https://stackoverflow.com/questions/39051451/ssim-ms-ssim-for-tensorflow) 7 | 8 | [Paper : Loss Functions for Image Restoration With Neural Networks](http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=7797130) and its [pycaffe codes](https://github.com/NVlabs/PL4NN/blob/master/src/loss.py) 9 | 10 | [pytorch_ssim](https://github.com/Po-Hsun-Su/pytorch-ssim)(only __ssim__ loss, not ms_ssim loss) 11 | 12 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ © 2018, lizhengwei """ 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, sigma, channel): 13 | _1D_window = gaussian(window_size, sigma).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | class MS_SSIM(torch.nn.Module): 19 | def __init__(self, size_average = True, max_val = 255): 20 | super(MS_SSIM, self).__init__() 21 | self.size_average = size_average 22 | self.channel = 3 23 | self.max_val = max_val 24 | def _ssim(self, img1, img2, size_average = True): 25 | 26 | _, c, w, h = img1.size() 27 | window_size = min(w, h, 11) 28 | sigma = 1.5 * window_size / 11 29 | window = create_window(window_size, sigma, self.channel).cuda() 30 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = self.channel) 31 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = self.channel) 32 | 33 | mu1_sq = mu1.pow(2) 34 | mu2_sq = mu2.pow(2) 35 | mu1_mu2 = mu1*mu2 36 | 37 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = self.channel) - mu1_sq 38 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = self.channel) - mu2_sq 39 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = self.channel) - mu1_mu2 40 | 41 | C1 = (0.01*self.max_val)**2 42 | C2 = (0.03*self.max_val)**2 43 | V1 = 2.0 * sigma12 + C2 44 | V2 = sigma1_sq + sigma2_sq + C2 45 | ssim_map = ((2*mu1_mu2 + C1)*V1)/((mu1_sq + mu2_sq + C1)*V2) 46 | mcs_map = V1 / V2 47 | if size_average: 48 | return ssim_map.mean(), mcs_map.mean() 49 | 50 | def ms_ssim(self, img1, img2, levels=5): 51 | 52 | weight = Variable(torch.Tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).cuda()) 53 | 54 | msssim = Variable(torch.Tensor(levels,).cuda()) 55 | mcs = Variable(torch.Tensor(levels,).cuda()) 56 | for i in range(levels): 57 | ssim_map, mcs_map = self._ssim(img1, img2) 58 | msssim[i] = ssim_map 59 | mcs[i] = mcs_map 60 | filtered_im1 = F.avg_pool2d(img1, kernel_size=2, stride=2) 61 | filtered_im2 = F.avg_pool2d(img2, kernel_size=2, stride=2) 62 | img1 = filtered_im1 63 | img2 = filtered_im2 64 | 65 | value = (torch.prod(mcs[0:levels-1]**weight[0:levels-1])* 66 | (msssim[levels-1]**weight[levels-1])) 67 | return value 68 | 69 | 70 | def forward(self, img1, img2): 71 | 72 | return self.ms_ssim(img1, img2) 73 | -------------------------------------------------------------------------------- /train_example.py: -------------------------------------------------------------------------------- 1 | import loss 2 | import torch 3 | from torch.autograd import Variable 4 | from torch import optim 5 | import cv2 6 | import numpy as np 7 | 8 | npImg1 = cv2.imread("einstein.png") 9 | 10 | img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0 11 | img2 = torch.rand(img1.size()) 12 | 13 | if torch.cuda.is_available(): 14 | img1 = img1.cuda() 15 | img2 = img2.cuda() 16 | 17 | 18 | img1 = Variable( img1, requires_grad=False) 19 | img2 = Variable( img2, requires_grad = True) 20 | 21 | 22 | # according input set max_val : 255 or 1 23 | ms_ssim_loss = MS_SSIM(max_val = 1) 24 | 25 | optimizer = optim.Adam([img2], lr=0.01) 26 | 27 | while ssim_value < 0.97: 28 | optimizer.zero_grad() 29 | ms_ssim_out = -ms_ssim_loss(img1, img2) 30 | ms_ssim_value = - ms_ssim_out.data[0] 31 | print(ms_ssim_value) 32 | ms_ssim_out.backward() 33 | optimizer.step() 34 | --------------------------------------------------------------------------------