├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── ssim.py └── test ├── old_version.py ├── po_hsun_su_ssim.py ├── test_ssim.py └── vainf_ssim.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Pang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mssim.pytorch 2 | 3 | $$ 4 | \begin{align} 5 | l(\mathbf{x}, \mathbf{y}) & = \frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2, K_1=0.01, \\ 6 | c(\mathbf{x}, \mathbf{y}) & = \frac{2\sigma_{x}\sigma_{y}+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2, K_2=0.02, \\ 7 | s(\mathbf{x}, \mathbf{y}) & = \frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2, \\ 8 | \text{SSIM}(\mathbf{x}, \mathbf{y}) & = [l(\mathbf{x}, \mathbf{y})]^\alpha \cdot [c(\mathbf{x}, \mathbf{y})]^\beta \cdot [s(\mathbf{x}, \mathbf{y})]^\gamma \\ 9 | & = \frac{(2\mu_x\mu_y+C_1)(2\sigma_{xy}+C_2)}{(\mu_x^2+\mu_y^2+C_1)(\sigma_x^2+\sigma_y^2+C_2)}, \\ 10 | & \alpha=\beta=\gamma=1, \\ 11 | \text{MS-SSIM}(\mathbf{x}, \mathbf{y}) & = [l(\mathbf{x}, \mathbf{y})]^{\alpha_{M}} \cdot \prod^{M}_{j=1} [c_j(\mathbf{x}, \mathbf{y})]^{\beta_j} \cdot [s_j(\mathbf{x}, \mathbf{y})]^{\gamma_j}, (M=5) \\ 12 | & \beta_1=\gamma_1=0.0448, \\ 13 | & \beta_2=\gamma_2=0.2856, \\ 14 | & \beta_3=\gamma_3=0.3001, \\ 15 | & \beta_4=\gamma_4=0.2363, \\ 16 | & \alpha_5=\beta_5=\gamma_5=0.1333. 17 | \end{align} 18 | $$ 19 | 20 | A better pytorch-based implementation for the mean structural similarity (MSSIM). 21 | 22 | Compared to this widely used implementation: , I further optimized and refactored the code. 23 | 24 | At the same time, in this implementation, I have dealt with the problem that the calculation with the fp16 mode cannot be consistent with the calculation with the fp32 mode. Typecasting is used here to ensure that the computation is done in fp32 mode. This might also avoid unexpected results when using it as a loss. 25 | 26 | > [!note] 27 | > 2024-12-04: SSIM for 1D, 2D and 3D data, and MS-SSIM calculation for 2D and 3D data are now supported simultaneously. 28 | 29 | | Setting | SSIM1d | SSIM2d | SSIM3d | MS-SSIM2d | MS-SSIM3d (**only pooling in the spatial domain**) | 30 | | --------------- | -------------- | --------------------- | ---------------------------- | --------------------- | -------------------------------------------------- | 31 | | data_dim | 1 | 2 (Default) | 3 | 2 | 3 | 32 | | return_msssim | `False` | `False` | `False` | `True` | `True` | 33 | | window_size | int, [int] | int, [int, int] | int, [int, int, int] | int, [int, int] | int, [int, int, int] | 34 | | padding | int, [int] | int, [int, int] | int, [int, int, int] | int, [int, int] | int, [int, int, int] | 35 | | sigma | float, [float] | float, [float, float] | float, [float, float, float] | float, [float, float] | float, [float, float, float] | 36 | | in_channels | int | int | int | int | int | 37 | | L | 1, 255 | 1, 255 | 1, 255 | 1, 255 | 1, 255 | 38 | | keep_batch_dim | ✅ | ✅ | ✅ | ✅ | ✅ | 39 | | return_log | ✅ | ✅ | ✅ | ❌ | ❌ | 40 | | ensemble_kernel | ✅ | ✅ | ✅ | ✅ | ✅ | 41 | 42 | ## Structural similarity index 43 | 44 | > When comparing images, the mean squared error (MSE)–while simple to implement–is not highly indicative of perceived similarity. Structural similarity aims to address this shortcoming by taking texture into account. More details can be seen at 45 | 46 | ![results](https://user-images.githubusercontent.com/26847524/175031400-92426661-4536-43c7-8f6e-5c470fb9ccb5.png) 47 | 48 | ```python 49 | import matplotlib.pyplot as plt 50 | import numpy as np 51 | import torch 52 | import torch.nn.functional as F 53 | from lartpang_ssim import SSIM 54 | from po_hsun_su_ssim import SSIM as PoHsunSuSSIM 55 | from vainf_ssim import MS_SSIM as VainFMSSSIM 56 | from vainf_ssim import SSIM as VainFSSIM 57 | from skimage import data, img_as_float 58 | 59 | img = img_as_float(data.camera()) 60 | rows, cols = img.shape 61 | 62 | noise = np.ones_like(img) * 0.3 * (img.max() - img.min()) 63 | rng = np.random.default_rng() 64 | noise[rng.random(size=noise.shape) > 0.5] *= -1 65 | 66 | img_noise = img + noise 67 | img_const = np.zeros_like(img) 68 | 69 | img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float() 70 | img_noise_tensor = torch.from_numpy(img_noise).unsqueeze(0).unsqueeze(0).float() 71 | img_const_tensor = torch.from_numpy(img_const).unsqueeze(0).unsqueeze(0).float() 72 | 73 | fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 7)) 74 | ax = axes.ravel() 75 | 76 | mse_none = F.mse_loss(img_tensor, img_tensor, reduction="mean") 77 | mse_noise = F.mse_loss(img_tensor, img_noise_tensor, reduction="mean") 78 | mse_const = F.mse_loss(img_tensor, img_const_tensor, reduction="mean") 79 | 80 | # https://github.com/VainF/pytorch-msssim 81 | vainf_ssim_none = VainFSSIM(channel=1, data_range=1)(img_tensor, img_tensor) 82 | vainf_ssim_noise = VainFSSIM(channel=1, data_range=1)(img_tensor, img_noise_tensor) 83 | vainf_ssim_const = VainFSSIM(channel=1, data_range=1)(img_tensor, img_const_tensor) 84 | vainf_ms_ssim_none = VainFMSSSIM(channel=1, data_range=1)(img_tensor, img_tensor) 85 | vainf_ms_ssim_noise = VainFMSSSIM(channel=1, data_range=1)(img_tensor, img_noise_tensor) 86 | vainf_ms_ssim_const = VainFMSSSIM(channel=1, data_range=1)(img_tensor, img_const_tensor) 87 | 88 | # use the settings of https://github.com/VainF/pytorch-msssim 89 | ssim_none_0 = SSIM(return_msssim=False, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_tensor) 90 | ssim_noise_0 = SSIM(return_msssim=False, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_noise_tensor) 91 | ssim_const_0 = SSIM(return_msssim=False, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_const_tensor) 92 | ms_ssim_none_0 = SSIM(return_msssim=True, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_tensor) 93 | ms_ssim_noise_0 = SSIM(return_msssim=True, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_noise_tensor) 94 | ms_ssim_const_0 = SSIM(return_msssim=True, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_const_tensor) 95 | 96 | # https://github.com/Po-Hsun-Su/pytorch-ssim 97 | pohsunsu_ssim_none = PoHsunSuSSIM()(img_tensor, img_tensor) 98 | pohsunsu_ssim_noise = PoHsunSuSSIM()(img_tensor, img_noise_tensor) 99 | pohsunsu_ssim_const = PoHsunSuSSIM()(img_tensor, img_const_tensor) 100 | 101 | # use the settings of https://github.com/Po-Hsun-Su/pytorch-ssim 102 | ssim_none_1 = SSIM(return_msssim=False, L=1, padding=None, ensemble_kernel=True)(img_tensor, img_tensor) 103 | ssim_noise_1 = SSIM(return_msssim=False, L=1, padding=None, ensemble_kernel=True)(img_tensor, img_noise_tensor) 104 | ssim_const_1 = SSIM(return_msssim=False, L=1, padding=None, ensemble_kernel=True)(img_tensor, img_const_tensor) 105 | 106 | 107 | ax[0].imshow(img, cmap=plt.cm.gray, vmin=0, vmax=1) 108 | ax[0].set_xlabel( 109 | f"MSE: {mse_none:.6f}\n" 110 | f"SSIM {ssim_none_0:.6f}, MS-SSIM {ms_ssim_none_0:.6f}\n" 111 | f"(VainF) SSIM: {vainf_ssim_none:.6f}, MS-SSIM {vainf_ms_ssim_none:.6f}\n" 112 | f"SSIM {ssim_none_1:.6f}\n" 113 | f"(PoHsunSu) SSIM: {pohsunsu_ssim_none:.6f}\n" 114 | ) 115 | ax[0].set_title("Original image") 116 | 117 | ax[1].imshow(img_noise, cmap=plt.cm.gray, vmin=0, vmax=1) 118 | ax[1].set_xlabel( 119 | f"MSE: {mse_noise:.6f}\n" 120 | f"SSIM {ssim_noise_0:.6f}, MS-SSIM {ms_ssim_noise_0:.6f}\n" 121 | f"(VainF) SSIM: {vainf_ssim_noise:.6f}, MS-SSIM {vainf_ms_ssim_noise:.6f}\n" 122 | f"SSIM {ssim_noise_1:.6f}\n" 123 | f"(PoHsunSu) SSIM: {pohsunsu_ssim_noise:.6f}\n" 124 | ) 125 | ax[1].set_title("Image with noise") 126 | 127 | ax[2].imshow(img_const, cmap=plt.cm.gray, vmin=0, vmax=1) 128 | ax[2].set_xlabel( 129 | f"MSE: {mse_const:.6f}\n" 130 | f"SSIM {ssim_const_0:.6f}, MS-SSIM {ms_ssim_const_0:.6f}\n" 131 | f"(VainF) SSIM: {vainf_ssim_const:.6f}, MS-SSIM {vainf_ms_ssim_const:.6f}\n" 132 | f"SSIM {ssim_const_1:.6f}\n" 133 | f"(PoHsunSu) SSIM: {pohsunsu_ssim_const:.6f}\n" 134 | ) 135 | ax[2].set_title("Image plus constant") 136 | 137 | 138 | [ax[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) for i in range(len(axes))] 139 | 140 | plt.tight_layout() 141 | plt.savefig("results.png") 142 | ``` 143 | 144 | ## More Examples 145 | 146 | ```python 147 | # setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim 148 | ssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda() 149 | 150 | # two 4d tensors 151 | x = torch.randn(3, 1, 100, 100).cuda() 152 | y = torch.randn(3, 1, 100, 100).cuda() 153 | ssim_score_0 = ssim_caller(x, y) 154 | # or in the fp16 mode (we have fixed the computation progress into the float32 mode to avoid the unexpected result) 155 | with torch.cuda.amp.autocast(enabled=True): 156 | ssim_score_1 = ssim_caller(x, y) 157 | assert torch.allclose(ssim_score_0, ssim_score_1) 158 | print(ssim_score_0.shape, ssim_score_1.shape) 159 | ``` 160 | 161 | ## As A Loss 162 | 163 | As you can see from the respective thresholds of the two cases below, it is easier to optimize towards MSSIM=1 than MSSIM=-1. 164 | 165 | ### Optimize towards MSSIM=1 166 | 167 | ![prediction](https://user-images.githubusercontent.com/26847524/174930091-9d7f7505-1752-423a-b7c3-d4dbfeb8d336.png) 168 | 169 | ```python 170 | import matplotlib.pyplot as plt 171 | import torch 172 | from pytorch_ssim import SSIM 173 | from skimage import data 174 | from torch import optim 175 | 176 | original_image = data.moon() / 255 177 | target_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda() 178 | predicted_image = torch.zeros_like( 179 | target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True 180 | ) 181 | initial_image = predicted_image.clone() 182 | 183 | ssim = SSIM().cuda() 184 | initial_ssim_value = ssim(predicted_image, target_image) 185 | 186 | ssim_value = initial_ssim_value 187 | optimizer = optim.Adam([predicted_image], lr=0.01) 188 | loss_curves = [] 189 | while ssim_value < 0.999: 190 | ssim_out = 1 - ssim(predicted_image, target_image) 191 | loss_curves.append(ssim_out.item()) 192 | ssim_value = 1 - ssim_out.item() 193 | print(ssim_value) 194 | ssim_out.backward() 195 | optimizer.step() 196 | optimizer.zero_grad() 197 | 198 | fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(8, 4)) 199 | ax = axes.ravel() 200 | 201 | ax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1) 202 | ax[0].set_title("Original Image") 203 | 204 | ax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1) 205 | ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.5f}") 206 | ax[1].set_title("Initial Image") 207 | 208 | ax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1) 209 | ax[2].set_xlabel(f"SSIM: {ssim_value:.5f}") 210 | ax[2].set_title("Predicted Image") 211 | 212 | ax[3].plot(loss_curves) 213 | ax[3].set_title("SSIM Loss Curve") 214 | 215 | ax[4].set_title("Original Image") 216 | ax[4].hist(original_image.ravel(), bins=256) 217 | ax[4].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0)) 218 | ax[4].set_xlabel("Pixel Intensity") 219 | 220 | ax[5].set_title("Initial Image") 221 | ax[5].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins=256) 222 | ax[5].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0)) 223 | ax[5].set_xlabel("Pixel Intensity") 224 | 225 | ax[6].set_title("Predicted Image") 226 | ax[6].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins=256) 227 | ax[6].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0)) 228 | ax[6].set_xlabel("Pixel Intensity") 229 | 230 | plt.tight_layout() 231 | plt.savefig("prediction.png") 232 | ``` 233 | 234 | ### Optimize towards MSSIM=-1 235 | 236 | ![prediction](https://user-images.githubusercontent.com/26847524/174929574-5332cab2-104f-4aab-a4e5-35e7635a793f.png) 237 | 238 | ```python 239 | import matplotlib.pyplot as plt 240 | import torch 241 | from pytorch_ssim import SSIM 242 | from skimage import data 243 | from torch import optim 244 | 245 | original_image = data.moon() / 255 246 | target_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda() 247 | predicted_image = torch.zeros_like( 248 | target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True 249 | ) 250 | initial_image = predicted_image.clone() 251 | 252 | ssim = SSIM(L=original_image.max() - original_image.min()).cuda() 253 | initial_ssim_value = ssim(predicted_image, target_image) 254 | 255 | ssim_value = initial_ssim_value 256 | optimizer = optim.Adam([predicted_image], lr=0.01) 257 | loss_curves = [] 258 | while ssim_value > -0.94: 259 | ssim_out = ssim(predicted_image, target_image) 260 | loss_curves.append(ssim_out.item()) 261 | ssim_value = ssim_out.item() 262 | print(ssim_value) 263 | ssim_out.backward() 264 | optimizer.step() 265 | optimizer.zero_grad() 266 | 267 | fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(8, 4)) 268 | ax = axes.ravel() 269 | 270 | ax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1) 271 | ax[0].set_title("Original Image") 272 | 273 | ax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1) 274 | ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.5f}") 275 | ax[1].set_title("Initial Image") 276 | 277 | ax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1) 278 | ax[2].set_xlabel(f"SSIM: {ssim_value:.5f}") 279 | ax[2].set_title("Predicted Image") 280 | 281 | ax[3].plot(loss_curves) 282 | ax[3].set_title("SSIM Loss Curve") 283 | 284 | ax[4].set_title("Original Image") 285 | ax[4].hist(original_image.ravel(), bins=256) 286 | ax[4].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0)) 287 | ax[4].set_xlabel("Pixel Intensity") 288 | 289 | ax[5].set_title("Initial Image") 290 | ax[5].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins=256) 291 | ax[5].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0)) 292 | ax[5].set_xlabel("Pixel Intensity") 293 | 294 | ax[6].set_title("Predicted Image") 295 | ax[6].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins=256) 296 | ax[6].ticklabel_format(axis="y", style="scientific", scilimits=(0, 0)) 297 | ax[6].set_xlabel("Pixel Intensity") 298 | 299 | plt.tight_layout() 300 | plt.savefig("prediction.png") 301 | ``` 302 | 303 | ## Reference 304 | 305 | * 306 | * 307 | * 308 | * Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, “Image quality assessment: From error visibility to structural similarity,” IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, Apr. 2004. 309 | 310 | ## Cite 311 | 312 | If you find this library useful, please cite our bibtex: 313 | 314 | ```bibtex 315 | @online{mssim.pytorch, 316 | author="lartpang", 317 | title="{A better pytorch-based implementation for the mean structural similarity. Differentiable simpler SSIM and MS-SSIM.}", 318 | url="https://github.com/lartpang/mssim.pytorch", 319 | note="(Jun 21, 2022)", 320 | } 321 | ``` 322 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | # Exclude a variety of commonly ignored directories. 3 | exclude = [ 4 | ".bzr", 5 | ".direnv", 6 | ".eggs", 7 | ".git", 8 | ".git-rewrite", 9 | ".hg", 10 | ".ipynb_checkpoints", 11 | ".mypy_cache", 12 | ".nox", 13 | ".pants.d", 14 | ".pyenv", 15 | ".pytest_cache", 16 | ".pytype", 17 | ".ruff_cache", 18 | ".svn", 19 | ".tox", 20 | ".venv", 21 | ".vscode", 22 | "__pypackages__", 23 | "_build", 24 | "buck-out", 25 | "build", 26 | "dist", 27 | "node_modules", 28 | "site-packages", 29 | "venv", 30 | ] 31 | 32 | # Same as Black. 33 | line-length = 118 34 | indent-width = 4 35 | 36 | [tool.ruff.lint] 37 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 38 | select = ["E4", "E7", "E9", "F"] 39 | ignore = [] 40 | 41 | # Allow fix for all enabled rules (when `--fix`) is provided. 42 | fixable = ["ALL"] 43 | unfixable = [] 44 | 45 | # Allow unused variables when underscore-prefixed. 46 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 47 | 48 | [tool.ruff.format] 49 | # Like Black, use double quotes for strings. 50 | quote-style = "double" 51 | 52 | # Like Black, indent with spaces, rather than tabs. 53 | indent-style = "space" 54 | 55 | # Like Black, respect magic trailing commas. 56 | skip-magic-trailing-comma = false 57 | 58 | # Like Black, automatically detect the appropriate line ending. 59 | line-ending = "auto" 60 | -------------------------------------------------------------------------------- /ssim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | FILTER = { 9 | 1: F.conv1d, 10 | 2: F.conv2d, 11 | 3: F.conv3d, 12 | } 13 | 14 | 15 | class GaussianFilter(nn.Module): 16 | def __init__(self, data_dim, window_size, in_channels, sigma, padding=None, ensemble_kernel=True): 17 | """Gaussian Filer for 1D, 2D or 3D data (3D/4D/5D tensor) 18 | 19 | Args: 20 | data_dim (int, optional): The dimension of the data. 21 | window_size (int or Tuple[int], optional): The window size of the gaussian filter. 22 | in_channels (int, optional): The number of channels of the 4d tensor. 23 | sigma (float or Tuple[float], optional): The sigma of the gaussian filter. 24 | padding (int or Tuple[int], optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0. 25 | ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True. 26 | """ 27 | super().__init__() 28 | if data_dim not in [1, 2, 3]: 29 | raise ValueError(f"data_dim must be 1, 2 or 3, but got {data_dim}.") 30 | self.data_dim = data_dim 31 | self.filter = FILTER[self.data_dim] 32 | 33 | if isinstance(window_size, int): 34 | window_size = [window_size] * self.data_dim 35 | if not all([w % 2 == 1 for w in window_size]): 36 | raise ValueError(f"Window size must be odd, but got {window_size}.") 37 | self.window_size = window_size 38 | 39 | if padding is None: 40 | padding = [w // 2 for w in window_size] 41 | if isinstance(padding, int): 42 | padding = [padding] * self.data_dim 43 | self.padding = padding 44 | 45 | if isinstance(sigma, (float, int)): 46 | sigma = [sigma] * self.data_dim 47 | self.sigma2 = [s**2 for s in sigma] 48 | 49 | assert len(self.window_size) == len(self.padding) == len(self.sigma2) == self.data_dim 50 | kernels = [self._get_gaussian_window1d(w, s2) for w, s2 in zip(self.window_size, self.sigma2)] 51 | 52 | self.ensemble_kernel = ensemble_kernel 53 | if self.ensemble_kernel: 54 | kernels = self._get_gaussian_windowNd(kernels) 55 | kernels = kernels.reshape(1, 1, *self.window_size).repeat_interleave(repeats=in_channels, dim=0) 56 | self.register_buffer(name="gaussian_window", tensor=kernels) 57 | else: 58 | for dim_idx, kernel in enumerate(kernels, start=2): 59 | base_shape = [1, 1] + [1] * self.data_dim 60 | base_shape[dim_idx] = -1 61 | kernel = kernel.reshape(*base_shape).repeat_interleave(repeats=in_channels, dim=0) 62 | if dim_idx == 2: 63 | name = "gaussian_window" 64 | else: 65 | name = f"gaussian_window_{dim_idx}" 66 | self.register_buffer(name=name, tensor=kernel) 67 | 68 | @staticmethod 69 | def _get_gaussian_window1d(window_size, sigma2): 70 | x = torch.arange(-(window_size // 2), window_size // 2 + 1) 71 | w = torch.exp(-0.5 * x**2 / sigma2) 72 | w = w / w.sum() 73 | return w 74 | 75 | def _get_gaussian_windowNd(self, gaussian_windows_1d): 76 | for dim_idx, kernel in enumerate(gaussian_windows_1d, start=2): 77 | base_shape = [1, 1] + [1] * self.data_dim 78 | base_shape[dim_idx] = -1 79 | kernel = kernel.reshape(*base_shape) 80 | if dim_idx == 2: 81 | w = kernel 82 | else: 83 | w = w * kernel 84 | return w 85 | 86 | def __repr__(self): 87 | base_str = f"{self.__class__.__name__} with Kernel: {self.gaussian_window.shape}" 88 | if not self.ensemble_kernel: 89 | for dim_idx in range(3, self.data_dim + 2): 90 | kernel = self.get_buffer(f"gaussian_window_{dim_idx}") 91 | base_str += f", {kernel.shape}" 92 | return base_str 93 | 94 | def forward(self, x): 95 | if self.ensemble_kernel: 96 | # ensemble kernel: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/3add4532d3f633316cba235da1c69e90f0dfb952/pytorch_ssim/__init__.py#L11-L15 97 | x = self.filter(input=x, weight=self.gaussian_window, stride=1, padding=self.padding, groups=x.shape[1]) 98 | else: 99 | # splitted kernel: https://github.com/VainF/pytorch-msssim/blob/2398f4db0abf44bcd3301cfadc1bf6c94788d416/pytorch_msssim/ssim.py#L48 100 | for i, d in enumerate(x.shape[2:], start=2): 101 | if d >= self.window_size[i - 2]: 102 | w = self.get_buffer(target="gaussian_window" if i == 2 else f"gaussian_window_{i}") 103 | x = self.filter(input=x, weight=w, stride=1, padding=self.padding, groups=x.shape[1]) 104 | else: 105 | warnings.warn( 106 | f"Skipping Gaussian Smoothing at dimension {i} for x: {x.shape} and window size: {self.window_size}" 107 | ) 108 | return x 109 | 110 | 111 | class SSIM(nn.Module): 112 | def __init__( 113 | self, 114 | window_size=11, 115 | in_channels=1, 116 | sigma=1.5, 117 | *, 118 | K1=0.01, 119 | K2=0.03, 120 | L=1, 121 | keep_batch_dim=False, 122 | data_dim=2, 123 | return_log=False, 124 | return_msssim=False, 125 | padding=None, 126 | ensemble_kernel=True, 127 | ): 128 | """Calculate the mean SSIM (MSSIM) between two 4D tensors. 129 | 130 | Args: 131 | window_size (int or Tuple[int], optional): The window size of the gaussian filter. Defaults to 11. 132 | in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False. 133 | sigma (float or Tuple[float], optional): The sigma of the gaussian filter. Defaults to 1.5. 134 | K1 (float, optional): K1 of MSSIM. Defaults to 0.01. 135 | K2 (float, optional): K2 of MSSIM. Defaults to 0.03. 136 | L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1. 137 | keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False. 138 | data_dim (int, optional): The dimension of the data. Defaults to 2, which means a 2d image (4d tensor). 139 | return_log (bool, optional): Whether to return the logarithmic form. Defaults to False. 140 | return_msssim (bool, optional): Whether to return the MS-SSIM score. Defaults to False, which will return the original MSSIM score. 141 | padding (int or Tuple[int], optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0. 142 | ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True. 143 | 144 | ``` 145 | # setting 0: for 4d float tensors with the data range [0, 1] and 1 channel 146 | ssim_caller = SSIM().cuda() 147 | # setting 1: for 4d float tensors with the data range [0, 1] and 3 channel 148 | ssim_caller = SSIM(in_channels=3).cuda() 149 | # setting 2: for 4d float tensors with the data range [0, 255] and 3 channel 150 | ssim_caller = SSIM(L=255, in_channels=3).cuda() 151 | # setting 3: for 4d float tensors with the data range [0, 255] and 3 channel, and return the logarithmic form 152 | ssim_caller = SSIM(L=255, in_channels=3, return_log=True).cuda() 153 | # setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim 154 | ssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda() 155 | # setting 5: for 4d float tensors with the data range [0, 1] and 1 channel, padding=0 and the splitted kernels. 156 | ssim_caller = SSIM(return_log=True, keep_batch_dim=True, padding=0, ensemble_kernel=False).cuda() 157 | 158 | # two 4d tensors 159 | x = torch.randn(3, 1, 100, 100).cuda() 160 | y = torch.randn(3, 1, 100, 100).cuda() 161 | ssim_score_0 = ssim_caller(x, y) 162 | # or in the fp16 mode (we have fixed the computation progress into the float32 mode to avoid the unexpected result) 163 | with torch.cuda.amp.autocast(enabled=True): 164 | ssim_score_1 = ssim_caller(x, y) 165 | assert torch.isclose(ssim_score_0, ssim_score_1) 166 | ``` 167 | 168 | Reference: 169 | [1] SSIM: Wang, Zhou et al. “Image quality assessment: from error visibility to structural similarity.” IEEE Transactions on Image Processing 13 (2004): 600-612. 170 | [2] MS-SSIM: Wang, Zhou et al. “Multi-scale structural similarity for image quality assessment.” (2003). 171 | """ 172 | super().__init__() 173 | self.data_dim = data_dim 174 | self.window_size = window_size 175 | self.C1 = (K1 * L) ** 2 # equ 7 in ref1 176 | self.C2 = (K2 * L) ** 2 # equ 7 in ref1 177 | self.keep_batch_dim = keep_batch_dim 178 | self.return_log = return_log 179 | self.return_msssim = return_msssim 180 | if self.return_msssim and self.return_log: 181 | raise ValueError("return_log only support return_msssim=False") 182 | if self.return_msssim and self.data_dim < 2: 183 | raise ValueError("return_msssim only support data_dim>=2") 184 | 185 | self.gaussian_filter = GaussianFilter( 186 | data_dim=self.data_dim, 187 | window_size=window_size, 188 | in_channels=in_channels, 189 | sigma=sigma, 190 | padding=padding, 191 | ensemble_kernel=ensemble_kernel, 192 | ) 193 | 194 | @torch.cuda.amp.autocast(enabled=False) 195 | def forward(self, x, y): 196 | """Calculate the mean SSIM (MSSIM) between two 3d/4d/5d tensors. 197 | 198 | Args: 199 | x (Tensor): 3d/4d/5d tensor 200 | y (Tensor): 3d/4d/5d tensor 201 | 202 | Returns: 203 | Tensor: MSSIM or MS-SSIM 204 | """ 205 | assert x.shape == y.shape, f"x: {x.shape} and y: {y.shape} must be the same" 206 | assert x.ndim == self.data_dim + 2, f"x: {x.ndim} and y: {y.ndim} must be {self.data_dim + 2}d tensors" 207 | if x.type() != self.gaussian_filter.gaussian_window.type(): 208 | x = x.type_as(self.gaussian_filter.gaussian_window) 209 | if y.type() != self.gaussian_filter.gaussian_window.type(): 210 | y = y.type_as(self.gaussian_filter.gaussian_window) 211 | 212 | if self.return_msssim: 213 | return self.msssim(x, y) 214 | else: 215 | return self.ssim(x, y) 216 | 217 | def ssim(self, x, y): 218 | ssim, _ = self._ssim(x, y) 219 | if self.return_log: 220 | # https://github.com/xuebinqin/BASNet/blob/56393818e239fed5a81d06d2a1abfe02af33e461/pytorch_ssim/__init__.py#L81-L83 221 | ssim = ssim - ssim.min() 222 | ssim = ssim / ssim.max() 223 | ssim = -torch.log(ssim + 1e-8) 224 | 225 | if self.keep_batch_dim: 226 | return ssim.flatten(1).mean(-1) 227 | else: 228 | return ssim.mean() 229 | 230 | def msssim(self, x, y): 231 | ms_components = [] 232 | for i, w in enumerate((0.0448, 0.2856, 0.3001, 0.2363, 0.1333)): 233 | ssim, cs = self._ssim(x, y) 234 | 235 | if self.keep_batch_dim: 236 | ssim = ssim.flatten(1).mean(-1) 237 | cs = cs.flatten(1).mean(-1) 238 | else: 239 | ssim = ssim.mean() 240 | cs = cs.mean() 241 | 242 | if i == 4: 243 | ms_components.append(ssim**w) 244 | else: 245 | ms_components.append(cs**w) 246 | bs, *c, h, w = x.shape 247 | padding = [s % 2 for s in (h, w)] # spatial padding 248 | if len(c) > 1: 249 | # only pooling in the spatial domain 250 | x = x.reshape(bs, -1, h, w) 251 | y = y.reshape(bs, -1, h, w) 252 | x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=padding) 253 | y = F.avg_pool2d(y, kernel_size=2, stride=2, padding=padding) 254 | if len(c) > 1: 255 | x = x.reshape(bs, *c, h // 2, w // 2) 256 | y = y.reshape(bs, *c, h // 2, w // 2) 257 | msssim = math.prod(ms_components) # equ 7 in ref2 258 | return msssim 259 | 260 | def _ssim(self, x, y): 261 | mu_x = self.gaussian_filter(x) # equ 14 262 | mu_y = self.gaussian_filter(y) # equ 14 263 | sigma2_x = self.gaussian_filter(x * x) - mu_x * mu_x # equ 15 264 | sigma2_y = self.gaussian_filter(y * y) - mu_y * mu_y # equ 15 265 | sigma_xy = self.gaussian_filter(x * y) - mu_x * mu_y # equ 16 266 | 267 | A1 = 2 * mu_x * mu_y + self.C1 268 | A2 = 2 * sigma_xy + self.C2 269 | B1 = mu_x * mu_x + mu_y * mu_y + self.C1 270 | B2 = sigma2_x + sigma2_y + self.C2 271 | 272 | # equ 12, 13 in ref1 273 | l = A1 / B1 274 | cs = A2 / B2 275 | ssim = l * cs 276 | return ssim, cs 277 | 278 | 279 | def ssim( 280 | x, 281 | y, 282 | *, 283 | window_size=11, 284 | in_channels=1, 285 | sigma=1.5, 286 | K1=0.01, 287 | K2=0.03, 288 | L=1, 289 | keep_batch_dim=False, 290 | data_dim=2, 291 | return_log=False, 292 | return_msssim=False, 293 | padding=None, 294 | ensemble_kernel=True, 295 | ): 296 | """Calculate the mean SSIM (MSSIM) between two 4D tensors. 297 | 298 | Args: 299 | x (Tensor): 4d tensor 300 | y (Tensor): 4d tensor 301 | window_size (int or Tuple[int], optional): The window size of the gaussian filter. Defaults to 11. 302 | in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False. 303 | sigma (float or Tuple[float], optional): The sigma of the gaussian filter. Defaults to 1.5. 304 | K1 (float, optional): K1 of MSSIM. Defaults to 0.01. 305 | K2 (float, optional): K2 of MSSIM. Defaults to 0.03. 306 | L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1. 307 | keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False. 308 | data_dim (int, optional): The dimension of the data. Defaults to 2, which means a 2d image (4d tensor). 309 | return_log (bool, optional): Whether to return the logarithmic form. Defaults to False. 310 | return_msssim (bool, optional): Whether to return the MS-SSIM score. Defaults to False, which will return the original MSSIM score. 311 | padding (int or Tuple[int], optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0. 312 | ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True. 313 | 314 | Returns: 315 | Tensor: MSSIM or MS-SSIM 316 | """ 317 | ssim_obj = SSIM( 318 | window_size=window_size, 319 | in_channels=in_channels, 320 | sigma=sigma, 321 | K1=K1, 322 | K2=K2, 323 | L=L, 324 | keep_batch_dim=keep_batch_dim, 325 | data_dim=data_dim, 326 | return_log=return_log, 327 | return_msssim=return_msssim, 328 | padding=padding, 329 | ensemble_kernel=ensemble_kernel, 330 | ).to(device=x.device) 331 | return ssim_obj(x, y) 332 | -------------------------------------------------------------------------------- /test/old_version.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class GaussianFilter2D(nn.Module): 10 | def __init__(self, window_size=11, in_channels=1, sigma=1.5, padding=None, ensemble_kernel=True): 11 | """2D Gaussian Filer 12 | 13 | Args: 14 | window_size (int, optional): The window size of the gaussian filter. Defaults to 11. 15 | in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False. 16 | sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5. 17 | padding (int, optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0. 18 | ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True. 19 | """ 20 | super().__init__() 21 | self.window_size = window_size 22 | if not (window_size % 2 == 1): 23 | raise ValueError("Window size must be odd.") 24 | self.padding = padding if padding is not None else window_size // 2 25 | self.sigma = sigma 26 | self.ensemble_kernel = ensemble_kernel 27 | 28 | kernel = self._get_gaussian_window1d() 29 | if ensemble_kernel: 30 | kernel = self._get_gaussian_window2d(kernel) 31 | self.register_buffer(name="gaussian_window", tensor=kernel.repeat(in_channels, 1, 1, 1)) 32 | 33 | def _get_gaussian_window1d(self): 34 | sigma2 = self.sigma * self.sigma 35 | x = torch.arange(-(self.window_size // 2), self.window_size // 2 + 1) 36 | w = torch.exp(-0.5 * x**2 / sigma2) 37 | w = w / w.sum() 38 | return w.reshape(1, 1, 1, self.window_size) 39 | 40 | def _get_gaussian_window2d(self, gaussian_window_1d): 41 | w = torch.matmul(gaussian_window_1d.transpose(dim0=-1, dim1=-2), gaussian_window_1d) 42 | return w 43 | 44 | def forward(self, x): 45 | if self.ensemble_kernel: 46 | # ensemble kernel: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/3add4532d3f633316cba235da1c69e90f0dfb952/pytorch_ssim/__init__.py#L11-L15 47 | x = F.conv2d(input=x, weight=self.gaussian_window, stride=1, padding=self.padding, groups=x.shape[1]) 48 | else: 49 | # splitted kernel: https://github.com/VainF/pytorch-msssim/blob/2398f4db0abf44bcd3301cfadc1bf6c94788d416/pytorch_msssim/ssim.py#L48 50 | for i, d in enumerate(x.shape[2:], start=2): 51 | if d >= self.window_size: 52 | w = self.gaussian_window.transpose(dim0=-1, dim1=i) 53 | x = F.conv2d(input=x, weight=w, stride=1, padding=self.padding, groups=x.shape[1]) 54 | else: 55 | warnings.warn( 56 | f"Skipping Gaussian Smoothing at dimension {i} for x: {x.shape} and window size: {self.window_size}" 57 | ) 58 | return x 59 | 60 | 61 | class SSIM(nn.Module): 62 | def __init__( 63 | self, 64 | window_size=11, 65 | in_channels=1, 66 | sigma=1.5, 67 | *, 68 | K1=0.01, 69 | K2=0.03, 70 | L=1, 71 | keep_batch_dim=False, 72 | return_log=False, 73 | return_msssim=False, 74 | padding=None, 75 | ensemble_kernel=True, 76 | ): 77 | """Calculate the mean SSIM (MSSIM) between two 4D tensors. 78 | 79 | Args: 80 | window_size (int, optional): The window size of the gaussian filter. Defaults to 11. 81 | in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False. 82 | sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5. 83 | K1 (float, optional): K1 of MSSIM. Defaults to 0.01. 84 | K2 (float, optional): K2 of MSSIM. Defaults to 0.03. 85 | L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1. 86 | keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False. 87 | return_log (bool, optional): Whether to return the logarithmic form. Defaults to False. 88 | return_msssim (bool, optional): Whether to return the MS-SSIM score. Defaults to False, which will return the original MSSIM score. 89 | padding (int, optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0. 90 | ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True. 91 | 92 | ``` 93 | # setting 0: for 4d float tensors with the data range [0, 1] and 1 channel 94 | ssim_caller = SSIM().cuda() 95 | # setting 1: for 4d float tensors with the data range [0, 1] and 3 channel 96 | ssim_caller = SSIM(in_channels=3).cuda() 97 | # setting 2: for 4d float tensors with the data range [0, 255] and 3 channel 98 | ssim_caller = SSIM(L=255, in_channels=3).cuda() 99 | # setting 3: for 4d float tensors with the data range [0, 255] and 3 channel, and return the logarithmic form 100 | ssim_caller = SSIM(L=255, in_channels=3, return_log=True).cuda() 101 | # setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim 102 | ssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda() 103 | # setting 5: for 4d float tensors with the data range [0, 1] and 1 channel, padding=0 and the splitted kernels. 104 | ssim_caller = SSIM(return_log=True, keep_batch_dim=True, padding=0, ensemble_kernel=False).cuda() 105 | 106 | # two 4d tensors 107 | x = torch.randn(3, 1, 100, 100).cuda() 108 | y = torch.randn(3, 1, 100, 100).cuda() 109 | ssim_score_0 = ssim_caller(x, y) 110 | # or in the fp16 mode (we have fixed the computation progress into the float32 mode to avoid the unexpected result) 111 | with torch.cuda.amp.autocast(enabled=True): 112 | ssim_score_1 = ssim_caller(x, y) 113 | assert torch.isclose(ssim_score_0, ssim_score_1) 114 | ``` 115 | 116 | Reference: 117 | [1] SSIM: Wang, Zhou et al. “Image quality assessment: from error visibility to structural similarity.” IEEE Transactions on Image Processing 13 (2004): 600-612. 118 | [2] MS-SSIM: Wang, Zhou et al. “Multi-scale structural similarity for image quality assessment.” (2003). 119 | """ 120 | super().__init__() 121 | self.window_size = window_size 122 | self.C1 = (K1 * L) ** 2 # equ 7 in ref1 123 | self.C2 = (K2 * L) ** 2 # equ 7 in ref1 124 | self.keep_batch_dim = keep_batch_dim 125 | self.return_log = return_log 126 | self.return_msssim = return_msssim 127 | 128 | self.gaussian_filter = GaussianFilter2D( 129 | window_size=window_size, 130 | in_channels=in_channels, 131 | sigma=sigma, 132 | padding=padding, 133 | ensemble_kernel=ensemble_kernel, 134 | ) 135 | 136 | @torch.cuda.amp.autocast(enabled=False) 137 | def forward(self, x, y): 138 | """Calculate the mean SSIM (MSSIM) between two 4d tensors. 139 | 140 | Args: 141 | x (Tensor): 4d tensor 142 | y (Tensor): 4d tensor 143 | 144 | Returns: 145 | Tensor: MSSIM or MS-SSIM 146 | """ 147 | assert x.shape == y.shape, f"x: {x.shape} and y: {y.shape} must be the same" 148 | assert x.ndim == y.ndim == 4, f"x: {x.ndim} and y: {y.ndim} must be 4" 149 | if x.type() != self.gaussian_filter.gaussian_window.type(): 150 | x = x.type_as(self.gaussian_filter.gaussian_window) 151 | if y.type() != self.gaussian_filter.gaussian_window.type(): 152 | y = y.type_as(self.gaussian_filter.gaussian_window) 153 | 154 | if self.return_msssim: 155 | return self.msssim(x, y) 156 | else: 157 | return self.ssim(x, y) 158 | 159 | def ssim(self, x, y): 160 | ssim, _ = self._ssim(x, y) 161 | if self.return_log: 162 | # https://github.com/xuebinqin/BASNet/blob/56393818e239fed5a81d06d2a1abfe02af33e461/pytorch_ssim/__init__.py#L81-L83 163 | ssim = ssim - ssim.min() 164 | ssim = ssim / ssim.max() 165 | ssim = -torch.log(ssim + 1e-8) 166 | 167 | if self.keep_batch_dim: 168 | return ssim.mean(dim=(1, 2, 3)) 169 | else: 170 | return ssim.mean() 171 | 172 | def msssim(self, x, y): 173 | ms_components = [] 174 | for i, w in enumerate((0.0448, 0.2856, 0.3001, 0.2363, 0.1333)): 175 | ssim, cs = self._ssim(x, y) 176 | 177 | if self.keep_batch_dim: 178 | ssim = ssim.mean(dim=(1, 2, 3)) 179 | cs = cs.mean(dim=(1, 2, 3)) 180 | else: 181 | ssim = ssim.mean() 182 | cs = cs.mean() 183 | 184 | if i == 4: 185 | ms_components.append(ssim**w) 186 | else: 187 | ms_components.append(cs**w) 188 | padding = [s % 2 for s in x.shape[2:]] # spatial padding 189 | x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=padding) 190 | y = F.avg_pool2d(y, kernel_size=2, stride=2, padding=padding) 191 | msssim = math.prod(ms_components) # equ 7 in ref2 192 | return msssim 193 | 194 | def _ssim(self, x, y): 195 | mu_x = self.gaussian_filter(x) # equ 14 196 | mu_y = self.gaussian_filter(y) # equ 14 197 | sigma2_x = self.gaussian_filter(x * x) - mu_x * mu_x # equ 15 198 | sigma2_y = self.gaussian_filter(y * y) - mu_y * mu_y # equ 15 199 | sigma_xy = self.gaussian_filter(x * y) - mu_x * mu_y # equ 16 200 | 201 | A1 = 2 * mu_x * mu_y + self.C1 202 | A2 = 2 * sigma_xy + self.C2 203 | B1 = mu_x * mu_x + mu_y * mu_y + self.C1 204 | B2 = sigma2_x + sigma2_y + self.C2 205 | 206 | # equ 12, 13 in ref1 207 | l = A1 / B1 208 | cs = A2 / B2 209 | ssim = l * cs 210 | return ssim, cs 211 | 212 | 213 | def ssim( 214 | x, 215 | y, 216 | *, 217 | window_size=11, 218 | in_channels=1, 219 | sigma=1.5, 220 | K1=0.01, 221 | K2=0.03, 222 | L=1, 223 | keep_batch_dim=False, 224 | return_log=False, 225 | return_msssim=False, 226 | padding=None, 227 | ensemble_kernel=True, 228 | ): 229 | """Calculate the mean SSIM (MSSIM) between two 4D tensors. 230 | 231 | Args: 232 | x (Tensor): 4d tensor 233 | y (Tensor): 4d tensor 234 | window_size (int, optional): The window size of the gaussian filter. Defaults to 11. 235 | in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False. 236 | sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5. 237 | K1 (float, optional): K1 of MSSIM. Defaults to 0.01. 238 | K2 (float, optional): K2 of MSSIM. Defaults to 0.03. 239 | L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1. 240 | keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False. 241 | return_log (bool, optional): Whether to return the logarithmic form. Defaults to False. 242 | return_msssim (bool, optional): Whether to return the MS-SSIM score. Defaults to False, which will return the original MSSIM score. 243 | padding (int, optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0. 244 | ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True. 245 | 246 | Returns: 247 | Tensor: MSSIM or MS-SSIM 248 | """ 249 | ssim_obj = SSIM( 250 | window_size=window_size, 251 | in_channels=in_channels, 252 | sigma=sigma, 253 | K1=K1, 254 | K2=K2, 255 | L=L, 256 | keep_batch_dim=keep_batch_dim, 257 | return_log=return_log, 258 | return_msssim=return_msssim, 259 | padding=padding, 260 | ensemble_kernel=ensemble_kernel, 261 | ).to(device=x.device) 262 | return ssim_obj(x, y) 263 | -------------------------------------------------------------------------------- /test/po_hsun_su_ssim.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) 10 | return gauss / gauss.sum() 11 | 12 | 13 | def create_window(window_size, channel): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 17 | return window 18 | 19 | 20 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 21 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 22 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 23 | 24 | mu1_sq = mu1.pow(2) 25 | mu2_sq = mu2.pow(2) 26 | mu1_mu2 = mu1 * mu2 27 | 28 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 29 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 30 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 31 | 32 | C1 = 0.01**2 33 | C2 = 0.03**2 34 | 35 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 36 | 37 | if size_average: 38 | return ssim_map.mean() 39 | else: 40 | return ssim_map.mean(1).mean(1).mean(1) 41 | 42 | 43 | class SSIM(torch.nn.Module): 44 | def __init__(self, window_size=11, size_average=True): 45 | super(SSIM, self).__init__() 46 | self.window_size = window_size 47 | self.size_average = size_average 48 | self.channel = 1 49 | self.window = create_window(window_size, self.channel) 50 | 51 | def forward(self, img1, img2): 52 | (_, channel, _, _) = img1.size() 53 | 54 | if channel == self.channel and self.window.data.type() == img1.data.type(): 55 | window = self.window 56 | else: 57 | window = create_window(self.window_size, channel) 58 | 59 | if img1.is_cuda: 60 | window = window.cuda(img1.get_device()) 61 | window = window.type_as(img1) 62 | 63 | self.window = window 64 | self.channel = channel 65 | 66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | 68 | 69 | def ssim(img1, img2, window_size=11, size_average=True): 70 | (_, channel, _, _) = img1.size() 71 | window = create_window(window_size, channel) 72 | 73 | if img1.is_cuda: 74 | window = window.cuda(img1.get_device()) 75 | window = window.type_as(img1) 76 | 77 | return _ssim(img1, img2, window, window_size, channel, size_average) 78 | -------------------------------------------------------------------------------- /test/test_ssim.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | 4 | import numpy as np 5 | import old_version 6 | import po_hsun_su_ssim 7 | import torch 8 | import vainf_ssim 9 | from skimage import data, img_as_float 10 | 11 | sys.path.append("..") 12 | import ssim 13 | 14 | 15 | class CheckSSIMTestCase(unittest.TestCase): 16 | @classmethod 17 | def setUpClass(cls): 18 | img = img_as_float(data.camera()) 19 | noise = np.ones_like(img) * 0.3 * (img.max() - img.min()) 20 | rng = np.random.default_rng(seed=20241204) 21 | noise[rng.random(size=noise.shape) > 0.5] *= -1 22 | img_noise = img + noise 23 | 24 | cls.x_2d = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=0).float() # 2,1,H,W 25 | cls.y_2d = torch.from_numpy(img_noise).unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=0).float() 26 | 27 | cls.x_1d = cls.x_2d.mean(dim=-1) # 2,1,H 28 | cls.y_1d = cls.y_2d.mean(dim=-1) # 2,1,H 29 | 30 | cls.x_3d = cls.x_2d.unsqueeze(dim=2).repeat_interleave(5, dim=2) # 2,1,5,H,W 31 | cls.y_3d = cls.y_2d.unsqueeze(dim=2).repeat_interleave(5, dim=2) # 2,1,5,H,W 32 | 33 | def test_ssim1d(self): 34 | our_ssim_score = ssim.ssim( 35 | self.x_1d, self.y_1d, return_msssim=False, L=1, padding=None, ensemble_kernel=True, data_dim=1 36 | ).item() 37 | self.assertEqual(our_ssim_score, 0.8740299940109253) 38 | 39 | def test_mssim1d(self): 40 | with self.assertRaises(ValueError): 41 | ssim.ssim(self.x_1d, self.y_1d, return_msssim=True, L=1, padding=None, ensemble_kernel=True, data_dim=1) 42 | 43 | def test_ssim3d(self): 44 | our_ssim_score = ssim.ssim( 45 | self.x_3d, self.y_3d, return_msssim=False, L=1, padding=None, ensemble_kernel=True, data_dim=3 46 | ).item() 47 | self.assertEqual(our_ssim_score, 0.4981585144996643) 48 | 49 | def test_mssim3d(self): 50 | our_ssim_score = ssim.ssim( 51 | self.x_3d, 52 | self.y_3d, 53 | return_msssim=True, 54 | L=1, 55 | ensemble_kernel=False, 56 | data_dim=3, 57 | window_size=(3, 11, 11), 58 | ).item() 59 | self.assertEqual(our_ssim_score, 0.8404632806777954) 60 | 61 | our_ssim_score = ssim.ssim( 62 | self.x_3d, 63 | self.y_3d, 64 | return_msssim=True, 65 | L=1, 66 | ensemble_kernel=True, 67 | data_dim=3, 68 | window_size=(3, 11, 11), 69 | ).item() 70 | self.assertEqual(our_ssim_score, 0.6245790719985962) 71 | 72 | def test_ssim2d_with_oldversion(self): 73 | kwargs = dict( 74 | window_size=11, L=1, keep_batch_dim=False, return_log=False, return_msssim=False, ensemble_kernel=True 75 | ) 76 | old_version_ssim_score = old_version.ssim(self.x_2d, self.y_2d).item() 77 | our_ssim_score = ssim.ssim(self.x_2d, self.y_2d, data_dim=2, **kwargs).item() 78 | self.assertEqual(our_ssim_score, old_version_ssim_score) 79 | 80 | kwargs = dict( 81 | window_size=5, L=1, keep_batch_dim=False, return_log=False, return_msssim=False, ensemble_kernel=True 82 | ) 83 | old_version_ssim_score = old_version.ssim(self.x_2d, self.y_2d, **kwargs).item() 84 | our_ssim_score = ssim.ssim(self.x_2d, self.y_2d, data_dim=2, **kwargs).item() 85 | self.assertEqual(our_ssim_score, old_version_ssim_score) 86 | 87 | kwargs = dict( 88 | window_size=5, L=1, keep_batch_dim=True, return_log=False, return_msssim=False, ensemble_kernel=True 89 | ) 90 | old_version_ssim_score = old_version.ssim(self.x_2d, self.y_2d, **kwargs).tolist() 91 | our_ssim_score = ssim.ssim(self.x_2d, self.y_2d, data_dim=2, **kwargs).tolist() 92 | self.assertEqual(our_ssim_score, old_version_ssim_score) 93 | 94 | kwargs = dict( 95 | window_size=5, L=1, keep_batch_dim=True, return_log=True, return_msssim=False, ensemble_kernel=True 96 | ) 97 | old_version_ssim_score = old_version.ssim(self.x_2d, self.y_2d, **kwargs).tolist() 98 | our_ssim_score = ssim.ssim(self.x_2d, self.y_2d, data_dim=2, **kwargs).tolist() 99 | self.assertEqual(our_ssim_score, old_version_ssim_score) 100 | 101 | kwargs = dict( 102 | window_size=5, L=1, keep_batch_dim=True, return_log=False, return_msssim=True, ensemble_kernel=True 103 | ) 104 | old_version_ssim_score = old_version.ssim(self.x_2d, self.y_2d, **kwargs).tolist() 105 | our_ssim_score = ssim.ssim(self.x_2d, self.y_2d, data_dim=2, **kwargs).tolist() 106 | self.assertEqual(our_ssim_score, old_version_ssim_score) 107 | 108 | kwargs = dict( 109 | window_size=5, L=1, keep_batch_dim=True, return_log=False, return_msssim=True, ensemble_kernel=False 110 | ) 111 | old_version_ssim_score = old_version.ssim(self.x_2d, self.y_2d, **kwargs).tolist() 112 | our_ssim_score = ssim.ssim(self.x_2d, self.y_2d, data_dim=2, **kwargs).tolist() 113 | self.assertEqual(our_ssim_score, old_version_ssim_score) 114 | 115 | def test_ssim2d_with_pohsunsu_method(self): 116 | # https://github.com/Po-Hsun-Su/pytorch-ssim 117 | po_hsun_su_ssim_score = po_hsun_su_ssim.ssim(self.x_2d, self.y_2d, window_size=11).item() 118 | # use the settings of https://github.com/Po-Hsun-Su/pytorch-ssim 119 | our_ssim_score = ssim.ssim(self.x_2d, self.y_2d, L=1, ensemble_kernel=True, data_dim=2, window_size=11).item() 120 | # 由于计算顺序的差异,导致存在一定的误差 121 | self.assertAlmostEqual(our_ssim_score, po_hsun_su_ssim_score) 122 | 123 | # https://github.com/Po-Hsun-Su/pytorch-ssim 124 | po_hsun_su_ssim_score = po_hsun_su_ssim.ssim(self.x_2d, self.y_2d, window_size=5).item() 125 | # use the settings of https://github.com/Po-Hsun-Su/pytorch-ssim 126 | our_ssim_score = ssim.ssim(self.x_2d, self.y_2d, L=1, ensemble_kernel=True, data_dim=2, window_size=5).item() 127 | # 由于计算顺序的差异,导致存在一定的误差 128 | self.assertAlmostEqual(our_ssim_score, po_hsun_su_ssim_score) 129 | 130 | def test_ssim2d_with_vainf_method(self): 131 | """https://github.com/VainF/pytorch-msssim 132 | 133 | VainF的方法中,最后先计算了空间上的均值,之后再平均其他维度,这与我们方法中直接整体平均的计算结果存在差异 134 | """ 135 | vainf_ssim_score = vainf_ssim.ssim(self.x_2d, self.y_2d, data_range=1, win_size=11).item() 136 | our_ssim_score = ssim.ssim(self.x_2d, self.y_2d, L=1, ensemble_kernel=False, padding=0, window_size=11).item() 137 | self.assertAlmostEqual(our_ssim_score, vainf_ssim_score) 138 | 139 | vainf_ssim_score = vainf_ssim.ssim(self.x_2d, self.y_2d, data_range=1, win_size=5).item() 140 | our_ssim_score = ssim.ssim(self.x_2d, self.y_2d, L=1, ensemble_kernel=False, padding=0, window_size=5).item() 141 | self.assertAlmostEqual(our_ssim_score, vainf_ssim_score) 142 | 143 | def test_msssim2d_with_vainf_method(self): 144 | """https://github.com/VainF/pytorch-msssim 145 | 146 | VainF的方法中,最后先计算了空间上的均值,之后再平均其他维度,这与我们方法中直接整体平均的计算结果存在差异 147 | """ 148 | vainf_ssim_score = vainf_ssim.ms_ssim(self.x_2d, self.y_2d, data_range=1, win_size=11).item() 149 | our_ssim_score = ssim.ssim( 150 | self.x_2d, self.y_2d, return_msssim=True, L=1, ensemble_kernel=False, window_size=11, padding=0 151 | ).item() 152 | self.assertAlmostEqual(our_ssim_score, vainf_ssim_score) 153 | 154 | vainf_ssim_score = vainf_ssim.ms_ssim(self.x_2d, self.y_2d, data_range=1, win_size=5).item() 155 | our_ssim_score = ssim.ssim( 156 | self.x_2d, self.y_2d, return_msssim=True, L=1, ensemble_kernel=False, window_size=5, padding=0 157 | ).item() 158 | self.assertAlmostEqual(our_ssim_score, vainf_ssim_score) 159 | 160 | 161 | if __name__ == "__main__": 162 | unittest.main() 163 | -------------------------------------------------------------------------------- /test/vainf_ssim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 by Gongfan Fang, Zhejiang University. 2 | # All rights reserved. 3 | import warnings 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | 10 | 11 | def _fspecial_gauss_1d(size: int, sigma: float) -> Tensor: 12 | r"""Create 1-D gauss kernel 13 | Args: 14 | size (int): the size of gauss kernel 15 | sigma (float): sigma of normal distribution 16 | Returns: 17 | torch.Tensor: 1D kernel (1 x 1 x size) 18 | """ 19 | coords = torch.arange(size, dtype=torch.float) 20 | coords -= size // 2 21 | 22 | g = torch.exp(-(coords**2) / (2 * sigma**2)) 23 | g /= g.sum() 24 | 25 | return g.unsqueeze(0).unsqueeze(0) 26 | 27 | 28 | def gaussian_filter(input: Tensor, win: Tensor) -> Tensor: 29 | r"""Blur input with 1-D kernel 30 | Args: 31 | input (torch.Tensor): a batch of tensors to be blurred 32 | window (torch.Tensor): 1-D gauss kernel 33 | Returns: 34 | torch.Tensor: blurred tensors 35 | """ 36 | assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape 37 | if len(input.shape) == 4: 38 | conv = F.conv2d 39 | elif len(input.shape) == 5: 40 | conv = F.conv3d 41 | else: 42 | raise NotImplementedError(input.shape) 43 | 44 | C = input.shape[1] 45 | out = input 46 | for i, s in enumerate(input.shape[2:]): 47 | if s >= win.shape[-1]: 48 | out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C) 49 | else: 50 | warnings.warn( 51 | f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}" 52 | ) 53 | 54 | return out 55 | 56 | 57 | def _ssim( 58 | X: Tensor, 59 | Y: Tensor, 60 | data_range: float, 61 | win: Tensor, 62 | size_average: bool = True, 63 | K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), 64 | ) -> Tuple[Tensor, Tensor]: 65 | r"""Calculate ssim index for X and Y 66 | 67 | Args: 68 | X (torch.Tensor): images 69 | Y (torch.Tensor): images 70 | data_range (float or int): value range of input images. (usually 1.0 or 255) 71 | win (torch.Tensor): 1-D gauss kernel 72 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 73 | 74 | Returns: 75 | Tuple[torch.Tensor, torch.Tensor]: ssim results. 76 | """ 77 | K1, K2 = K 78 | # batch, channel, [depth,] height, width = X.shape 79 | compensation = 1.0 80 | 81 | C1 = (K1 * data_range) ** 2 82 | C2 = (K2 * data_range) ** 2 83 | 84 | win = win.to(X.device, dtype=X.dtype) 85 | 86 | mu1 = gaussian_filter(X, win) 87 | mu2 = gaussian_filter(Y, win) 88 | 89 | mu1_sq = mu1.pow(2) 90 | mu2_sq = mu2.pow(2) 91 | mu1_mu2 = mu1 * mu2 92 | 93 | sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) 94 | sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) 95 | sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) 96 | 97 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 98 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map 99 | 100 | ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1) 101 | cs = torch.flatten(cs_map, 2).mean(-1) 102 | return ssim_per_channel, cs 103 | 104 | 105 | def ssim( 106 | X: Tensor, 107 | Y: Tensor, 108 | data_range: float = 255, 109 | size_average: bool = True, 110 | win_size: int = 11, 111 | win_sigma: float = 1.5, 112 | win: Optional[Tensor] = None, 113 | K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), 114 | nonnegative_ssim: bool = False, 115 | ) -> Tensor: 116 | r"""interface of ssim 117 | Args: 118 | X (torch.Tensor): a batch of images, (N,C,H,W) 119 | Y (torch.Tensor): a batch of images, (N,C,H,W) 120 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 121 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 122 | win_size: (int, optional): the size of gauss kernel 123 | win_sigma: (float, optional): sigma of normal distribution 124 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma 125 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 126 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu 127 | 128 | Returns: 129 | torch.Tensor: ssim results 130 | """ 131 | if not X.shape == Y.shape: 132 | raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.") 133 | 134 | for d in range(len(X.shape) - 1, 1, -1): 135 | X = X.squeeze(dim=d) 136 | Y = Y.squeeze(dim=d) 137 | 138 | if len(X.shape) not in (4, 5): 139 | raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") 140 | 141 | # if not X.type() == Y.type(): 142 | # raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.") 143 | 144 | if win is not None: # set win_size 145 | win_size = win.shape[-1] 146 | 147 | if not (win_size % 2 == 1): 148 | raise ValueError("Window size should be odd.") 149 | 150 | if win is None: 151 | win = _fspecial_gauss_1d(win_size, win_sigma) 152 | win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) 153 | 154 | ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K) 155 | if nonnegative_ssim: 156 | ssim_per_channel = torch.relu(ssim_per_channel) 157 | 158 | if size_average: 159 | return ssim_per_channel.mean() 160 | else: 161 | return ssim_per_channel.mean(1) 162 | 163 | 164 | def ms_ssim( 165 | X: Tensor, 166 | Y: Tensor, 167 | data_range: float = 255, 168 | size_average: bool = True, 169 | win_size: int = 11, 170 | win_sigma: float = 1.5, 171 | win: Optional[Tensor] = None, 172 | weights: Optional[List[float]] = None, 173 | K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), 174 | ) -> Tensor: 175 | r"""interface of ms-ssim 176 | Args: 177 | X (torch.Tensor): a batch of images, (N,C,[T,]H,W) 178 | Y (torch.Tensor): a batch of images, (N,C,[T,]H,W) 179 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 180 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 181 | win_size: (int, optional): the size of gauss kernel 182 | win_sigma: (float, optional): sigma of normal distribution 183 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma 184 | weights (list, optional): weights for different levels 185 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 186 | Returns: 187 | torch.Tensor: ms-ssim results 188 | """ 189 | if not X.shape == Y.shape: 190 | raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.") 191 | 192 | for d in range(len(X.shape) - 1, 1, -1): 193 | X = X.squeeze(dim=d) 194 | Y = Y.squeeze(dim=d) 195 | 196 | # if not X.type() == Y.type(): 197 | # raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.") 198 | 199 | if len(X.shape) == 4: 200 | avg_pool = F.avg_pool2d 201 | elif len(X.shape) == 5: 202 | avg_pool = F.avg_pool3d 203 | else: 204 | raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") 205 | 206 | if win is not None: # set win_size 207 | win_size = win.shape[-1] 208 | 209 | if not (win_size % 2 == 1): 210 | raise ValueError("Window size should be odd.") 211 | 212 | smaller_side = min(X.shape[-2:]) 213 | assert smaller_side > (win_size - 1) * (2**4), ( 214 | "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2**4)) 215 | ) 216 | 217 | if weights is None: 218 | weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] 219 | weights_tensor = X.new_tensor(weights) 220 | 221 | if win is None: 222 | win = _fspecial_gauss_1d(win_size, win_sigma) 223 | win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) 224 | 225 | levels = weights_tensor.shape[0] 226 | mcs = [] 227 | for i in range(levels): 228 | ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K) 229 | 230 | if i < levels - 1: 231 | mcs.append(torch.relu(cs)) 232 | padding = [s % 2 for s in X.shape[2:]] 233 | X = avg_pool(X, kernel_size=2, padding=padding) 234 | Y = avg_pool(Y, kernel_size=2, padding=padding) 235 | 236 | ssim_per_channel = torch.relu(ssim_per_channel) # type: ignore # (batch, channel) 237 | mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) # (level, batch, channel) 238 | ms_ssim_val = torch.prod(mcs_and_ssim ** weights_tensor.view(-1, 1, 1), dim=0) 239 | 240 | if size_average: 241 | return ms_ssim_val.mean() 242 | else: 243 | return ms_ssim_val.mean(1) 244 | 245 | 246 | class SSIM(torch.nn.Module): 247 | def __init__( 248 | self, 249 | data_range: float = 255, 250 | size_average: bool = True, 251 | win_size: int = 11, 252 | win_sigma: float = 1.5, 253 | channel: int = 3, 254 | spatial_dims: int = 2, 255 | K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), 256 | nonnegative_ssim: bool = False, 257 | ) -> None: 258 | r"""class for ssim 259 | Args: 260 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 261 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 262 | win_size: (int, optional): the size of gauss kernel 263 | win_sigma: (float, optional): sigma of normal distribution 264 | channel (int, optional): input channels (default: 3) 265 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 266 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu. 267 | """ 268 | 269 | super(SSIM, self).__init__() 270 | self.win_size = win_size 271 | self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) 272 | self.size_average = size_average 273 | self.data_range = data_range 274 | self.K = K 275 | self.nonnegative_ssim = nonnegative_ssim 276 | 277 | def forward(self, X: Tensor, Y: Tensor) -> Tensor: 278 | return ssim( 279 | X, 280 | Y, 281 | data_range=self.data_range, 282 | size_average=self.size_average, 283 | win=self.win, 284 | K=self.K, 285 | nonnegative_ssim=self.nonnegative_ssim, 286 | ) 287 | 288 | 289 | class MS_SSIM(torch.nn.Module): 290 | def __init__( 291 | self, 292 | data_range: float = 255, 293 | size_average: bool = True, 294 | win_size: int = 11, 295 | win_sigma: float = 1.5, 296 | channel: int = 3, 297 | spatial_dims: int = 2, 298 | weights: Optional[List[float]] = None, 299 | K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), 300 | ) -> None: 301 | r"""class for ms-ssim 302 | Args: 303 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 304 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 305 | win_size: (int, optional): the size of gauss kernel 306 | win_sigma: (float, optional): sigma of normal distribution 307 | channel (int, optional): input channels (default: 3) 308 | weights (list, optional): weights for different levels 309 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 310 | """ 311 | 312 | super(MS_SSIM, self).__init__() 313 | self.win_size = win_size 314 | self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) 315 | self.size_average = size_average 316 | self.data_range = data_range 317 | self.weights = weights 318 | self.K = K 319 | 320 | def forward(self, X: Tensor, Y: Tensor) -> Tensor: 321 | return ms_ssim( 322 | X, 323 | Y, 324 | data_range=self.data_range, 325 | size_average=self.size_average, 326 | win=self.win, 327 | weights=self.weights, 328 | K=self.K, 329 | ) 330 | --------------------------------------------------------------------------------