├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── _test_speed ├── no1_ms_ssim_lizhengwei1992_MS_SSIM_pytorch.py ├── no2_ssim_Po_Hsun_Su_pytorch_ssim.py ├── no3_ssim_VainF_pytorch_msssim.py ├── no5_ssim_francois_rozet_piqa.py ├── test_ms_ssim_speed.py └── test_ssim_speed.py ├── make_gif.cmd ├── ms_ssim_test.gif ├── ms_ssim_test.mkv ├── ssim.py ├── ssim_test.gif ├── ssim_test.mkv └── test_img1.jpg /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .idea/ 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 One-sixth 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ms_ssim_pytorch 2 | 3 | The code was modified from https://github.com/VainF/pytorch-msssim. 4 | Part of the code has been modified to make it faster, takes up less VRAM, and is compatible with pytorch jit. 5 | 6 | The dynamic channel version can found here https://github.com/One-sixth/ms_ssim_pytorch/tree/dynamic_channel_num. 7 | More convenient to use but has a little performance loss. 8 | 9 | Thanks [vegetable09](https://github.com/vegetable09) for finding and fixing a bug that causes gradient nan when ms_ssim backward. [#3](https://github.com/One-sixth/ms_ssim_pytorch/issues/3) 10 | 11 | If you are using pytorch 1.2, please be careful not to create and destroy this jit module in the training loop (other jit modules may also have this situation), there may be memory leaks. I have tested that pytorch 1.6 does not have this problem. [#4](https://github.com/One-sixth/ms_ssim_pytorch/issues/4) 12 | 13 | I study to the ssim.py of the library [piqa](https://github.com/francois-rozet/piqa), which makes my implementation of ssim and ms-ssim a little faster than before. 14 | 15 | # Speed up. Only test on GPU. 16 | losser1 is https://github.com/lizhengwei1992/MS_SSIM_pytorch/blob/master/loss.py 268fc76 17 | losser2 is https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 881d210 18 | losser3 is https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py 5caf547 19 | losser4 is https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py 1c2f14a 20 | losser5 is https://github.com/francois-rozet/piqa/blob/master/piqa/ssim.py abaf398 21 | https://github.com/francois-rozet/piqa/blob/master/piqa/utils.py abaf398 22 | 23 | In pytorch 1.7.1 24 | My test environment: i7-8750H GTX1070-8G 25 | 26 | In pytorch 1.1 1.2 27 | My test environment: i7-6700HQ GTX970M-3G 28 | 29 | ## SSIM 30 | Test output 31 | 32 | pytorch 1.7.1 33 | ``` 34 | Performance Testing SSIM 35 | 36 | testing losser2 37 | cuda time 39963.15625 38 | perf_counter time 35.9110169 39 | 40 | testing losser3 41 | cuda time 17141.841796875 42 | perf_counter time 17.124456199999997 43 | 44 | testing losser4 45 | cuda time 13205.0322265625 46 | perf_counter time 10.477991699999997 47 | 48 | testing losser5 49 | cuda time 13142.8232421875 50 | perf_counter time 11.079514100000011 51 | 52 | ``` 53 | 54 | pytorch 1.2 55 | ``` 56 | Performance Testing SSIM 57 | 58 | testing losser2 59 | cuda time 89290.7734375 60 | perf_counter time 87.1042247 61 | 62 | testing losser3 63 | cuda time 36153.64453125 64 | perf_counter time 36.09167939999999 65 | 66 | testing losser4 67 | cuda time 31085.455078125 68 | perf_counter time 29.80807200000001 69 | 70 | ``` 71 | 72 | pytorch 1.1 73 | ``` 74 | Performance Testing SSIM 75 | 76 | testing losser2 77 | cuda time 88990.0703125 78 | perf_counter time 86.80163019999999 79 | 80 | testing losser3 81 | cuda time 36119.06640625 82 | perf_counter time 36.057978399999996 83 | 84 | testing losser4 85 | cuda time 34708.8359375 86 | perf_counter time 33.916086199999995 87 | 88 | ``` 89 | 90 | ## MS-SSIM 91 | Test output 92 | 93 | pytorch 1.7.1 94 | ``` 95 | Performance Testing MS_SSIM 96 | 97 | testing losser1 98 | cuda time 60403.59765625 99 | perf_counter time 60.351266200000005 100 | 101 | testing losser3 102 | cuda time 26321.48828125 103 | perf_counter time 26.30165939999999 104 | 105 | testing losser4 106 | cuda time 24471.6875 107 | perf_counter time 24.45189119999999 108 | 109 | testing losser5 110 | cuda time 23153.962890625 111 | perf_counter time 23.135575399999993 112 | 113 | ``` 114 | 115 | pytorch 1.2 116 | ``` 117 | Performance Testing MS_SSIM 118 | 119 | testing losser1 120 | cuda time 134158.84375 121 | perf_counter time 134.0433756 122 | 123 | testing losser3 124 | cuda time 62143.4140625 125 | perf_counter time 62.103911400000015 126 | 127 | testing losser4 128 | cuda time 46854.25390625 129 | perf_counter time 46.81785239999999 130 | 131 | ``` 132 | 133 | pytorch 1.1 134 | ``` 135 | Performance Testing MS_SSIM 136 | 137 | testing losser1 138 | cuda time 134115.96875 139 | perf_counter time 134.0006031 140 | 141 | testing losser3 142 | cuda time 61760.56640625 143 | perf_counter time 61.71994470000001 144 | 145 | testing losser4 146 | cuda time 52888.03125 147 | perf_counter time 52.848280500000016 148 | 149 | ``` 150 | 151 | ## Test speed by yourself 152 | 1. cd ms_ssim_pytorch/_test_speed 153 | 154 | 2. python test_ssim_speed.py 155 | or 156 | 2. python test_ms_ssim_speed.py 157 | 158 | # Other thing 159 | Add parameter use_padding. 160 | When set to True, the gaussian_filter behavior is the same as https://github.com/Po-Hsun-Su/pytorch-ssim. 161 | This parameter is mainly used for MS-SSIM, because MS-SSIM needs to be downsampled. 162 | When the input image is smaller than 176x176, this parameter needs to be set to True to ensure that MS-SSIM works normally. (when parameter weight and level are the default) 163 | 164 | # Require 165 | Pytorch >= 1.1 166 | 167 | if you want to test the code with animation. You also need to install some package. 168 | ``` 169 | pip install imageio imageio-ffmpeg opencv-python 170 | ``` 171 | 172 | # Test code with animation 173 | The test code is included in the ssim.py file, you can run the file directly to start the test. 174 | 175 | 1. git clone https://github.com/One-sixth/ms_ssim_pytorch 176 | 2. cd ms_ssim_pytorch 177 | 3. python ssim.py 178 | 179 | # Code Example. 180 | ```python 181 | import torch 182 | import ssim 183 | 184 | 185 | im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda') 186 | img1 = im / 255 187 | img2 = img1 * 0.5 188 | 189 | losser = ssim.SSIM(data_range=1., channel=3).cuda() 190 | loss = losser(img1, img2).mean() 191 | 192 | losser2 = ssim.MS_SSIM(data_range=1., channel=3).cuda() 193 | loss2 = losser2(img1, img2).mean() 194 | 195 | print(loss.item()) 196 | print(loss2.item()) 197 | ``` 198 | 199 | # Animation 200 | GIF is a bit big. Loading may take some time. 201 | Or you can download the mkv video file directly to view it, smaller and smoother. 202 | https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim_test.mkv 203 | https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ms_ssim_test.mkv 204 | 205 | SSIM 206 | ![ssim](https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim_test.gif) 207 | 208 | MS-SSIM 209 | ![ms-ssim](https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ms_ssim_test.gif) 210 | 211 | # References 212 | https://github.com/VainF/pytorch-msssim 213 | https://github.com/Po-Hsun-Su/pytorch-ssim 214 | https://github.com/lizhengwei1992/MS_SSIM_pytorch 215 | https://github.com/francois-rozet/piqa 216 | -------------------------------------------------------------------------------- /_test_speed/no1_ms_ssim_lizhengwei1992_MS_SSIM_pytorch.py: -------------------------------------------------------------------------------- 1 | """ © 2018, lizhengwei """ 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, sigma, channel): 13 | _1D_window = gaussian(window_size, sigma).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | class MS_SSIM(torch.nn.Module): 19 | def __init__(self, size_average = True, max_val = 255): 20 | super(MS_SSIM, self).__init__() 21 | self.size_average = size_average 22 | self.channel = 3 23 | self.max_val = max_val 24 | def _ssim(self, img1, img2, size_average = True): 25 | 26 | _, c, w, h = img1.size() 27 | window_size = min(w, h, 11) 28 | sigma = 1.5 * window_size / 11 29 | window = create_window(window_size, sigma, self.channel).cuda() 30 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = self.channel) 31 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = self.channel) 32 | 33 | mu1_sq = mu1.pow(2) 34 | mu2_sq = mu2.pow(2) 35 | mu1_mu2 = mu1*mu2 36 | 37 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = self.channel) - mu1_sq 38 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = self.channel) - mu2_sq 39 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = self.channel) - mu1_mu2 40 | 41 | C1 = (0.01*self.max_val)**2 42 | C2 = (0.03*self.max_val)**2 43 | V1 = 2.0 * sigma12 + C2 44 | V2 = sigma1_sq + sigma2_sq + C2 45 | ssim_map = ((2*mu1_mu2 + C1)*V1)/((mu1_sq + mu2_sq + C1)*V2) 46 | mcs_map = V1 / V2 47 | if size_average: 48 | return ssim_map.mean(), mcs_map.mean() 49 | 50 | def ms_ssim(self, img1, img2, levels=5): 51 | 52 | weight = Variable(torch.Tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).cuda()) 53 | 54 | msssim = Variable(torch.Tensor(levels,).cuda()) 55 | mcs = Variable(torch.Tensor(levels,).cuda()) 56 | for i in range(levels): 57 | ssim_map, mcs_map = self._ssim(img1, img2) 58 | msssim[i] = ssim_map 59 | mcs[i] = mcs_map 60 | filtered_im1 = F.avg_pool2d(img1, kernel_size=2, stride=2) 61 | filtered_im2 = F.avg_pool2d(img2, kernel_size=2, stride=2) 62 | img1 = filtered_im1 63 | img2 = filtered_im2 64 | 65 | value = (torch.prod(mcs[0:levels-1]**weight[0:levels-1])* 66 | (msssim[levels-1]**weight[levels-1])) 67 | return value 68 | 69 | 70 | def forward(self, img1, img2): 71 | 72 | return self.ms_ssim(img1, img2) -------------------------------------------------------------------------------- /_test_speed/no2_ssim_Po_Hsun_Su_pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /_test_speed/no3_ssim_VainF_pytorch_msssim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 by Gongfan Fang, Zhejiang University. 2 | # All rights reserved. 3 | import warnings 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def _fspecial_gauss_1d(size, sigma): 10 | r"""Create 1-D gauss kernel 11 | Args: 12 | size (int): the size of gauss kernel 13 | sigma (float): sigma of normal distribution 14 | Returns: 15 | torch.Tensor: 1D kernel (1 x 1 x size) 16 | """ 17 | coords = torch.arange(size).to(dtype=torch.float) 18 | coords -= size // 2 19 | 20 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) 21 | g /= g.sum() 22 | 23 | return g.unsqueeze(0).unsqueeze(0) 24 | 25 | 26 | def gaussian_filter(input, win): 27 | r""" Blur input with 1-D kernel 28 | Args: 29 | input (torch.Tensor): a batch of tensors to be blurred 30 | window (torch.Tensor): 1-D gauss kernel 31 | Returns: 32 | torch.Tensor: blurred tensors 33 | """ 34 | assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape 35 | if len(input.shape) == 4: 36 | conv = F.conv2d 37 | elif len(input.shape) == 5: 38 | conv = F.conv3d 39 | else: 40 | raise NotImplementedError(input.shape) 41 | 42 | C = input.shape[1] 43 | out = input 44 | for i, s in enumerate(input.shape[2:]): 45 | if s >= win.shape[-1]: 46 | out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C) 47 | else: 48 | warnings.warn( 49 | f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}" 50 | ) 51 | 52 | return out 53 | 54 | 55 | def _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)): 56 | 57 | r""" Calculate ssim index for X and Y 58 | Args: 59 | X (torch.Tensor): images 60 | Y (torch.Tensor): images 61 | win (torch.Tensor): 1-D gauss kernel 62 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 63 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 64 | Returns: 65 | torch.Tensor: ssim results. 66 | """ 67 | K1, K2 = K 68 | # batch, channel, [depth,] height, width = X.shape 69 | compensation = 1.0 70 | 71 | C1 = (K1 * data_range) ** 2 72 | C2 = (K2 * data_range) ** 2 73 | 74 | win = win.to(X.device, dtype=X.dtype) 75 | 76 | mu1 = gaussian_filter(X, win) 77 | mu2 = gaussian_filter(Y, win) 78 | 79 | mu1_sq = mu1.pow(2) 80 | mu2_sq = mu2.pow(2) 81 | mu1_mu2 = mu1 * mu2 82 | 83 | sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) 84 | sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) 85 | sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) 86 | 87 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 88 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map 89 | 90 | ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1) 91 | cs = torch.flatten(cs_map, 2).mean(-1) 92 | return ssim_per_channel, cs 93 | 94 | 95 | def ssim( 96 | X, 97 | Y, 98 | data_range=255, 99 | size_average=True, 100 | win_size=11, 101 | win_sigma=1.5, 102 | win=None, 103 | K=(0.01, 0.03), 104 | nonnegative_ssim=False, 105 | ): 106 | r""" interface of ssim 107 | Args: 108 | X (torch.Tensor): a batch of images, (N,C,H,W) 109 | Y (torch.Tensor): a batch of images, (N,C,H,W) 110 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 111 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 112 | win_size: (int, optional): the size of gauss kernel 113 | win_sigma: (float, optional): sigma of normal distribution 114 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma 115 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 116 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu 117 | Returns: 118 | torch.Tensor: ssim results 119 | """ 120 | if not X.shape == Y.shape: 121 | raise ValueError("Input images should have the same dimensions.") 122 | 123 | for d in range(len(X.shape) - 1, 1, -1): 124 | X = X.squeeze(dim=d) 125 | Y = Y.squeeze(dim=d) 126 | 127 | if len(X.shape) not in (4, 5): 128 | raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") 129 | 130 | if not X.type() == Y.type(): 131 | raise ValueError("Input images should have the same dtype.") 132 | 133 | if win is not None: # set win_size 134 | win_size = win.shape[-1] 135 | 136 | if not (win_size % 2 == 1): 137 | raise ValueError("Window size should be odd.") 138 | 139 | if win is None: 140 | win = _fspecial_gauss_1d(win_size, win_sigma) 141 | win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) 142 | 143 | ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K) 144 | if nonnegative_ssim: 145 | ssim_per_channel = torch.relu(ssim_per_channel) 146 | 147 | if size_average: 148 | return ssim_per_channel.mean() 149 | else: 150 | return ssim_per_channel.mean(1) 151 | 152 | 153 | def ms_ssim( 154 | X, Y, data_range=255, size_average=True, win_size=11, win_sigma=1.5, win=None, weights=None, K=(0.01, 0.03) 155 | ): 156 | 157 | r""" interface of ms-ssim 158 | Args: 159 | X (torch.Tensor): a batch of images, (N,C,[T,]H,W) 160 | Y (torch.Tensor): a batch of images, (N,C,[T,]H,W) 161 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 162 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 163 | win_size: (int, optional): the size of gauss kernel 164 | win_sigma: (float, optional): sigma of normal distribution 165 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma 166 | weights (list, optional): weights for different levels 167 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 168 | Returns: 169 | torch.Tensor: ms-ssim results 170 | """ 171 | if not X.shape == Y.shape: 172 | raise ValueError("Input images should have the same dimensions.") 173 | 174 | for d in range(len(X.shape) - 1, 1, -1): 175 | X = X.squeeze(dim=d) 176 | Y = Y.squeeze(dim=d) 177 | 178 | if not X.type() == Y.type(): 179 | raise ValueError("Input images should have the same dtype.") 180 | 181 | if len(X.shape) == 4: 182 | avg_pool = F.avg_pool2d 183 | elif len(X.shape) == 5: 184 | avg_pool = F.avg_pool3d 185 | else: 186 | raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") 187 | 188 | if win is not None: # set win_size 189 | win_size = win.shape[-1] 190 | 191 | if not (win_size % 2 == 1): 192 | raise ValueError("Window size should be odd.") 193 | 194 | smaller_side = min(X.shape[-2:]) 195 | assert smaller_side > (win_size - 1) * ( 196 | 2 ** 4 197 | ), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4)) 198 | 199 | if weights is None: 200 | weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] 201 | weights = torch.FloatTensor(weights).to(X.device, dtype=X.dtype) 202 | 203 | if win is None: 204 | win = _fspecial_gauss_1d(win_size, win_sigma) 205 | win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) 206 | 207 | levels = weights.shape[0] 208 | mcs = [] 209 | for i in range(levels): 210 | ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K) 211 | 212 | if i < levels - 1: 213 | mcs.append(torch.relu(cs)) 214 | padding = [s % 2 for s in X.shape[2:]] 215 | X = avg_pool(X, kernel_size=2, padding=padding) 216 | Y = avg_pool(Y, kernel_size=2, padding=padding) 217 | 218 | ssim_per_channel = torch.relu(ssim_per_channel) # (batch, channel) 219 | mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) # (level, batch, channel) 220 | ms_ssim_val = torch.prod(mcs_and_ssim ** weights.view(-1, 1, 1), dim=0) 221 | 222 | if size_average: 223 | return ms_ssim_val.mean() 224 | else: 225 | return ms_ssim_val.mean(1) 226 | 227 | 228 | class SSIM(torch.nn.Module): 229 | def __init__( 230 | self, 231 | data_range=255, 232 | size_average=True, 233 | win_size=11, 234 | win_sigma=1.5, 235 | channel=3, 236 | spatial_dims=2, 237 | K=(0.01, 0.03), 238 | nonnegative_ssim=False, 239 | ): 240 | r""" class for ssim 241 | Args: 242 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 243 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 244 | win_size: (int, optional): the size of gauss kernel 245 | win_sigma: (float, optional): sigma of normal distribution 246 | channel (int, optional): input channels (default: 3) 247 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 248 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu. 249 | """ 250 | 251 | super(SSIM, self).__init__() 252 | self.win_size = win_size 253 | self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) 254 | self.size_average = size_average 255 | self.data_range = data_range 256 | self.K = K 257 | self.nonnegative_ssim = nonnegative_ssim 258 | 259 | def forward(self, X, Y): 260 | return ssim( 261 | X, 262 | Y, 263 | data_range=self.data_range, 264 | size_average=self.size_average, 265 | win=self.win, 266 | K=self.K, 267 | nonnegative_ssim=self.nonnegative_ssim, 268 | ) 269 | 270 | 271 | class MS_SSIM(torch.nn.Module): 272 | def __init__( 273 | self, 274 | data_range=255, 275 | size_average=True, 276 | win_size=11, 277 | win_sigma=1.5, 278 | channel=3, 279 | spatial_dims=2, 280 | weights=None, 281 | K=(0.01, 0.03), 282 | ): 283 | r""" class for ms-ssim 284 | Args: 285 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 286 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 287 | win_size: (int, optional): the size of gauss kernel 288 | win_sigma: (float, optional): sigma of normal distribution 289 | channel (int, optional): input channels (default: 3) 290 | weights (list, optional): weights for different levels 291 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 292 | """ 293 | 294 | super(MS_SSIM, self).__init__() 295 | self.win_size = win_size 296 | self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) 297 | self.size_average = size_average 298 | self.data_range = data_range 299 | self.weights = weights 300 | self.K = K 301 | 302 | def forward(self, X, Y): 303 | return ms_ssim( 304 | X, 305 | Y, 306 | data_range=self.data_range, 307 | size_average=self.size_average, 308 | win=self.win, 309 | weights=self.weights, 310 | K=self.K, 311 | ) -------------------------------------------------------------------------------- /_test_speed/no5_ssim_francois_rozet_piqa.py: -------------------------------------------------------------------------------- 1 | # https://github.com/francois-rozet/piqa/blob/master/piqa/utils.py abaf398 2 | r"""Miscellaneous tools such as modules, functionals and more. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from typing import Callable, List, Tuple, Union 10 | 11 | 12 | jit = torch.jit.script 13 | 14 | 15 | def channel_conv( 16 | x: torch.Tensor, 17 | kernel: torch.Tensor, 18 | padding: int = 0, # Union[int, Tuple[int, ...]] 19 | ) -> torch.Tensor: 20 | r"""Returns the channel-wise convolution of `x` with respect to `kernel`. 21 | Args: 22 | x: An input tensor, (N, C, *). 23 | kernel: A kernel, (C', 1, *). 24 | padding: The implicit paddings on both sides of the input dimensions. 25 | Example: 26 | >>> x = torch.arange(25).float().view(1, 1, 5, 5) 27 | >>> x 28 | tensor([[[[ 0., 1., 2., 3., 4.], 29 | [ 5., 6., 7., 8., 9.], 30 | [10., 11., 12., 13., 14.], 31 | [15., 16., 17., 18., 19.], 32 | [20., 21., 22., 23., 24.]]]]) 33 | >>> kernel = torch.ones((1, 1, 3, 3)) 34 | >>> channel_conv(x, kernel) 35 | tensor([[[[ 54., 63., 72.], 36 | [ 99., 108., 117.], 37 | [144., 153., 162.]]]]) 38 | """ 39 | 40 | return F.conv1d(x, kernel, padding=padding, groups=x.size(1)) 41 | 42 | 43 | def channel_sep_conv( 44 | x: torch.Tensor, 45 | kernels: List[torch.Tensor], 46 | ) -> torch.Tensor: 47 | r"""Returns the channel-wise convolution of `x` with respect to the 48 | separated kernel `kernels`. 49 | Args: 50 | x: An input tensor, (N, C, *). 51 | kernels: A separated kernel, (C', 1, 1*, K, 1*). 52 | Example: 53 | >>> x = torch.arange(25).float().view(1, 1, 5, 5) 54 | >>> x 55 | tensor([[[[ 0., 1., 2., 3., 4.], 56 | [ 5., 6., 7., 8., 9.], 57 | [10., 11., 12., 13., 14.], 58 | [15., 16., 17., 18., 19.], 59 | [20., 21., 22., 23., 24.]]]]) 60 | >>> kernels = [torch.ones((1, 1, 3, 1)), torch.ones((1, 1, 1, 3))] 61 | >>> channel_sep_conv(x, kernels) 62 | tensor([[[[ 54., 63., 72.], 63 | [ 99., 108., 117.], 64 | [144., 153., 162.]]]]) 65 | """ 66 | 67 | for kernel in kernels: 68 | x = channel_conv(x, kernel) 69 | 70 | return x 71 | 72 | 73 | def unravel_index( 74 | indices: torch.LongTensor, 75 | shape: Tuple[int, ...], 76 | ) -> torch.LongTensor: 77 | r"""Converts flat indices into unraveled coordinates in a target shape. 78 | This is a `torch` implementation of `numpy.unravel_index`. 79 | Args: 80 | indices: A tensor of (flat) indices, (*, N). 81 | shape: The targeted shape, (D,). 82 | Returns: 83 | coord: The unraveled coordinates, (*, N, D). 84 | Example: 85 | >>> unravel_index(torch.arange(9), (3, 3)) 86 | tensor([[0, 0], 87 | [0, 1], 88 | [0, 2], 89 | [1, 0], 90 | [1, 1], 91 | [1, 2], 92 | [2, 0], 93 | [2, 1], 94 | [2, 2]]) 95 | """ 96 | 97 | coord = [] 98 | 99 | for dim in reversed(shape): 100 | coord.append(indices % dim) 101 | indices = indices // dim 102 | 103 | coord = torch.stack(coord[::-1], dim=-1) 104 | 105 | return coord 106 | 107 | 108 | def gaussian_kernel( 109 | kernel_size: int, 110 | sigma: float = 1. 111 | ) -> torch.Tensor: 112 | r"""Returns the 1D Gaussian kernel of size `kernel_size`. 113 | The distribution is centered around the kernel's center 114 | and the standard deviation is `sigma`. 115 | Args: 116 | kernel_size: The size of the kernel. 117 | sigma: The standard deviation of the distribution. 118 | Wikipedia: 119 | https://en.wikipedia.org/wiki/Normal_distribution 120 | Example: 121 | >>> gaussian_kernel(5, sigma=1.5) 122 | tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201]) 123 | """ 124 | 125 | kernel = torch.arange(kernel_size).float() 126 | kernel -= (kernel_size - 1) / 2 127 | kernel = kernel ** 2 / (2. * sigma ** 2) 128 | kernel = torch.exp(-kernel) 129 | kernel /= kernel.sum() 130 | 131 | return kernel 132 | 133 | 134 | def haar_kernel(size: int) -> Tuple[torch.Tensor, torch.Tensor]: 135 | r"""Returns the separated Haar kernel. 136 | Args: 137 | size: The kernel (even) size. 138 | Wikipedia: 139 | https://en.wikipedia.org/wiki/Haar_wavelet 140 | Example: 141 | >>> haar_kernel(2) 142 | (tensor([0.5000, 0.5000]), tensor([ 1., -1.])) 143 | """ 144 | 145 | return ( 146 | torch.ones(size) / size, 147 | torch.tensor([1., -1.]).repeat_interleave(size // 2) 148 | ) 149 | 150 | 151 | def prewitt_kernel() -> Tuple[torch.Tensor, torch.Tensor]: 152 | r"""Returns the separated 3x3 Prewitt kernel. 153 | Wikipedia: 154 | https://en.wikipedia.org/wiki/Prewitt_operator 155 | Example: 156 | >>> prewitt_kernel() 157 | (tensor([0.3333, 0.3333, 0.3333]), tensor([ 1., 0., -1.])) 158 | """ 159 | 160 | return torch.tensor([1., 1., 1.]) / 3, torch.tensor([1., 0., -1.]) 161 | 162 | 163 | def sobel_kernel() -> Tuple[torch.Tensor, torch.Tensor]: 164 | r"""Returns the separated 3x3 Sobel kernel. 165 | Wikipedia: 166 | https://en.wikipedia.org/wiki/Sobel_operator 167 | Example: 168 | >>> sobel_kernel() 169 | (tensor([0.2500, 0.5000, 0.2500]), tensor([ 1., 0., -1.])) 170 | """ 171 | 172 | return torch.tensor([1., 2., 1.]) / 4, torch.tensor([1., 0., -1.]) 173 | 174 | 175 | def scharr_kernel() -> Tuple[torch.Tensor, torch.Tensor]: 176 | r"""Returns the separated 3x3 Scharr kernel. 177 | Wikipedia: 178 | https://en.wikipedia.org/wiki/Scharr_operator 179 | Example: 180 | >>> scharr_kernel() 181 | (tensor([0.1875, 0.6250, 0.1875]), tensor([ 1., 0., -1.])) 182 | """ 183 | 184 | return torch.tensor([3., 10., 3.]) / 16, torch.tensor([1., 0., -1.]) 185 | 186 | 187 | def tensor_norm( 188 | x: torch.Tensor, 189 | dim: List[int], # Union[int, Tuple[int, ...]] = () 190 | keepdim: bool = False, 191 | norm: str = 'L2', 192 | ) -> torch.Tensor: 193 | r"""Returns the norm of `x`. 194 | Args: 195 | x: An input tensor. 196 | dim: The dimension(s) along which to calculate the norm. 197 | keepdim: Whether the output tensor has `dim` retained or not. 198 | norm: Specifies the norm funcion to apply: 199 | `'L1'` | `'L2'` | `'L2_squared'`. 200 | Wikipedia: 201 | https://en.wikipedia.org/wiki/Norm_(mathematics) 202 | Example: 203 | >>> x = torch.arange(9).float().view(3, 3) 204 | >>> x 205 | tensor([[0., 1., 2.], 206 | [3., 4., 5.], 207 | [6., 7., 8.]]) 208 | >>> tensor_norm(x, dim=0) 209 | tensor([6.7082, 8.1240, 9.6437]) 210 | """ 211 | 212 | if norm in ['L2', 'L2_squared']: 213 | x = x ** 2 214 | else: # norm == 'L1' 215 | x = x.abs() 216 | 217 | x = x.sum(dim=dim, keepdim=keepdim) 218 | 219 | if norm == 'L2': 220 | x = x.sqrt() 221 | 222 | return x 223 | 224 | 225 | def normalize_tensor( 226 | x: torch.Tensor, 227 | dim: List[int], # Union[int, Tuple[int, ...]] = () 228 | norm: str = 'L2', 229 | epsilon: float = 1e-8, 230 | ) -> torch.Tensor: 231 | r"""Returns `x` normalized. 232 | Args: 233 | x: An input tensor. 234 | dim: The dimension(s) along which to normalize. 235 | norm: Specifies the norm funcion to use: 236 | `'L1'` | `'L2'` | `'L2_squared'`. 237 | epsilon: A numerical stability term. 238 | Example: 239 | >>> x = torch.arange(9).float().view(3, 3) 240 | >>> x 241 | tensor([[0., 1., 2.], 242 | [3., 4., 5.], 243 | [6., 7., 8.]]) 244 | >>> normalize_tensor(x, dim=0) 245 | tensor([[0.0000, 0.1231, 0.2074], 246 | [0.4472, 0.4924, 0.5185], 247 | [0.8944, 0.8616, 0.8296]]) 248 | """ 249 | 250 | norm = tensor_norm(x, dim=dim, keepdim=True, norm=norm) 251 | 252 | return x / (norm + epsilon) 253 | 254 | 255 | def cpow( 256 | x: torch.cfloat, 257 | exponent: Union[int, float, torch.Tensor], 258 | ) -> torch.cfloat: 259 | r"""Returns the power of `x` with `exponent`. 260 | Args: 261 | x: A complex input tensor. 262 | exponent: The exponent value or tensor. 263 | Example: 264 | >>> x = torch.tensor([1. + 0.j, 0.707 + 0.707j]) 265 | >>> cpow(x, 2) 266 | tensor([ 1.0000e+00+0.0000j, -4.3698e-08+0.9997j]) 267 | """ 268 | 269 | r = x.abs() ** exponent 270 | phi = torch.atan2(x.imag, x.real) * exponent 271 | 272 | return torch.complex(r * torch.cos(phi), r * torch.sin(phi)) 273 | 274 | 275 | class Intermediary(nn.Module): 276 | r"""Module that catches and returns the outputs of indermediate 277 | target layers of a sequential module during its forward pass. 278 | Args: 279 | layers: A sequential module. 280 | targets: A list of target layer indexes. 281 | """ 282 | 283 | def __init__(self, layers: nn.Sequential, targets: List[int]): 284 | r"""""" 285 | super().__init__() 286 | 287 | self.layers = layers 288 | self.targets = targets 289 | self.len = len(self.targets) 290 | 291 | def forward(self, input: torch.Tensor) -> List[torch.Tensor]: 292 | r"""Defines the computation performed at every call. 293 | """ 294 | 295 | output = [] 296 | j = 0 297 | 298 | for i, layer in enumerate(self.layers): 299 | input = layer(input) 300 | 301 | if i == self.targets[j]: 302 | output.append(input) 303 | j += 1 304 | 305 | if j == self.len: 306 | break 307 | 308 | return output 309 | 310 | 311 | def build_reduce(reduction: str = 'mean') -> nn.Module: 312 | r"""Returns a reducing module. 313 | Args: 314 | reduction: Specifies the reduce type: 315 | `'none'` | `'mean'` | `'sum'`. 316 | Example: 317 | >>> red = build_reduce(reduction='sum') 318 | >>> red(torch.arange(5)) 319 | tensor(10) 320 | """ 321 | 322 | if reduction == 'mean': 323 | return _Mean() 324 | elif reduction == 'sum': 325 | return _Sum() 326 | 327 | return nn.Identity() 328 | 329 | 330 | class _Mean(nn.Module): 331 | def forward(self, input: torch.Tensor) -> torch.Tensor: 332 | return input.mean() 333 | 334 | 335 | class _Sum(nn.Module): 336 | def forward(self, input: torch.Tensor) -> torch.Tensor: 337 | return input.sum() 338 | 339 | 340 | # https://github.com/francois-rozet/piqa/blob/master/piqa/ssim.py abaf398 341 | 342 | r"""Structural Similarity (SSIM) and Multi-Scale Structural Similarity (MS-SSIM) 343 | This module implements the SSIM and MS-SSIM in PyTorch. 344 | Wikipedia: 345 | https://en.wikipedia.org/wiki/Structural_similarity 346 | Credits: 347 | Inspired by [pytorch-msssim](https://github.com/VainF/pytorch-msssim) 348 | References: 349 | [1] Multiscale structural similarity for image quality assessment 350 | (Wang et al., 2003) 351 | https://ieeexplore.ieee.org/abstract/document/1292216/ 352 | [2] Image quality assessment: From error visibility to structural similarity 353 | (Wang et al., 2004) 354 | https://ieeexplore.ieee.org/abstract/document/1284395/ 355 | """ 356 | 357 | import torch 358 | import torch.nn as nn 359 | import torch.nn.functional as F 360 | 361 | # from piqa.utils import jit, build_reduce, gaussian_kernel, channel_sep_conv 362 | 363 | from typing import Union, List, Tuple 364 | 365 | _MS_WEIGHTS = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 366 | 367 | 368 | @jit 369 | def create_window( 370 | window_size: int, 371 | n_channels: int, 372 | device: torch.device = torch.device('cpu'), 373 | ) -> List[torch.Tensor]: 374 | r"""Returns the SSIM convolution window of size `window_size`. 375 | Args: 376 | window_size: The size of the window. 377 | n_channels: A number of channels. 378 | device: Specifies the device of the created window. 379 | Example: 380 | >>> win = create_window(5, n_channels=3) 381 | >>> win[0].size() 382 | torch.Size([3, 1, 5, 1]) 383 | >>> win[0][0] 384 | tensor([[[0.1201], 385 | [0.2339], 386 | [0.2921], 387 | [0.2339], 388 | [0.1201]]]) 389 | """ 390 | 391 | kernel = gaussian_kernel(window_size, 1.5).to(device) 392 | kernel = kernel.repeat(n_channels, 1, 1) 393 | 394 | return [ 395 | kernel.unsqueeze(-1).contiguous(), 396 | kernel.unsqueeze(-2).contiguous() 397 | ] 398 | 399 | 400 | @jit 401 | def ssim_per_channel( 402 | x: torch.Tensor, 403 | y: torch.Tensor, 404 | window: List[torch.Tensor], 405 | value_range: float = 1., 406 | non_negative: bool = False, 407 | k1: float = 0.01, 408 | k2: float = 0.03, 409 | ) -> List[torch.Tensor]: 410 | r"""Returns the SSIM and the contrast sensitivity per channel 411 | between `x` and `y`. 412 | Args: 413 | x: An input tensor, (N, C, H, W). 414 | y: A target tensor, (N, C, H, W). 415 | window: A separated kernel, ((C, 1, K, 1), (C, 1, 1, K)). 416 | value_range: The value range of the inputs (usually 1. or 255). 417 | non_negative: Whether negative values are clipped or not. 418 | For the remaining arguments, refer to [2]. 419 | Example: 420 | >>> x = torch.rand(5, 3, 256, 256) 421 | >>> y = torch.rand(5, 3, 256, 256) 422 | >>> window = create_window(7, 3) 423 | >>> ss, cs = ssim_per_channel(x, y, window) 424 | >>> ss.size(), cs.size() 425 | (torch.Size([5, 3]), torch.Size([5, 3])) 426 | """ 427 | 428 | c1 = (k1 * value_range) ** 2 429 | c2 = (k2 * value_range) ** 2 430 | 431 | # Mean (mu) 432 | mu_x = channel_sep_conv(x, window) 433 | mu_y = channel_sep_conv(y, window) 434 | 435 | mu_xx = mu_x ** 2 436 | mu_yy = mu_y ** 2 437 | mu_xy = mu_x * mu_y 438 | 439 | # Variance (sigma) 440 | sigma_xx = channel_sep_conv(x ** 2, window) - mu_xx 441 | sigma_yy = channel_sep_conv(y ** 2, window) - mu_yy 442 | sigma_xy = channel_sep_conv(x * y, window) - mu_xy 443 | 444 | # Contrast sensitivity 445 | cs = (2. * sigma_xy + c2) / (sigma_xx + sigma_yy + c2) 446 | 447 | # Structural similarity 448 | ss = (2. * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs 449 | 450 | # Average 451 | ss, cs = ss.mean((-1, -2)), cs.mean((-1, -2)) 452 | 453 | if non_negative: 454 | ss, cs = torch.relu(ss), torch.relu(cs) 455 | 456 | return ss, cs 457 | 458 | 459 | def ssim( 460 | x: torch.Tensor, 461 | y: torch.Tensor, 462 | window_size: int = 11, 463 | **kwargs, 464 | ) -> torch.Tensor: 465 | r"""Returns the SSIM between `x` and `y`. 466 | Args: 467 | x: An input tensor, (N, C, H, W). 468 | y: A target tensor, (N, C, H, W). 469 | window_size: The size of the window. 470 | `**kwargs` are transmitted to `ssim_per_channel`. 471 | Example: 472 | >>> x = torch.rand(5, 3, 256, 256) 473 | >>> y = torch.rand(5, 3, 256, 256) 474 | >>> l = ssim(x, y) 475 | >>> l.size() 476 | torch.Size([5]) 477 | """ 478 | 479 | window = create_window(window_size, x.size(1), device=x.device) 480 | 481 | return ssim_per_channel(x, y, window, **kwargs)[0].mean(-1) 482 | 483 | 484 | @jit 485 | def msssim_per_channel( 486 | x: torch.Tensor, 487 | y: torch.Tensor, 488 | window: List[torch.Tensor], 489 | weights: torch.Tensor, 490 | value_range: float = 1., 491 | k1: float = 0.01, 492 | k2: float = 0.03, 493 | ) -> torch.Tensor: 494 | """Returns the MS-SSIM per channel between `x` and `y`. 495 | Args: 496 | x: An input tensor, (N, C, H, W). 497 | y: A target tensor, (N, C, H, W). 498 | window: A separated kernel, ((C, 1, K, 1), (C, 1, 1, K)). 499 | weights: The weights of the scales, (M,). 500 | value_range: The value range of the inputs (usually 1. or 255). 501 | For the remaining arguments, refer to [2]. 502 | Example: 503 | >>> x = torch.rand(5, 3, 256, 256) 504 | >>> y = torch.rand(5, 3, 256, 256) 505 | >>> window = create_window(7, 3) 506 | >>> weights = torch.rand(5) 507 | >>> l = msssim_per_channel(x, y, window, weights) 508 | >>> l.size() 509 | torch.Size([5, 3]) 510 | """ 511 | 512 | css = [] 513 | 514 | m = weights.numel() 515 | for i in range(m): 516 | if i > 0: 517 | x = F.avg_pool2d(x, kernel_size=2, ceil_mode=True) 518 | y = F.avg_pool2d(y, kernel_size=2, ceil_mode=True) 519 | 520 | ss, cs = ssim_per_channel( 521 | x, y, window, 522 | value_range=value_range, 523 | non_negative=True, 524 | k1=k1, k2=k2, 525 | ) 526 | 527 | css.append(cs if i + 1 < m else ss) 528 | 529 | msss = torch.stack(css, dim=-1) 530 | msss = (msss ** weights).prod(dim=-1) 531 | 532 | return msss 533 | 534 | 535 | def msssim( 536 | x: torch.Tensor, 537 | y: torch.Tensor, 538 | window_size: int = 11, 539 | sigma: float = 1.5, 540 | weights: torch.Tensor = None, 541 | **kwargs, 542 | ) -> torch.Tensor: 543 | r"""Returns the MS-SSIM between `x` and `y`. 544 | Args: 545 | x: An input tensor, (N, C, H, W). 546 | y: A target tensor, (N, C, H, W). 547 | window_size: The size of the window. 548 | weights: The weights of the scales, (M,). 549 | If `None`, use the official weights instead. 550 | `**kwargs` are transmitted to `msssim_per_channel`. 551 | Example: 552 | >>> x = torch.rand(5, 3, 256, 256) 553 | >>> y = torch.rand(5, 3, 256, 256) 554 | >>> l = msssim(x, y) 555 | >>> l.size() 556 | torch.Size([5]) 557 | """ 558 | 559 | window = create_window(window_size, x.size(1), device=x.device) 560 | 561 | if weights is None: 562 | weights = _MS_WEIGHTS.to(x.device) 563 | 564 | return msssim_per_channel(x, y, window, weights, **kwargs).mean(-1) 565 | 566 | 567 | class SSIM(nn.Module): 568 | r"""Creates a criterion that measures the SSIM 569 | between an input and a target. 570 | Args: 571 | window_size: The size of the window. 572 | n_channels: The number of channels. 573 | reduction: Specifies the reduction to apply to the output: 574 | `'none'` | `'mean'` | `'sum'`. 575 | `**kwargs` are transmitted to `ssim_per_channel`. 576 | Shape: 577 | * Input: (N, C, H, W) 578 | * Target: (N, C, H, W), same shape as the input 579 | * Output: (N,) or (1,) depending on `reduction` 580 | Example: 581 | >>> criterion = SSIM().cuda() 582 | >>> x = torch.rand(5, 3, 256, 256).cuda() 583 | >>> y = torch.rand(5, 3, 256, 256).cuda() 584 | >>> l = criterion(x, y) 585 | >>> l.size() 586 | torch.Size([]) 587 | """ 588 | 589 | def __init__( 590 | self, 591 | window_size: int = 11, 592 | n_channels: int = 3, 593 | reduction: str = 'mean', 594 | **kwargs, 595 | ): 596 | r"""""" 597 | super().__init__() 598 | 599 | window = create_window(window_size, n_channels) 600 | 601 | self.register_buffer('window0', window[0]) 602 | self.register_buffer('window1', window[1]) 603 | 604 | self.reduce = build_reduce(reduction) 605 | self.kwargs = kwargs 606 | 607 | @property 608 | def window(self) -> List[torch.Tensor]: 609 | return [self.window0, self.window1] 610 | 611 | def forward( 612 | self, 613 | input: torch.Tensor, 614 | target: torch.Tensor, 615 | ) -> torch.Tensor: 616 | r"""Defines the computation performed at every call. 617 | """ 618 | 619 | l = ssim_per_channel( 620 | input, 621 | target, 622 | window=self.window, 623 | **self.kwargs, 624 | )[0].mean(-1) 625 | 626 | return self.reduce(l) 627 | 628 | 629 | class MSSSIM(nn.Module): 630 | r"""Creates a criterion that measures the MS-SSIM 631 | between an input and a target. 632 | Args: 633 | window_size: The size of the window. 634 | n_channels: The number of channels. 635 | weights: The weights of the scales, (M,). 636 | If `None`, use the official weights instead. 637 | reduction: Specifies the reduction to apply to the output: 638 | `'none'` | `'mean'` | `'sum'`. 639 | `**kwargs` are transmitted to `msssim_per_channel`. 640 | Shape: 641 | * Input: (N, C, H, W) 642 | * Target: (N, C, H, W), same shape as the input 643 | * Output: (N,) or (1,) depending on `reduction` 644 | Example: 645 | >>> criterion = MSSSIM().cuda() 646 | >>> x = torch.rand(5, 3, 256, 256).cuda() 647 | >>> y = torch.rand(5, 3, 256, 256).cuda() 648 | >>> l = criterion(x, y) 649 | >>> l.size() 650 | torch.Size([]) 651 | """ 652 | 653 | def __init__( 654 | self, 655 | window_size: int = 11, 656 | n_channels: int = 3, 657 | weights: torch.Tensor = None, 658 | reduction: str = 'mean', 659 | **kwargs, 660 | ): 661 | r"""""" 662 | super().__init__() 663 | 664 | window = create_window(window_size, n_channels) 665 | 666 | self.register_buffer('window0', window[0]) 667 | self.register_buffer('window1', window[1]) 668 | 669 | if weights is None: 670 | weights = _MS_WEIGHTS 671 | 672 | self.register_buffer('weights', weights) 673 | 674 | self.reduce = build_reduce(reduction) 675 | self.kwargs = kwargs 676 | 677 | @property 678 | def window(self) -> List[torch.Tensor]: 679 | return [self.window0, self.window1] 680 | 681 | def forward( 682 | self, 683 | input: torch.Tensor, 684 | target: torch.Tensor, 685 | ) -> torch.Tensor: 686 | r"""Defines the computation performed at every call. 687 | """ 688 | 689 | l = msssim_per_channel( 690 | input, 691 | target, 692 | window=self.window, 693 | weights=self.weights, 694 | **self.kwargs, 695 | ).mean(-1) 696 | 697 | return self.reduce(l) -------------------------------------------------------------------------------- /_test_speed/test_ms_ssim_speed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | sys.path.append('../') 4 | 5 | from no1_ms_ssim_lizhengwei1992_MS_SSIM_pytorch import MS_SSIM as MS_SSIM1 6 | from no3_ssim_VainF_pytorch_msssim import MS_SSIM as MS_SSIM3 7 | from ssim import MS_SSIM as MS_SSIM4 8 | from no5_ssim_francois_rozet_piqa import MSSSIM as MS_SSIM5 9 | 10 | 11 | def test_speed(losser): 12 | a = torch.randint(0, 255, size=(20, 3, 256, 256), dtype=torch.float32).cuda() / 255. 13 | b = a * 0.5 14 | a.requires_grad = True 15 | b.requires_grad = True 16 | 17 | start_record = torch.cuda.Event(enable_timing=True) 18 | end_record = torch.cuda.Event(enable_timing=True) 19 | 20 | start_time = time.perf_counter() 21 | start_record.record() 22 | for _ in range(500): 23 | loss = losser(a, b).mean() 24 | loss.backward() 25 | end_record.record() 26 | end_time = time.perf_counter() 27 | 28 | torch.cuda.synchronize() 29 | 30 | print('cuda time', start_record.elapsed_time(end_record)) 31 | print('perf_counter time', end_time - start_time) 32 | 33 | 34 | if __name__ == '__main__': 35 | print('Performance Testing MS_SSIM') 36 | print() 37 | import time 38 | losser1 = MS_SSIM1(size_average=False, max_val=1.).cuda() 39 | losser3 = MS_SSIM3(win_size=11, win_sigma=1.5, data_range=1., size_average=False, channel=3).cuda() 40 | losser4 = MS_SSIM4(window_size=11, window_sigma=1.5, data_range=1., channel=3, use_padding=False).cuda() 41 | losser5 = MS_SSIM5(window_size=11, value_range=1., n_channels=3, reduction='none').cuda() 42 | 43 | print('testing losser1') 44 | test_speed(losser1) 45 | print() 46 | 47 | print('testing losser3') 48 | test_speed(losser3) 49 | print() 50 | 51 | print('testing losser4') 52 | test_speed(losser4) 53 | print() 54 | 55 | print('testing losser5') 56 | test_speed(losser5) 57 | print() 58 | -------------------------------------------------------------------------------- /_test_speed/test_ssim_speed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | sys.path.append('../') 4 | 5 | from no2_ssim_Po_Hsun_Su_pytorch_ssim import SSIM as SSIM2 6 | from no3_ssim_VainF_pytorch_msssim import SSIM as SSIM3 7 | from ssim import SSIM as SSIM4 8 | from no5_ssim_francois_rozet_piqa import SSIM as SSIM5 9 | 10 | 11 | def test_speed(losser): 12 | a = torch.randint(0, 255, size=(20, 3, 256, 256), dtype=torch.float32).cuda() / 255. 13 | b = a * 0.5 14 | a.requires_grad = True 15 | b.requires_grad = True 16 | 17 | start_record = torch.cuda.Event(enable_timing=True) 18 | end_record = torch.cuda.Event(enable_timing=True) 19 | 20 | start_time = time.perf_counter() 21 | start_record.record() 22 | for _ in range(500): 23 | loss = losser(a, b).mean() 24 | loss.backward() 25 | end_record.record() 26 | end_time = time.perf_counter() 27 | 28 | torch.cuda.synchronize() 29 | 30 | print('cuda time', start_record.elapsed_time(end_record)) 31 | print('perf_counter time', end_time - start_time) 32 | 33 | 34 | if __name__ == '__main__': 35 | print('Performance Testing SSIM') 36 | print() 37 | import time 38 | losser2 = SSIM2(window_size=11, size_average=False).cuda() 39 | losser3 = SSIM3(win_size=11, win_sigma=1.5, data_range=1., size_average=False, channel=3).cuda() 40 | losser4 = SSIM4(window_size=11, window_sigma=1.5, data_range=1., channel=3, use_padding=False).cuda() 41 | losser5 = SSIM5(window_size=11, value_range=1., n_channels=3, reduction='none').cuda() 42 | 43 | print('testing losser2') 44 | test_speed(losser2) 45 | print() 46 | 47 | print('testing losser3') 48 | test_speed(losser3) 49 | print() 50 | 51 | print('testing losser4') 52 | test_speed(losser4) 53 | print() 54 | 55 | print('testing losser5') 56 | test_speed(losser5) 57 | print() 58 | -------------------------------------------------------------------------------- /make_gif.cmd: -------------------------------------------------------------------------------- 1 | ffmpeg.exe -i ./ssim_test.mkv -r 2 ./ssim_test.gif 2 | ffmpeg.exe -i ./ms_ssim_test.mkv -r 2 ./ms_ssim_test.gif 3 | echo "convert complete" -------------------------------------------------------------------------------- /ms_ssim_test.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/One-sixth/ms_ssim_pytorch/6269c62e0dd29c91fa38e4ba73d906d0c84ca966/ms_ssim_test.gif -------------------------------------------------------------------------------- /ms_ssim_test.mkv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/One-sixth/ms_ssim_pytorch/6269c62e0dd29c91fa38e4ba73d906d0c84ca966/ms_ssim_test.mkv -------------------------------------------------------------------------------- /ssim.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code modified from 3 | https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py 4 | ''' 5 | 6 | import torch 7 | import torch.jit 8 | import torch.nn.functional as F 9 | 10 | 11 | @torch.jit.script 12 | def create_window(window_size: int, sigma: float, channel: int): 13 | ''' 14 | Create 1-D gauss kernel 15 | :param window_size: the size of gauss kernel 16 | :param sigma: sigma of normal distribution 17 | :param channel: input channel 18 | :return: 1D kernel 19 | ''' 20 | coords = torch.arange(window_size, dtype=torch.float) 21 | coords -= window_size // 2 22 | 23 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) 24 | g /= g.sum() 25 | 26 | g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1) 27 | return g 28 | 29 | 30 | @torch.jit.script 31 | def _gaussian_filter(x, window_1d, use_padding: bool): 32 | ''' 33 | Blur input with 1-D kernel 34 | :param x: batch of tensors to be blured 35 | :param window_1d: 1-D gauss kernel 36 | :param use_padding: padding image before conv 37 | :return: blured tensors 38 | ''' 39 | C = x.shape[1] 40 | padding = 0 41 | if use_padding: 42 | window_size = window_1d.shape[3] 43 | padding = window_size // 2 44 | out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C) 45 | out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C) 46 | return out 47 | 48 | 49 | @torch.jit.script 50 | def ssim(X, Y, window, data_range: float, use_padding: bool=False): 51 | ''' 52 | Calculate ssim index for X and Y 53 | :param X: images 54 | :param Y: images 55 | :param window: 1-D gauss kernel 56 | :param data_range: value range of input images. (usually 1.0 or 255) 57 | :param use_padding: padding image before conv 58 | :return: 59 | ''' 60 | 61 | K1 = 0.01 62 | K2 = 0.03 63 | compensation = 1.0 64 | 65 | C1 = (K1 * data_range) ** 2 66 | C2 = (K2 * data_range) ** 2 67 | 68 | mu1 = _gaussian_filter(X, window, use_padding) 69 | mu2 = _gaussian_filter(Y, window, use_padding) 70 | sigma1_sq = _gaussian_filter(X * X, window, use_padding) 71 | sigma2_sq = _gaussian_filter(Y * Y, window, use_padding) 72 | sigma12 = _gaussian_filter(X * Y, window, use_padding) 73 | 74 | mu1_sq = mu1.pow(2) 75 | mu2_sq = mu2.pow(2) 76 | mu1_mu2 = mu1 * mu2 77 | 78 | sigma1_sq = compensation * (sigma1_sq - mu1_sq) 79 | sigma2_sq = compensation * (sigma2_sq - mu2_sq) 80 | sigma12 = compensation * (sigma12 - mu1_mu2) 81 | 82 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) 83 | # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan. 84 | cs_map = F.relu(cs_map) 85 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map 86 | 87 | ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW 88 | cs = cs_map.mean(dim=(1, 2, 3)) 89 | 90 | return ssim_val, cs 91 | 92 | 93 | @torch.jit.script 94 | def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool=False, eps: float=1e-8): 95 | ''' 96 | interface of ms-ssim 97 | :param X: a batch of images, (N,C,H,W) 98 | :param Y: a batch of images, (N,C,H,W) 99 | :param window: 1-D gauss kernel 100 | :param data_range: value range of input images. (usually 1.0 or 255) 101 | :param weights: weights for different levels 102 | :param use_padding: padding image before conv 103 | :param eps: use for avoid grad nan. 104 | :return: 105 | ''' 106 | weights = weights[:, None] 107 | 108 | levels = weights.shape[0] 109 | vals = [] 110 | for i in range(levels): 111 | ss, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding) 112 | 113 | if i < levels-1: 114 | vals.append(cs) 115 | X = F.avg_pool2d(X, kernel_size=2, stride=2, ceil_mode=True) 116 | Y = F.avg_pool2d(Y, kernel_size=2, stride=2, ceil_mode=True) 117 | else: 118 | vals.append(ss) 119 | 120 | vals = torch.stack(vals, dim=0) 121 | # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. 122 | vals = vals.clamp_min(eps) 123 | # The origin ms-ssim op. 124 | ms_ssim_val = torch.prod(vals[:-1] ** weights[:-1] * vals[-1:] ** weights[-1:], dim=0) 125 | # The new ms-ssim op. But I don't know which is best. 126 | # ms_ssim_val = torch.prod(vals ** weights, dim=0) 127 | # In this file's image training demo. I feel the old ms-ssim more better. So I keep use old ms-ssim op. 128 | return ms_ssim_val 129 | 130 | 131 | class SSIM(torch.jit.ScriptModule): 132 | __constants__ = ['data_range', 'use_padding'] 133 | 134 | def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False): 135 | ''' 136 | :param window_size: the size of gauss kernel 137 | :param window_sigma: sigma of normal distribution 138 | :param data_range: value range of input images. (usually 1.0 or 255) 139 | :param channel: input channels (default: 3) 140 | :param use_padding: padding image before conv 141 | ''' 142 | super().__init__() 143 | assert window_size % 2 == 1, 'Window size must be odd.' 144 | window = create_window(window_size, window_sigma, channel) 145 | self.register_buffer('window', window) 146 | self.data_range = data_range 147 | self.use_padding = use_padding 148 | 149 | @torch.jit.script_method 150 | def forward(self, X, Y): 151 | r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding) 152 | return r[0] 153 | 154 | 155 | class MS_SSIM(torch.jit.ScriptModule): 156 | __constants__ = ['data_range', 'use_padding', 'eps'] 157 | 158 | def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None, levels=None, eps=1e-8): 159 | ''' 160 | class for ms-ssim 161 | :param window_size: the size of gauss kernel 162 | :param window_sigma: sigma of normal distribution 163 | :param data_range: value range of input images. (usually 1.0 or 255) 164 | :param channel: input channels 165 | :param use_padding: padding image before conv 166 | :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 167 | :param levels: number of downsampling 168 | :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. 169 | ''' 170 | super().__init__() 171 | assert window_size % 2 == 1, 'Window size must be odd.' 172 | self.data_range = data_range 173 | self.use_padding = use_padding 174 | self.eps = eps 175 | 176 | window = create_window(window_size, window_sigma, channel) 177 | self.register_buffer('window', window) 178 | 179 | if weights is None: 180 | weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] 181 | weights = torch.tensor(weights, dtype=torch.float) 182 | 183 | if levels is not None: 184 | weights = weights[:levels] 185 | weights = weights / weights.sum() 186 | 187 | self.register_buffer('weights', weights) 188 | 189 | @torch.jit.script_method 190 | def forward(self, X, Y): 191 | return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights, 192 | use_padding=self.use_padding, eps=self.eps) 193 | 194 | 195 | if __name__ == '__main__': 196 | print('Simple Test') 197 | im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda') 198 | img1 = im / 255 199 | img2 = img1 * 0.5 200 | 201 | losser = SSIM(data_range=1.).cuda() 202 | loss = losser(img1, img2).mean() 203 | 204 | losser2 = MS_SSIM(data_range=1.).cuda() 205 | loss2 = losser2(img1, img2).mean() 206 | 207 | print(loss.item()) 208 | print(loss2.item()) 209 | 210 | 211 | if __name__ == '__main__': 212 | print('Training Test') 213 | import cv2 214 | import torch.optim 215 | import numpy as np 216 | import imageio 217 | import time 218 | 219 | out_test_video = False 220 | # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF 221 | video_use_gif = False 222 | 223 | im = cv2.imread('test_img1.jpg', 1) 224 | t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255. 225 | 226 | if out_test_video: 227 | if video_use_gif: 228 | fps = 0.5 229 | out_wh = (im.shape[1]//2, im.shape[0]//2) 230 | suffix = '.gif' 231 | else: 232 | fps = 5 233 | out_wh = (im.shape[1], im.shape[0]) 234 | suffix = '.mkv' 235 | video_last_time = time.perf_counter() 236 | video = imageio.get_writer('ssim_test'+suffix, fps=fps) 237 | 238 | # 测试ssim 239 | print('Training SSIM') 240 | rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255. 241 | rand_im.requires_grad = True 242 | optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8) 243 | losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda() 244 | ssim_score = 0 245 | while ssim_score < 0.999: 246 | optim.zero_grad() 247 | loss = losser(rand_im, t_im) 248 | (-loss).sum().backward() 249 | ssim_score = loss.item() 250 | optim.step() 251 | r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0] 252 | r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2) 253 | 254 | if out_test_video: 255 | if time.perf_counter() - video_last_time > 1. / fps: 256 | video_last_time = time.perf_counter() 257 | out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB) 258 | out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA) 259 | if isinstance(out_frame, cv2.UMat): 260 | out_frame = out_frame.get() 261 | video.append_data(out_frame) 262 | 263 | cv2.imshow('ssim', r_im) 264 | cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score) 265 | cv2.waitKey(1) 266 | 267 | if out_test_video: 268 | video.close() 269 | 270 | # 测试ms_ssim 271 | if out_test_video: 272 | if video_use_gif: 273 | fps = 0.5 274 | out_wh = (im.shape[1]//2, im.shape[0]//2) 275 | suffix = '.gif' 276 | else: 277 | fps = 5 278 | out_wh = (im.shape[1], im.shape[0]) 279 | suffix = '.mkv' 280 | video_last_time = time.perf_counter() 281 | video = imageio.get_writer('ms_ssim_test'+suffix, fps=fps) 282 | 283 | print('Training MS_SSIM') 284 | rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255. 285 | rand_im.requires_grad = True 286 | optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8) 287 | losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda() 288 | ssim_score = 0 289 | while ssim_score < 0.999: 290 | optim.zero_grad() 291 | loss = losser(rand_im, t_im) 292 | (-loss).sum().backward() 293 | ssim_score = loss.item() 294 | optim.step() 295 | r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0] 296 | r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2) 297 | 298 | if out_test_video: 299 | if time.perf_counter() - video_last_time > 1. / fps: 300 | video_last_time = time.perf_counter() 301 | out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB) 302 | out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA) 303 | if isinstance(out_frame, cv2.UMat): 304 | out_frame = out_frame.get() 305 | video.append_data(out_frame) 306 | 307 | cv2.imshow('ms_ssim', r_im) 308 | cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score) 309 | cv2.waitKey(1) 310 | 311 | if out_test_video: 312 | video.close() 313 | 314 | 315 | if __name__ == '__main__': 316 | print('Performance Testing SSIM') 317 | import time 318 | s = SSIM(data_range=1.).cuda() 319 | 320 | a = torch.randint(0, 255, size=(20, 3, 256, 256), dtype=torch.float32).cuda() / 255. 321 | b = a * 0.5 322 | a.requires_grad = True 323 | b.requires_grad = True 324 | 325 | start_record = torch.cuda.Event(enable_timing=True) 326 | end_record = torch.cuda.Event(enable_timing=True) 327 | 328 | start_time = time.perf_counter() 329 | start_record.record() 330 | for _ in range(500): 331 | loss = s(a, b).mean() 332 | loss.backward() 333 | end_record.record() 334 | end_time = time.perf_counter() 335 | 336 | torch.cuda.synchronize() 337 | 338 | print('cuda time', start_record.elapsed_time(end_record)) 339 | print('perf_counter time', end_time-start_time) 340 | 341 | 342 | if __name__ == '__main__': 343 | print('Performance Testing MS_SSIM') 344 | import time 345 | s = MS_SSIM(data_range=1.).cuda() 346 | 347 | a = torch.randint(0, 255, size=(20, 3, 256, 256), dtype=torch.float32).cuda() / 255. 348 | b = a * 0.5 349 | a.requires_grad = True 350 | b.requires_grad = True 351 | 352 | start_record = torch.cuda.Event(enable_timing=True) 353 | end_record = torch.cuda.Event(enable_timing=True) 354 | 355 | start_time = time.perf_counter() 356 | start_record.record() 357 | for _ in range(500): 358 | loss = s(a, b).mean() 359 | loss.backward() 360 | end_record.record() 361 | end_time = time.perf_counter() 362 | 363 | torch.cuda.synchronize() 364 | 365 | print('cuda time', start_record.elapsed_time(end_record)) 366 | print('perf_counter time', end_time-start_time) 367 | -------------------------------------------------------------------------------- /ssim_test.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/One-sixth/ms_ssim_pytorch/6269c62e0dd29c91fa38e4ba73d906d0c84ca966/ssim_test.gif -------------------------------------------------------------------------------- /ssim_test.mkv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/One-sixth/ms_ssim_pytorch/6269c62e0dd29c91fa38e4ba73d906d0c84ca966/ssim_test.mkv -------------------------------------------------------------------------------- /test_img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/One-sixth/ms_ssim_pytorch/6269c62e0dd29c91fa38e4ba73d906d0c84ca966/test_img1.jpg --------------------------------------------------------------------------------