├── .gitignore ├── FSDS_code.py ├── FSDS_explanation.ipynb ├── FSDS_matlab_example.m ├── FrequencySpectrumDistributionSimilarity.m ├── HyRA Usage.ipynb ├── HyRA.py ├── HyRA_results ├── hr_fft.png ├── linear.png ├── linear_fft.png ├── lr_fft.png ├── non_linear.png └── non_linear_fft.png ├── Impulse_Responses.pptx ├── LPF2ISR.py ├── Low-pass filter to ISR.ipynb ├── README.md ├── cal_metrcis_and_table.py ├── example_figures ├── EDSRx4.png ├── baby.png ├── baby_x2.png ├── gt.png ├── impulse_response.png ├── lr.png └── sr.png ├── sr_results └── readme.md └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /FSDS_code.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import einops 4 | import numpy 5 | def frequency_align_integrate(tsr): 6 | """ 7 | :param tsr: CxHxW 8 | :return: 9 | """ 10 | # pad zeros to make H even 11 | 12 | if tsr.shape[1] % 2 == 1: 13 | tsr = torch.concat([tsr, torch.zeros(tsr.shape[0], 1, tsr.shape[2])], dim=1) 14 | 15 | tsr_rfft = torch.fft.rfft2(tsr, norm="backward") 16 | part_1 = einops.rearrange(tsr_rfft, "C (p H) W -> C p H W", p=2) 17 | part_2 = part_1[:, 1, :, :] # C H/2 W 18 | part_1 = part_1[:, 0, :, :] # C H/2 W 19 | part_2 = torch.flip(part_2, dims=[1]).contiguous() 20 | 21 | part_1 = torch.cumsum(part_1, dim=1) 22 | part_1 = torch.cumsum(part_1, dim=2) 23 | 24 | part_2 = torch.cumsum(part_2, dim=1) 25 | part_2 = torch.cumsum(part_2, dim=2) 26 | 27 | return (part_1, part_2) 28 | 29 | 30 | def _FrequencySpectrumDistributionSimilarity(pred, tar): 31 | """ 32 | 33 | :param pred: the predicted output, in shape CxHxW 34 | :param tar: the ground-truth data, in shape CxHxW 35 | :return: 36 | """ 37 | if pred.shape != tar.shape: 38 | raise ValueError("The shape of pred is expected to be the same as tar") 39 | 40 | pred = pred - torch.mean(pred) 41 | pred /= torch.std(pred) 42 | tar = tar - torch.mean(tar) 43 | tar /= torch.std(tar) 44 | 45 | pred_part_1, pred_part_2 = frequency_align_integrate(pred) # C H/2 W 46 | tar_part_1, tar_part_2 = frequency_align_integrate(tar) # C H/2 W 47 | 48 | part_1_error = ((pred_part_1 - tar_part_1).abs()) ** 2 49 | part_2_error = ((pred_part_2 - tar_part_2).abs()) ** 2 50 | 51 | return -10*math.log10(torch.sum(part_1_error + part_2_error) / torch.sum(tar_part_2.abs() ** 2 + tar_part_1.abs() ** 2)) 52 | 53 | 54 | def FrequencySpectrumDistributionSimilarity(pred, gt): 55 | if pred.shape != gt.shape: 56 | raise ValueError("The shape of input tensor does not match") 57 | if len(pred.shape) == 3: 58 | return _FrequencySpectrumDistributionSimilarity(pred, gt) 59 | elif len(pred.shape) == 4: 60 | index = [] 61 | for i in range(pred.shape[0]): 62 | index.append(_FrequencySpectrumDistributionSimilarity(pred[i], gt[i])) 63 | return index 64 | 65 | def matlab_interface(pred, gt): 66 | pred = numpy.array(pred) 67 | gt = numpy.array(gt) 68 | if len(pred.shape) != 3 or len(gt.shape) != 3: 69 | print("The input array should be 3D") 70 | return 0.0 71 | if pred.shape != gt.shape: 72 | print("The input array should have the same shape") 73 | return 0.0 74 | pred = torch.from_numpy(pred) 75 | gt = torch.from_numpy(gt) 76 | pred = einops.rearrange(pred, "H W C -> C H W") 77 | gt = einops.rearrange(gt, "H W C -> C H W") 78 | return FrequencySpectrumDistributionSimilarity(pred, gt) 79 | -------------------------------------------------------------------------------- /FSDS_matlab_example.m: -------------------------------------------------------------------------------- 1 | close all;clearvars; 2 | pred = imread('example_figures\EDSRx4.png'); 3 | gt = imread('example_figures\gt.png'); 4 | FrequencySpectrumDistributionSimilarity(single(pred), single(gt)) -------------------------------------------------------------------------------- /FrequencySpectrumDistributionSimilarity.m: -------------------------------------------------------------------------------- 1 | function scores = FrequencySpectrumDistributionSimilarity(pred, gt) 2 | pyenv('Version','C:\Users\RisingEntropy\scoop\apps\miniconda3-py311\current\envs\CV\python.exe'); 3 | module = py.importlib.import_module('FSDS_code'); 4 | scores = module.matlab_interface(single(pred), single(gt)); 5 | end 6 | -------------------------------------------------------------------------------- /HyRA.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | import torch 5 | from PIL import Image 6 | from einops import einops 7 | from torchvision.transforms import ToTensor 8 | import torch.nn.functional as F 9 | import matplotlib.pyplot as plt 10 | from utils import zero_interpolate_torch 11 | import argparse 12 | 13 | 14 | def normalize(ten:torch.Tensor): 15 | return (ten - torch.min(ten)) / (torch.max(ten) - torch.min(ten)) 16 | 17 | 18 | def convertNp(img:torch.Tensor): 19 | return einops.rearrange(img, "C H W -> H W C").numpy() 20 | 21 | 22 | def get_linear(lr:torch.Tensor, impulse_response:torch.Tensor, scale): 23 | impulse_response = impulse_response.unsqueeze(dim=1) 24 | lr_pad = F.pad(input=lr, 25 | pad=(impulse_response.shape[2] // (2 * scale), impulse_response.shape[2] // (2 * scale), 26 | impulse_response.shape[3] // (2 * scale), impulse_response.shape[3] // (2 * scale)), 27 | mode="reflect") 28 | lr_inter = zero_interpolate_torch(lr_pad, scale) 29 | lr_lp = F.conv2d(input=lr_inter, weight=impulse_response, stride=1, padding="valid", groups=3) 30 | return lr_lp 31 | 32 | 33 | def get_nonlinear(linear:torch.Tensor, sr:torch.Tensor): 34 | sr = F.interpolate(sr.unsqueeze(dim=0), size=(linear.shape[1], linear.shape[2])) 35 | return (sr - linear).squeeze() 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--lr", type=str, required=True) 41 | parser.add_argument("--impulse_response", type=str, required=True) 42 | parser.add_argument("--sr", type=str, required=True) 43 | parser.add_argument("--scale", type=int, required=True) 44 | parser.add_argument("--save_path", type=str, required=True) 45 | args = parser.parse_args() 46 | 47 | if args.scale <= 0: 48 | raise ValueError("Scale must be positive") 49 | converter = ToTensor() 50 | 51 | lr = converter(Image.open(args.lr)) 52 | impulse_response = converter(Image.open(args.impulse_response)) 53 | sr = converter(Image.open(args.sr)) 54 | 55 | cmp = matplotlib.colormaps["Blues"] 56 | cmp = cmp.reversed() 57 | linear = get_linear(lr, impulse_response, args.scale) 58 | non_linear = get_nonlinear(linear, sr) 59 | linear_fft = torch.log(torch.fft.fftshift(torch.fft.fft2(linear)).abs() + 1) 60 | non_linear_fft = torch.log(torch.fft.fftshift(torch.fft.fft2(non_linear)).abs() + 1) 61 | lr = F.interpolate(lr.unsqueeze(dim=0), size=(linear.shape[1], linear.shape[2]))[0] 62 | lr_fft = torch.log(torch.fft.fftshift(torch.fft.fft2(lr)).abs() + 1) 63 | sr_fft = torch.log(torch.fft.fftshift(torch.fft.fft2(sr)).abs() + 1) 64 | plt.imsave(os.path.join(args.save_path, "linear.png"), convertNp(linear / torch.max(linear))) 65 | plt.imsave(os.path.join(args.save_path, "non_linear.png"), convertNp(normalize(non_linear))) 66 | plt.imsave(os.path.join(args.save_path, "linear_fft.png"), linear_fft[0], cmap=cmp) # display only one channel 67 | plt.imsave(os.path.join(args.save_path, "non_linear_fft.png"), non_linear_fft[0], cmap=cmp) 68 | plt.imsave(os.path.join(args.save_path, "lr_fft.png"), lr_fft[0], cmap=cmp) 69 | plt.imsave(os.path.join(args.save_path, "hr_fft.png"), sr_fft[0], cmap=cmp) 70 | print("done!") 71 | 72 | 73 | if __name__ == "__main__": 74 | main() -------------------------------------------------------------------------------- /HyRA_results/hr_fft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/hr_fft.png -------------------------------------------------------------------------------- /HyRA_results/linear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/linear.png -------------------------------------------------------------------------------- /HyRA_results/linear_fft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/linear_fft.png -------------------------------------------------------------------------------- /HyRA_results/lr_fft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/lr_fft.png -------------------------------------------------------------------------------- /HyRA_results/non_linear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/non_linear.png -------------------------------------------------------------------------------- /HyRA_results/non_linear_fft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/non_linear_fft.png -------------------------------------------------------------------------------- /Impulse_Responses.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/Impulse_Responses.pptx -------------------------------------------------------------------------------- /LPF2ISR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | 5 | def sinc(tensor, omega): 6 | """ 7 | The sinc function implementation. sinc(t) is defined as sin(pi*t)/(pi*t), omega is a 8 | factor to adjust the scale 9 | :param tensor: variants of sinc function 10 | :param omega: scale factor 11 | :return: 12 | """ 13 | return torch.sin(torch.abs(math.pi * tensor * omega) + 1e-9) / (torch.abs(math.pi * tensor * omega) + 1e-9) 14 | 15 | 16 | def nearest_odd(num): 17 | return num + 1 if num % 2 == 0 else num 18 | 19 | 20 | def zero_interpolate_torch(img: torch.Tensor, scale: int): 21 | """ 22 | interpolate 0 by `scale` times 23 | :param img: NxCxHxW 24 | :param scale: 25 | :return: 26 | """ 27 | if len(img.shape) != 4: # batched 28 | img = img.unsqueeze(dim=0) 29 | img_ = img.reshape(-1, 1, img.shape[2], img.shape[3]) 30 | img_int = torch.concat( 31 | [img_, torch.zeros(img_.shape[0], scale * scale - 1, img_.shape[2], img_.shape[3]).to(img.device)], 32 | dim=1) 33 | return torch.nn.functional.pixel_shuffle(img_int, scale).reshape(img.shape[0], img.shape[1], img.shape[2] * scale, 34 | img.shape[3] * scale).squeeze(dim=0) 35 | 36 | 37 | def lpf_sr_single(img: torch.Tensor, scale: int, omega=3.): 38 | """ 39 | Interpolate an image using the sinc function, it's slower than the cubic or others. 40 | 41 | :param img: the image to be interpolated. 42 | :param size: the expected size 43 | :param omega: the factor to adjust the scale of the sinc function 44 | :return: the interpolated image 45 | :param backend: use torch or cuda code to apply zero-interpolate 46 | """ 47 | img_pad = F.pad(input=img, 48 | pad=(img.shape[2] // 2, img.shape[2] // 2, img.shape[3] // 2, img.shape[3] // 2), 49 | mode="reflect") 50 | target = zero_interpolate_torch(img_pad, scale) 51 | h_grid = torch.linspace(-1, 1, (img.shape[2] // 2) * scale * 2 + 1) 52 | w_grid = torch.linspace(-1, 1, (img.shape[3] // 2) * scale * 2 + 1) 53 | kernel = torch.meshgrid([h_grid, w_grid], indexing='xy') 54 | 55 | kernel = sinc(kernel[0], omega) * sinc(kernel[1], omega) 56 | kernel = kernel.unsqueeze(dim=0).unsqueeze(dim=0).to(img.device) 57 | # kernel.require_grad = False 58 | target = F.conv2d(input=target, weight=kernel, stride=1, padding="valid") 59 | for i in range(target.shape[0]): 60 | if torch.max(img[i])>0.001: # to avoid a all 0 image 61 | target[i] = (target[i] - torch.min(target[i]))/(torch.max(target[i])-torch.min(target[i])) * (torch.max(img[i])-torch.min(img[i])) + torch.min(img[i]) 62 | return target 63 | 64 | 65 | def lpf_sr(img: torch.Tensor, scale: int, omega=3.): 66 | """ 67 | Interpolate image(s) using the sinc function, it's slower than the cubic or others. 68 | :param img: the image to be interpolated. 69 | :param size: the expected size 70 | :param omega: the factor to adjust the scale of the sinc function 71 | :return: the interpolated image 72 | """ 73 | if len(img.shape) == 4: # Batched 74 | origin_shape = img.shape 75 | img = img.view(-1, 1, img.shape[2], img.shape[3]) 76 | out = lpf_sr_single(img, scale, omega) 77 | return out.reshape(origin_shape[0], origin_shape[1], origin_shape[2] * scale, 78 | origin_shape[3] * scale) 79 | else: 80 | origin_shape = img.shape 81 | img = img.view(-1, 1, img.shape[1], img.shape[2]) 82 | out = lpf_sr_single(img, scale, omega) 83 | return out.reshape(origin_shape[0], origin_shape[1] * scale, origin_shape[2] * scale) -------------------------------------------------------------------------------- /Low-pass filter to ISR.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "efe07f220f804a5f", 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "source": [ 10 | "# Low-pass filter to ISR\n", 11 | "We can use a simple low pass filter to ISR, here is the example." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "892215eb", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "!pip install torchmetrics\n", 22 | "!pip install einops" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "id": "dc9e595e35c92629", 29 | "metadata": { 30 | "ExecuteTime": { 31 | "end_time": "2024-02-01T09:01:33.162507300Z", 32 | "start_time": "2024-02-01T09:01:24.059457600Z" 33 | }, 34 | "collapsed": false 35 | }, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "LPF ISR performance: psnr=27.45184326171875, ssim=0.7809420824050903, fsds=27.509355337158876\n" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "import FSDS_code\n", 47 | "import utils\n", 48 | "from PIL import Image\n", 49 | "from torchvision.transforms import ToTensor\n", 50 | "from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure\n", 51 | "\n", 52 | "psnr = PeakSignalNoiseRatio(data_range=1).cuda()\n", 53 | "ssim = StructuralSimilarityIndexMeasure(data_range=1).cuda()\n", 54 | "\n", 55 | "lr = ToTensor()(Image.open(\"./example_figures/baby_x2.png\")).unsqueeze(0).cuda()\n", 56 | "gt = ToTensor()(Image.open(\"./example_figures/baby.png\")).unsqueeze(0).cuda()\n", 57 | "sr = utils.lpf_sr(img=lr, scale=2, omega=48.5)\n", 58 | "print(f\"LPF ISR performance: psnr={psnr(sr, gt)}, ssim={ssim(sr, gt)}, fsds={FSDS_code.FrequencySpectrumDistributionSimilarity(sr, gt)[0]}\")" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "id": "2abc40f7b9990b93", 64 | "metadata": { 65 | "collapsed": false 66 | }, 67 | "source": [ 68 | "# Implementation\n", 69 | "## 1) Zero-interpolation\n", 70 | "Before we apply a low-pass filter to the LR image, we first interpolate 0 to it to achieve the target size by:\n", 71 | "```python\n", 72 | "def zero_interpolate_torch(img: torch.Tensor, scale: int):\n", 73 | " \"\"\"\n", 74 | " interpolate 0 by `scale` times\n", 75 | " :param img: NxCxHxW\n", 76 | " :param scale:\n", 77 | " :return:\n", 78 | " \"\"\"\n", 79 | " if len(img.shape) != 4: # batched\n", 80 | " img = img.unsqueeze(dim=0)\n", 81 | " img_ = img.reshape(-1, 1, img.shape[2], img.shape[3])\n", 82 | " img_int = torch.concat(\n", 83 | " [img_, torch.zeros(img_.shape[0], scale * scale - 1, img_.shape[2], img_.shape[3]).to(img.device)],\n", 84 | " dim=1)\n", 85 | " return torch.nn.functional.pixel_shuffle(img_int, scale).reshape(img.shape[0], img.shape[1], img.shape[2] * scale,\n", 86 | " img.shape[3] * scale).squeeze(dim=0)\n", 87 | "\n", 88 | "```\n", 89 | "## 2) Low-pass filter\n", 90 | "Then we apply a low-pass filter to the interpolated image using convolution. The full implementation is:\n", 91 | "```python\n", 92 | "def lpf_sr_single(img: torch.Tensor, scale: int, omega=3.):\n", 93 | " \"\"\"\n", 94 | " Interpolate an image using the sinc function, it's slower than the cubic or others.\n", 95 | "\n", 96 | " :param img: the image to be interpolated.\n", 97 | " :param size: the expected size\n", 98 | " :param omega: the factor to adjust the scale of the sinc function\n", 99 | " :return: the interpolated image\n", 100 | " \"\"\"\n", 101 | " img_pad = F.pad(input=img,\n", 102 | " pad=(img.shape[2] // 2, img.shape[2] // 2, img.shape[3] // 2, img.shape[3] // 2),\n", 103 | " mode=\"reflect\")\n", 104 | " target = zero_interpolate_torch(img_pad, scale) # zero interpolate to the target size\n", 105 | " h_grid = torch.linspace(-1, 1, (img.shape[2] // 2) * scale * 2 + 1)\n", 106 | " w_grid = torch.linspace(-1, 1, (img.shape[3] // 2) * scale * 2 + 1)\n", 107 | " kernel = torch.meshgrid([h_grid, w_grid], indexing='xy')\n", 108 | "\n", 109 | " kernel = sinc(kernel[0], omega) * sinc(kernel[1], omega) # generate the low-passfilter, the sinc function with parameter omega\n", 110 | " kernel = kernel.unsqueeze(dim=0).unsqueeze(dim=0).to(img.device)\n", 111 | " # low-pass filtering, since the sinc function is symmetric, we can directly utilize the torch.nn.functional.conv2d\n", 112 | " target = F.conv2d(input=target, weight=kernel, stride=1, padding=\"valid\") \n", 113 | " for i in range(target.shape[0]):\n", 114 | " if torch.max(img[i])>1: # to avoid a all 0 image\n", 115 | " target[i] = (target[i] - torch.min(target[i]))/(torch.max(target[i])-torch.min(target[i])) * (torch.max(img[i])-torch.min(img[i])) + torch.min(img[i])\n", 116 | " return target\n", 117 | "```\n", 118 | "In the code above, the sinc function is defined as:\n", 119 | "```python\n", 120 | "def sinc(tensor, omega):\n", 121 | " \"\"\"\n", 122 | " The sinc function implementation. sinc(t) is defined as sin(pi*t)/(pi*t), omega is a\n", 123 | " factor to adjust the scale\n", 124 | " :param tensor: variants of sinc function\n", 125 | " :param omega: scale factor\n", 126 | " :return:\n", 127 | " \"\"\"\n", 128 | " return torch.sin(torch.abs(math.pi * tensor * omega) + 1e-9) / (torch.abs(math.pi * tensor * omega) + 1e-9)\n", 129 | "```" 130 | ] 131 | } 132 | ], 133 | "metadata": { 134 | "kernelspec": { 135 | "display_name": "Python 3 (ipykernel)", 136 | "language": "python", 137 | "name": "python3" 138 | }, 139 | "language_info": { 140 | "codemirror_mode": { 141 | "name": "ipython", 142 | "version": 3 143 | }, 144 | "file_extension": ".py", 145 | "mimetype": "text/x-python", 146 | "name": "python", 147 | "nbconvert_exporter": "python", 148 | "pygments_lexer": "ipython3", 149 | "version": "3.8.16" 150 | } 151 | }, 152 | "nbformat": 4, 153 | "nbformat_minor": 5 154 | } 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exploring the Low-Pass Filtering Behavior in Image Super-Resolution 2 | 3 | [![Static Badge](https://img.shields.io/badge/ICML_2024-Accepted-green)](https://icml.cc/virtual/2024/poster/35191) [![Static Badge](https://img.shields.io/badge/arXiv-2405.07919-brown?logo=arxiv) 4 | ](https://arxiv.org/abs/2405.07919)![](https://img.shields.io/badge/Code_with-PyTorch-orange?logo=pytorch) 5 | 6 | Haoyu Deng, Zijing Xu, Yule Duan, Xiao Wu, Wenjie Shu, Liang-Jian Deng 7 | 8 | 9 | Corresponding author 10 | 11 | 12 | If you have any questions, feel free to raise an issue or send a mail to academic@hydeng.cn. I will respond to you as soon as possible. If you think our work is useful, please give us a warmful citation: 13 | ``` 14 | @article{deng2024exploring, 15 | title={Exploring the Low-Pass Filtering Behavior in Image Super-Resolution}, 16 | author={Deng, Haoyu and Xu, Zijing and Duan, Yule and Wu, Xiao and Shu, Wenjie and Deng, Liang-Jian}, 17 | journal={arXiv preprint arXiv:2405.07919}, 18 | year={2024} 19 | } 20 | ``` 21 | 22 | ## TODOs 23 | - [x] matlab implementation of FSDS 24 | - [ ] make a video 25 | - [ ] try to use FSDS as a loss and share results 26 | 27 | ## Impulse Responses 28 | Please refer to `Impulse_Responses.pptx`. We provide two versions, with/without enhancement. We use the enhancement provided by PPT for better visualization. The enhanced figures are brighter in color while the unenhanced ones are mathematically closer to the sinc funtion. If your directly observe the output of networks (before clamp), you will find another feature of the sinc function: negative values near the main lobe. 29 | 30 | ## Hybrid Response Analysis (HyRA) 31 | We provide a script to directly analyze a network using HyRA, i.e., `HyRA.py`. Here is the usage 32 | 33 | ```bash 34 | python HyRA.py --lr [path to low-resolution image] --sr [path to super-resolution image, namely N(I) in the paper] --scale [the scale of super resolution] --impulse_response [path to impulse response] --save_path [path to save results] 35 | ``` 36 | 37 | We also provide a tutorial for the code. Please refer to `HyRA Usage.ipynb`. 38 | 39 | ## Frequency Spectrum Distribution Similarity (FSDS) 40 | ### Python Version 41 | FSDS describes the image quality from the perspective of frequency spectrum. The complete implementation of FSDS is in `FSDS_code.py`. We provide an explanation and a tutorial for the code. Please refer to `FSDS_explanation.ipynb`. 42 | 43 | ### Matlab Version 44 | To ensure that matlab version produces the **same result** as python version, the best way is to use FFI produced by matlab to invoke python code. Please edit the python intepreter path in `FrequencySpectrumDistributionSimilarity.m`. Make sure the intepreter produced are equipped with proper environment!!! An example for this can be found in `FSDS_matlab_example.m`. 45 | ## Experimental Results 46 | We provide super-resolution results and code for Tab.1. 47 | 48 | The super-resolution results can be found in: [https://huggingface.co/RisingEntropy/NNsAreLPFing/blob/main/sr_results.zip](https://huggingface.co/RisingEntropy/NNsAreLPFing/blob/main/sr_results.zip) 49 | 50 | Code that generates Tab.1 is in `cal_metrcis_and_table.py`. Please unzip `sr_results.zip` to the `sr_results` directory and run the code. The `sr_results` directory should look like: 51 | ``` 52 | sr_results 53 | ├─ArbSR 54 | │ ├─RCAN_x12 55 | │ ├─RCAN_x2 56 | │ ├─RCAN_x3 57 | │ ├─RCAN_x4 58 | │ ├─RCAN_x6 59 | │ └─RCAN_x8 60 | ├─Bicubic 61 | │ ├─Bicubic_x12 62 | │ ├─Bicubic_x18 63 | │ ├─Bicubic_x2 64 | │ ├─Bicubic_x3 65 | │ ├─Bicubic_x4 66 | │ └─Bicubic_x6 67 | ├─edsr_baseline 68 | │ ├─x2 69 | │ ├─x3 70 | │ └─x4 71 | ├─GRL 72 | │ ├─base 73 | │ │ ├─X2 74 | │ │ ├─X3 75 | │ │ └─X4 76 | │ ├─small 77 | │ │ ├─X2 78 | │ │ ├─X3 79 | │ │ └─X4 80 | │ └─tiny 81 | │ ├─X2 82 | │ ├─X3 83 | │ └─X4 84 | ├─HAT 85 | │ ├─HAT-S_SRx2 86 | │ ├─HAT-S_SRx3 87 | │ ├─HAT-S_SRx4 88 | │ ├─HAT_SRx2 89 | │ ├─HAT_SRx3 90 | │ └─HAT_SRx4 91 | ├─HDSRNet 92 | │ ├─X2 93 | │ ├─X3 94 | │ └─X4 95 | ├─hr 96 | ├─ITSRN 97 | │ ├─ITSRN_x12 98 | │ ├─ITSRN_x2 99 | │ ├─ITSRN_x3 100 | │ ├─ITSRN_x4 101 | │ └─ITSRN_x6 102 | ├─liif 103 | │ ├─edsr_x12 104 | │ ├─edsr_x18 105 | │ ├─edsr_x2 106 | │ ├─edsr_x3 107 | │ ├─edsr_x4 108 | │ ├─edsr_x6 109 | │ ├─rdn_x12 110 | │ ├─rdn_x18 111 | │ ├─rdn_x2 112 | │ ├─rdn_x3 113 | │ ├─rdn_x4 114 | │ └─rdn_x6 115 | ├─LTE 116 | │ ├─EDSR_baseline_x12 117 | │ ├─EDSR_baseline_x2 118 | │ ├─EDSR_baseline_x3 119 | │ ├─EDSR_baseline_x4 120 | │ ├─EDSR_baseline_x6 121 | │ ├─RDN_x12 122 | │ ├─RDN_x2 123 | │ ├─RDN_x3 124 | │ ├─RDN_x4 125 | │ ├─RDN_x6 126 | │ ├─SwinIR_x12 127 | │ ├─SwinIR_x2 128 | │ ├─SwinIR_x3 129 | │ ├─SwinIR_x4 130 | │ └─SwinIR_x6 131 | ├─OPESR 132 | │ ├─EDSR_x2 133 | │ ├─EDSR_x3 134 | │ ├─EDSR_x4 135 | │ ├─RDN_x2 136 | │ ├─RDN_x3 137 | │ └─RDN_x4 138 | ├─RDN 139 | │ ├─RDN_small_x2 140 | │ ├─RDN_small_x3 141 | │ └─RDN_small_x4 142 | ├─SRNO 143 | │ ├─EDSR_baseline_x12 144 | │ ├─EDSR_baseline_x2 145 | │ ├─EDSR_baseline_x3 146 | │ ├─EDSR_baseline_x4 147 | │ └─EDSR_baseline_x6 148 | └─SwinIR 149 | ├─swinir_classical_sr_x2 150 | ├─swinir_classical_sr_x3 151 | ├─swinir_classical_sr_x4 152 | └─swinir_classical_sr_x8 153 | ``` 154 | ## Acknowledgement 155 | We appreciate anonymous reviewers for their previous suggestions to help this paper better. Moreover, we would like to express our sincere gratitude to [Ruijie Zhu](https://github.com/ridgerchu) for his generous support in GPUs Without his support, it is hard for us to do experiments using full-scale DIV2K dataset. This work is supported by NSFC (12271083). -------------------------------------------------------------------------------- /cal_metrcis_and_table.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | import os 5 | import torchmetrics 6 | import numpy as np 7 | from PIL import Image 8 | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, \ 9 | LearnedPerceptualImagePatchSimilarity 10 | from torchvision.transforms import ToTensor 11 | from tqdm import tqdm 12 | 13 | from FrequencySpectrumDistributionSimilarity import FrequencySpectrumDistributionSimilarity 14 | 15 | minList = ["LPIPS", "L1InFFT", "L1", "L2InFFT", "L2"] 16 | max_dict = {} 17 | color_table = ["red", "blue"] 18 | 19 | 20 | class TableRow: 21 | def __init__(self, paper, overall_scales, overall_metrics, citation=None): 22 | self.data = {} 23 | self.paper = paper 24 | self.citation = citation 25 | self.overall_scales = overall_scales 26 | self.overall_metrics = overall_metrics 27 | 28 | def addMetric(self, metric, value, scale): 29 | if metric not in self.data.keys(): 30 | self.data[metric] = {} 31 | if scale not in self.data[metric].keys(): 32 | self.data[metric][scale] = [] 33 | self.data[metric][scale].append(value) 34 | 35 | def __str__(self): 36 | out = f"{self.paper}" 37 | if self.citation is not None: 38 | out += f"{self.citation}" 39 | 40 | for metric in self.overall_metrics: 41 | if metric not in self.data.keys(): 42 | print("ERROR") 43 | return 44 | for scale in self.overall_scales: 45 | if scale in self.data[metric].keys(): 46 | mean = np.mean(self.data[metric][scale]) 47 | mark = False 48 | index = max_dict[metric][scale].index(mean) 49 | # for i in range(len(color_table)): 50 | # if mean == max_dict[metric][scale][i]: 51 | # if mean < 0.01: 52 | # out += f"&\\textcolor{{ {color_table[i]} }} {{ {mean:.3e}({index+1}) }} " 53 | # else: 54 | # out += f"&\\textcolor{{{color_table[i]}}}{{{mean:.3f}({index+1}) }}" 55 | # mark = True 56 | # if not mark: 57 | # if mean < 0.01: 58 | # out += f"&{mean:.3e}({index+1}) " 59 | # else: 60 | # out += f"&{mean:.3f}({index+1}) " 61 | for i in range(len(color_table)): 62 | if mean == max_dict[metric][scale][i]: 63 | if metric!="LPIPS" and metric!="SSIM": 64 | if mean < 0.01: 65 | out += f"&\\textcolor{{{color_table[i]}}}{{{mean:.2e}}}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}" 66 | else: 67 | out += f"&\\textcolor{{{color_table[i]}}}{{{mean:.2f}}}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}" 68 | mark = True 69 | else: 70 | if mean < 0.01: 71 | out += f"&\\textcolor{{{color_table[i]}}}{{{mean:.3f}}}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}" 72 | else: 73 | out += f"&\\textcolor{{{color_table[i]}}}{{{mean:.3f}}}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}" 74 | mark = True 75 | if not mark: 76 | if metric!="LPIPS" and metric!="SSIM": 77 | if mean < 0.01: 78 | out += f"&{mean:.2e}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}" 79 | else: 80 | out += f"&{mean:.2f}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}" 81 | else: 82 | if mean < 0.01: 83 | out += f"&{mean:.3f}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}" 84 | else: 85 | out += f"&{mean:.3f}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}" 86 | 87 | else: 88 | out += "&-" 89 | return out + "\\\\" 90 | 91 | 92 | def l1_in_fft(gt, img): 93 | diff = (torch.fft.fftn(gt) - torch.fft.fftn(img)).abs() 94 | return (diff.sum() / diff.numel()).item() 95 | 96 | 97 | def l2_in_fft(gt, img): 98 | diff = (torch.fft.fftn(gt) - torch.fft.fftn(img)).abs() 99 | return (torch.sum(diff ** 2) / diff.numel()).item() 100 | 101 | 102 | def psnr_in_fft(gt, img): 103 | gt_fft = torch.fft.fftn(gt) 104 | img_fft = torch.fft.fftn(img) 105 | _psnr = PeakSignalNoiseRatio(data_range=torch.max(gt_fft.abs()).item()).cuda() 106 | return _psnr(gt_fft.abs(), img_fft.abs()).item() 107 | 108 | def psnr(gt, img): 109 | _psnr = PeakSignalNoiseRatio(data_range=1).cuda() 110 | return _psnr(gt, img).item() 111 | 112 | def l1(gt, img): 113 | return torch.nn.functional.l1_loss(gt, img).item() 114 | 115 | def l2(gt, img): 116 | return torch.nn.functional.mse_loss(gt, img).item() 117 | convertor = ToTensor() 118 | psnr = PeakSignalNoiseRatio(data_range=1).cuda() 119 | ssim = StructuralSimilarityIndexMeasure(data_range=1).cuda() 120 | fsds = FrequencySpectrumDistributionSimilarity 121 | lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze').cuda() 122 | l1 = torch.nn.functional.l1_loss 123 | l2 = torch.nn.functional.mse_loss 124 | row_item = { 125 | "EDSR": { 126 | "paper": "EDSR", 127 | "citation": "\citep{edsr}", 128 | "overall_scales": [2, 3, 4], 129 | "scale_path": { 130 | 2: "./sr_results/edsr_baseline/x2", 131 | 3: "./sr_results/edsr_baseline/x3", 132 | 4: "./sr_results/edsr_baseline/x4" 133 | }, 134 | }, 135 | 136 | "EDSR-LIIF": { 137 | "paper": "EDSR-LIIF", 138 | "citation": "\citep{liif}", 139 | "overall_scales": [2, 3, 4, 6, 12], 140 | "scale_path": { 141 | 2: "./sr_results/liif/edsr_x2", 142 | 3: "./sr_results/liif/edsr_x3", 143 | 4: "./sr_results/liif/edsr_x4", 144 | 6: "./sr_results/liif/edsr_x6", 145 | 12: "./sr_results/liif/edsr_x12" 146 | }, 147 | }, 148 | 149 | "EDSR-OPESR": { 150 | "paper": "EDSR-OPESR", 151 | "citation": "\citep{OPESR}", 152 | "overall_scales": [2, 3, 4], 153 | "scale_path": { 154 | 2: "./sr_results/OPESR/EDSR_x2", 155 | 3: "./sr_results/OPESR/EDSR_x3", 156 | 4: "./sr_results/OPESR/EDSR_x4" 157 | }, 158 | }, 159 | 160 | "EDSR-SRNO": { 161 | "paper": "EDSR-SRNO", 162 | "citation": "\citep{SRNO}", 163 | "overall_scales": [2, 3, 4, 6, 12], 164 | "scale_path": { 165 | 2: "./sr_results/SRNO/EDSR_baseline_x2", 166 | 3: "./sr_results/SRNO/EDSR_baseline_x3", 167 | 4: "./sr_results/SRNO/EDSR_baseline_x4", 168 | 6: "./sr_results/SRNO/EDSR_baseline_x6", 169 | 12: "./sr_results/SRNO/EDSR_baseline_x12" 170 | }, 171 | }, 172 | 173 | "EDSR-LTE": { 174 | "paper": "EDSR-LTE", 175 | "citation": "\citep{LTE}", 176 | "overall_scales": [2, 3, 4, ], 177 | "scale_path": { 178 | 2: "./sr_results/LTE/EDSR_baseline_x2", 179 | 3: "./sr_results/LTE/EDSR_baseline_x3", 180 | 4: "./sr_results/LTE/EDSR_baseline_x4" 181 | }, 182 | }, 183 | 184 | "RDN": { 185 | "paper": "RDN", 186 | "citation": "\citep{rdn}", 187 | "overall_scales": [2, 3, 4], 188 | "scale_path": { 189 | 2: "./sr_results/RDN/RDN_small_x2", 190 | 3: "./sr_results/RDN/RDN_small_x3", 191 | 4: "./sr_results/RDN/RDN_small_x4" 192 | }, 193 | }, 194 | 195 | "RDN-LIIF": { 196 | "paper": "RDN-LIIF", 197 | "citation": "\citep{liif}", 198 | "overall_scales": [2, 3, 4, 6, 12], 199 | "scale_path": { 200 | 2: "./sr_results/liif/rdn_x2", 201 | 3: "./sr_results/liif/rdn_x3", 202 | 4: "./sr_results/liif/rdn_x4", 203 | 6: "./sr_results/liif/rdn_x6", 204 | 12: "./sr_results/liif/rdn_x12" 205 | }, 206 | 207 | }, 208 | 209 | "RDN-OPESR": { 210 | "paper": "RDN-OPESR", 211 | "citation": "\citep{OPESR}", 212 | "overall_scales": [2, 3, 4], 213 | "scale_path": { 214 | 2: "./sr_results/OPESR/RDN_x2", 215 | 3: "./sr_results/OPESR/RDN_x3", 216 | 4: "./sr_results/OPESR/RDN_x4" 217 | }, 218 | }, 219 | 220 | "RDN-LTE": { 221 | "paper": "RDN-LTE", 222 | "citation": "\citep{LTE}", 223 | "overall_scales": [2, 3, 4, 6, 12], 224 | "scale_path": { 225 | 2: "./sr_results/LTE/RDN_x2", 226 | 3: "./sr_results/LTE/RDN_x3", 227 | 4: "./sr_results/LTE/RDN_x4", 228 | 6: "./sr_results/LTE/RDN_x6", 229 | 12: "./sr_results/LTE/RDN_x12" 230 | }, 231 | }, 232 | 233 | "SwinIR-classical": { 234 | "paper": "SwinIR-classical", 235 | "citation": "\citep{swinir}", 236 | "overall_scales": [2, 3, 4], 237 | "scale_path": { 238 | 2: "./sr_results/SwinIR/swinir_classical_sr_x2", 239 | 3: "./sr_results/SwinIR/swinir_classical_sr_x3", 240 | 4: "./sr_results/SwinIR/swinir_classical_sr_x4" 241 | }, 242 | }, 243 | 244 | "ITSRN": { 245 | "paper": "ITSRN", 246 | "citation": "\citep{ITSRN}", 247 | "overall_scales": [2, 3, 4, 6, 12], 248 | "scale_path": { 249 | 2: "./sr_results/ITSRN/ITSRN_x2", 250 | 3: "./sr_results/ITSRN/ITSRN_x3", 251 | 4: "./sr_results/ITSRN/ITSRN_x4", 252 | 6: "./sr_results/ITSRN/ITSRN_x6", 253 | 12: "./sr_results/ITSRN/ITSRN_x12" 254 | }, 255 | }, 256 | 257 | "Bicubic": { 258 | "paper": "Bicubic", 259 | "citation": "", 260 | "overall_scales": [2, 3, 4, 6, 12], 261 | "scale_path": { 262 | 2: "./sr_results/Bicubic/Bicubic_x2", 263 | 3: "./sr_results/Bicubic/Bicubic_x3", 264 | 4: "./sr_results/Bicubic/Bicubic_x4", 265 | 6: "./sr_results/Bicubic/Bicubic_x6", 266 | 12: "./sr_results/Bicubic/Bicubic_x12" 267 | }, 268 | }, 269 | 270 | "HAT-S": { 271 | "paper": "HAT-S", 272 | "citation": "\citep{hat}", 273 | "overall_scales": [2, 3, 4], 274 | "scale_path": { 275 | 2: "./sr_results/HAT/HAT-S_SRx2", 276 | 3: "./sr_results/HAT/HAT-S_SRx3", 277 | 4: "./sr_results/HAT/HAT-S_SRx4" 278 | }, 279 | }, 280 | 281 | "HAT": { 282 | "paper": "HAT", 283 | "citation": "\citep{hat}", 284 | "overall_scales": [2, 3, 4], 285 | "scale_path": { 286 | 2: "./sr_results/HAT/HAT_SRx2", 287 | 3: "./sr_results/HAT/HAT_SRx3", 288 | 4: "./sr_results/HAT/HAT_SRx4" 289 | }, 290 | }, 291 | 292 | "HDSRNet": { 293 | "paper": "HDSRNet", 294 | "citation": "\citep{hdsrnet}", 295 | "overall_scales": [2, 3, 4], 296 | "scale_path": { 297 | 2: "./sr_results/HDSRNet/X2", 298 | 3: "./sr_results/HDSRNet/X3", 299 | 4: "./sr_results/HDSRNet/X4" 300 | }, 301 | }, 302 | 303 | "GRLBase": { 304 | "paper": "GRLBase", 305 | "citation": "\citep{grl}", 306 | "overall_scales": [2, 3, 4], 307 | "scale_path": { 308 | 2: "./sr_results/GRL/base/X2", 309 | 3: "./sr_results/GRL/base/X3", 310 | 4: "./sr_results/GRL/base/X4" 311 | }, 312 | }, 313 | 314 | "GRLSmall": { 315 | "paper": "GRLSmal", 316 | "citation": "\citep{grl}", 317 | "overall_scales": [2, 3, 4], 318 | "scale_path": { 319 | 2: "./sr_results/GRL/small/X2", 320 | 3: "./sr_results/GRL/small/X3", 321 | 4: "./sr_results/GRL/small/X4" 322 | }, 323 | }, 324 | 325 | "GRLTiny": { 326 | "paper": "GRLTiny", 327 | "citation": "\citep{grl}", 328 | "overall_scales": [2, 3, 4], 329 | "scale_path": { 330 | 2: "./sr_results/GRL/tiny/X2", 331 | 3: "./sr_results/GRL/tiny/X3", 332 | 4: "./sr_results/GRL/tiny/X4" 333 | }, 334 | } 335 | } 336 | 337 | 338 | def getGTFileName(name): 339 | return name[0:4] + ".png" 340 | 341 | 342 | total_img = 0 343 | 344 | 345 | def validate_data(dic): 346 | for key in dic.keys(): 347 | for scale in dic[key]["overall_scales"]: 348 | path = dic[key]["scale_path"][scale] 349 | for file in os.listdir(path): 350 | img = Image.open(os.path.join(path, file)) 351 | gt = Image.open(os.path.join("./sr_results/hr", getGTFileName(file))) 352 | global total_img 353 | total_img += 1 354 | if img.size != gt.size: 355 | print(f"Size mismatch for {file}") 356 | return False 357 | return True 358 | 359 | 360 | 361 | with torch.no_grad(): 362 | validate_data(row_item) 363 | print(f"total:{total_img}") 364 | print("validate ok") 365 | pbar = tqdm(total=total_img) 366 | rows = [] 367 | gts = {} 368 | for file in os.listdir("./sr_results/hr"): 369 | gts[file] = convertor(Image.open(os.path.join("./sr_results/hr", file))).cuda() 370 | for net in row_item.keys(): 371 | row_item[net]["metrics"] = {} 372 | for scale in row_item[net]["overall_scales"]: 373 | path = row_item[net]["scale_path"][scale] 374 | row_item[net]["metrics"][scale] = {} 375 | for file in os.listdir(path): 376 | img = convertor(Image.open(os.path.join(path, file))).cuda() 377 | gt = gts[getGTFileName(file)] 378 | row_item[net]["metrics"][scale][file] = {"PSNR": psnr(gt.unsqueeze(0), img.unsqueeze(0)).item(), 379 | "SSIM": ssim(gt.unsqueeze(0), img.unsqueeze(0)).item(), 380 | "LPIPS": lpips(gt.unsqueeze(0), img.unsqueeze(0)).item(), 381 | "FSDS": fsds(gt, img), 382 | "L1InFFT": l1_in_fft(gt, img), 383 | "L1": l1(gt, img).item(), 384 | "L2InFFT": l2_in_fft(gt, img), 385 | "L2": l2(gt, img, reduction="sum").item(), 386 | "PSNRInFFT": psnr_in_fft(gt, img)} 387 | 388 | pbar.update(1) 389 | 390 | with open("div2k_metrics_icml_rebuttal.json", "w") as f: 391 | json.dump(row_item, f, indent=4) 392 | 393 | json_text = "" 394 | with open("div2k_metrics_icml_rebuttal.json", "r") as f: 395 | json_text = f.read() 396 | dic = json.loads(json_text) 397 | row_item = dic 398 | rows = [] 399 | # all_metrics = ["PSNR", "SSIM", "LPIPS", "FSDS", "L1InFFT", "L1", "L2InFFT", "L2", "PSNRInFFT"] 400 | # all_metrics = ["FSDS", "L1InFFT", "L1", "L2InFFT", "L2", "PSNRInFFT"] 401 | # all_metrics = ["FSDS", "PSNRInFFT"] 402 | all_metrics = ["PSNR", "SSIM", "LPIPS", "FSDS"] 403 | for key in row_item.keys(): 404 | row = TableRow(row_item[key]["paper"], [2, 3, 4, 6, 12], all_metrics,None) 405 | for scale in row_item[key]["overall_scales"]: 406 | for item in row_item[key]["metrics"][str(scale)]: 407 | row.addMetric("PSNR", row_item[key]["metrics"][str(scale)][item]["PSNR"], scale) 408 | row.addMetric("SSIM", row_item[key]["metrics"][str(scale)][item]["SSIM"], scale) 409 | row.addMetric("LPIPS", row_item[key]["metrics"][str(scale)][item]["LPIPS"], scale) 410 | row.addMetric("FSDS", row_item[key]["metrics"][str(scale)][item]["FSDS"], scale) 411 | # row.addMetric("PSNRInFFT", row_item[key]["metrics"][str(scale)][item]["PSNRInFFT"], scale) 412 | # row.addMetric("PSNR", row_item[key]["metrics"][str(scale)][item]["PSNR"], scale) 413 | # row.addMetric("L1InFFT", row_item[key]["metrics"][str(scale)][item]["L1InFFT"], scale) 414 | # row.addMetric("L1", row_item[key]["metrics"][str(scale)][item]["L1"], scale) 415 | # row.addMetric("L2InFFT", row_item[key]["metrics"][str(scale)][item]["L2InFFT"], scale) 416 | # row.addMetric("L2", row_item[key]["metrics"][str(scale)][item]["L2"], scale) 417 | # row.addMetric("PSNRInFFT", row_item[key]["metrics"][str(scale)][item]["PSNRInFFT"], scale) 418 | 419 | rows.append(row) 420 | 421 | for metrics in all_metrics: 422 | max_dict[metrics] = {} 423 | for scale in [2, 3, 4, 6, 12]: 424 | if scale not in max_dict[metrics].keys(): 425 | max_dict[metrics][scale] = [] 426 | for row in rows: 427 | if scale in row.data[metrics].keys(): 428 | max_dict[metrics][scale].append(np.mean(row.data[metrics][scale])) 429 | if metrics not in minList: 430 | max_dict[metrics][scale].sort(reverse=True) 431 | else: 432 | max_dict[metrics][scale].sort() 433 | 434 | for row in rows: 435 | print(row) -------------------------------------------------------------------------------- /example_figures/EDSRx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/EDSRx4.png -------------------------------------------------------------------------------- /example_figures/baby.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/baby.png -------------------------------------------------------------------------------- /example_figures/baby_x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/baby_x2.png -------------------------------------------------------------------------------- /example_figures/gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/gt.png -------------------------------------------------------------------------------- /example_figures/impulse_response.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/impulse_response.png -------------------------------------------------------------------------------- /example_figures/lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/lr.png -------------------------------------------------------------------------------- /example_figures/sr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/sr.png -------------------------------------------------------------------------------- /sr_results/readme.md: -------------------------------------------------------------------------------- 1 | Please download super-resolution results from [https://huggingface.co/RisingEntropy/NNsAreLPFing/blob/main/sr_results.zip](https://huggingface.co/RisingEntropy/NNsAreLPFing/blob/main/sr_results.zip) and unzip it. 2 | This directory should be like: 3 | ``` 4 | ├─ArbSR 5 | │ ├─RCAN_x12 6 | │ ├─RCAN_x2 7 | │ ├─RCAN_x3 8 | │ ├─RCAN_x4 9 | │ ├─RCAN_x6 10 | │ └─RCAN_x8 11 | ├─Bicubic 12 | │ ├─Bicubic_x12 13 | │ ├─Bicubic_x18 14 | │ ├─Bicubic_x2 15 | │ ├─Bicubic_x3 16 | │ ├─Bicubic_x4 17 | │ └─Bicubic_x6 18 | ├─edsr_baseline 19 | │ ├─x2 20 | │ ├─x3 21 | │ └─x4 22 | ├─GRL 23 | │ ├─base 24 | │ │ ├─X2 25 | │ │ ├─X3 26 | │ │ └─X4 27 | │ ├─small 28 | │ │ ├─X2 29 | │ │ ├─X3 30 | │ │ └─X4 31 | │ └─tiny 32 | │ ├─X2 33 | │ ├─X3 34 | │ └─X4 35 | ├─HAT 36 | │ ├─HAT-S_SRx2 37 | │ ├─HAT-S_SRx3 38 | │ ├─HAT-S_SRx4 39 | │ ├─HAT_SRx2 40 | │ ├─HAT_SRx3 41 | │ └─HAT_SRx4 42 | ├─HDSRNet 43 | │ ├─X2 44 | │ ├─X3 45 | │ └─X4 46 | ├─hr 47 | ├─ITSRN 48 | │ ├─ITSRN_x12 49 | │ ├─ITSRN_x2 50 | │ ├─ITSRN_x3 51 | │ ├─ITSRN_x4 52 | │ └─ITSRN_x6 53 | ├─liif 54 | │ ├─edsr_x12 55 | │ ├─edsr_x18 56 | │ ├─edsr_x2 57 | │ ├─edsr_x3 58 | │ ├─edsr_x4 59 | │ ├─edsr_x6 60 | │ ├─rdn_x12 61 | │ ├─rdn_x18 62 | │ ├─rdn_x2 63 | │ ├─rdn_x3 64 | │ ├─rdn_x4 65 | │ └─rdn_x6 66 | ├─LTE 67 | │ ├─EDSR_baseline_x12 68 | │ ├─EDSR_baseline_x2 69 | │ ├─EDSR_baseline_x3 70 | │ ├─EDSR_baseline_x4 71 | │ ├─EDSR_baseline_x6 72 | │ ├─RDN_x12 73 | │ ├─RDN_x2 74 | │ ├─RDN_x3 75 | │ ├─RDN_x4 76 | │ ├─RDN_x6 77 | │ ├─SwinIR_x12 78 | │ ├─SwinIR_x2 79 | │ ├─SwinIR_x3 80 | │ ├─SwinIR_x4 81 | │ └─SwinIR_x6 82 | ├─OPESR 83 | │ ├─EDSR_x2 84 | │ ├─EDSR_x3 85 | │ ├─EDSR_x4 86 | │ ├─RDN_x2 87 | │ ├─RDN_x3 88 | │ └─RDN_x4 89 | ├─RDN 90 | │ ├─RDN_small_x2 91 | │ ├─RDN_small_x3 92 | │ └─RDN_small_x4 93 | ├─SRNO 94 | │ ├─EDSR_baseline_x12 95 | │ ├─EDSR_baseline_x2 96 | │ ├─EDSR_baseline_x3 97 | │ ├─EDSR_baseline_x4 98 | │ └─EDSR_baseline_x6 99 | └─SwinIR 100 | ├─swinir_classical_sr_x2 101 | ├─swinir_classical_sr_x3 102 | ├─swinir_classical_sr_x4 103 | └─swinir_classical_sr_x8 104 | ``` -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import warnings 3 | from datetime import datetime 4 | 5 | import einops 6 | import torch 7 | import math 8 | import torch.nn.functional as F 9 | 10 | def sinc(tensor, omega): 11 | """ 12 | The sinc function implementation. sinc(t) is defined as sin(pi*t)/(pi*t), omega is a 13 | factor to adjust the scale 14 | :param tensor: variants of sinc function 15 | :param omega: scale factor 16 | :return: 17 | """ 18 | return torch.sin(torch.abs(math.pi * tensor * omega) + 1e-9) / (torch.abs(math.pi * tensor * omega) + 1e-9) 19 | 20 | 21 | def nearest_odd(num): 22 | return num + 1 if num % 2 == 0 else num 23 | 24 | 25 | def zero_interpolate_torch(img: torch.Tensor, scale: int): 26 | """ 27 | interpolate 0 by `scale` times 28 | :param img: NxCxHxW 29 | :param scale: 30 | :return: 31 | """ 32 | if len(img.shape) != 4: # batched 33 | img = img.unsqueeze(dim=0) 34 | img_ = img.reshape(-1, 1, img.shape[2], img.shape[3]) 35 | img_int = torch.concat( 36 | [img_, torch.zeros(img_.shape[0], scale * scale - 1, img_.shape[2], img_.shape[3]).to(img.device)], 37 | dim=1) 38 | return torch.nn.functional.pixel_shuffle(img_int, scale).reshape(img.shape[0], img.shape[1], img.shape[2] * scale, 39 | img.shape[3] * scale).squeeze(dim=0) 40 | 41 | 42 | def lpf_sr_single(img: torch.Tensor, scale: int, omega=3., rgb_range = 255): 43 | """ 44 | Interpolate an image using the sinc function, it's slower than the cubic or others. 45 | 46 | :param img: the image to be interpolated. 47 | :param size: the expected size 48 | :param omega: the factor to adjust the scale of the sinc function 49 | :return: the interpolated image 50 | :param backend: use torch or cuda code to apply zero-interpolate 51 | """ 52 | img_pad = F.pad(input=img, 53 | pad=(img.shape[2] // 2, img.shape[2] // 2, img.shape[3] // 2, img.shape[3] // 2), 54 | mode="reflect") 55 | target = zero_interpolate_torch(img_pad, scale) 56 | h_grid = torch.linspace(-1, 1, (img.shape[2] // 2) * scale * 2 + 1) 57 | w_grid = torch.linspace(-1, 1, (img.shape[3] // 2) * scale * 2 + 1) 58 | kernel = torch.meshgrid([h_grid, w_grid], indexing='xy') 59 | 60 | kernel = sinc(kernel[0], omega) * sinc(kernel[1], omega) 61 | kernel = kernel.unsqueeze(dim=0).unsqueeze(dim=0).to(img.device) 62 | # kernel.require_grad = False 63 | target = F.conv2d(input=target, weight=kernel, stride=1, padding="valid") 64 | for i in range(target.shape[0]): 65 | if torch.max(img[i])>=0.01: # to avoid a all 0 image 66 | target[i] = (target[i] - torch.min(target[i]))/(torch.max(target[i])-torch.min(target[i])) * (torch.max(img[i])-torch.min(img[i])) + torch.min(img[i]) 67 | return target 68 | 69 | 70 | 71 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None): 72 | if hr.nelement() == 1: return 0 73 | 74 | diff = (sr - hr) / rgb_range 75 | if dataset and dataset.dataset.benchmark: 76 | shave = scale 77 | if diff.size(1) > 1: 78 | gray_coeffs = [65.738, 129.057, 25.064] 79 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 80 | diff = diff.mul(convert).sum(dim=1) 81 | else: 82 | shave = scale + 6 83 | 84 | valid = diff[..., shave:-shave, shave:-shave] 85 | mse = valid.pow(2).mean() 86 | if abs(mse)<1e-6: 87 | return 100 88 | return -10 * math.log10(mse) 89 | 90 | 91 | def psnr(a, b, range=255): 92 | return calc_psnr(a, b, 2, range) 93 | 94 | 95 | def lpf_sr(img: torch.Tensor, scale: int, omega=3., rgb_range=1): 96 | """ 97 | Interpolate image(s) using the sinc function, it's slower than the cubic or others. 98 | :param img: the image to be interpolated. 99 | :param size: the expected size 100 | :param omega: the factor to adjust the scale of the sinc function 101 | :return: the interpolated image 102 | """ 103 | if len(img.shape) == 4: # Batched 104 | origin_shape = img.shape 105 | img = img.view(-1, 1, img.shape[2], img.shape[3]) 106 | out = lpf_sr_single(img, scale, omega, rgb_range=rgb_range) 107 | return out.reshape(origin_shape[0], origin_shape[1], origin_shape[2] * scale, 108 | origin_shape[3] * scale) 109 | else: 110 | origin_shape = img.shape 111 | img = img.view(-1, 1, img.shape[1], img.shape[2]) 112 | out = lpf_sr_single(img, scale, omega, rgb_range=rgb_range) 113 | return out.reshape(origin_shape[0], origin_shape[1] * scale, origin_shape[2] * scale) 114 | 115 | 116 | def quantize(img, rgb_range): 117 | pixel_range = 255 / rgb_range 118 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 119 | 120 | 121 | class FileLogger: 122 | def __init__(self, exp_name): 123 | if not os.path.exists(os.path.join("ExperimentLogs", exp_name)): 124 | os.makedirs(os.path.join("ExperimentLogs", exp_name)) 125 | self.exp_name = exp_name 126 | self.filename = os.path.join("ExperimentLogs", exp_name, "log output.txt") 127 | self.figure_file = os.path.join("ExperimentLogs", exp_name, "figure_curves.pt") 128 | self.figures = {} 129 | 130 | def print(self, text, append=True, sync_with_screen=True): 131 | with open(self.filename, "a" if append else "w") as f: 132 | print(datetime.now().strftime("%Y-%M-%d %H:%M:%S--->") + text, file=f) 133 | if sync_with_screen: 134 | print(datetime.now().strftime("%Y-%M-%d %H:%M:%S--->") + text) 135 | 136 | def log_figure(self, figure_name, figure: float): 137 | if figure_name in self.figures: 138 | self.figures[figure_name].append(figure) 139 | else: 140 | self.figures[figure_name] = [figure] 141 | torch.save(self.figures, self.figure_file) 142 | 143 | def save_model(self, obj, attribute: str = ""): 144 | torch.save(obj, os.path.join("ExperimentLogs", self.exp_name, f"check_point_{attribute}")) 145 | 146 | def torch2np(tensor): 147 | return einops.rearrange(tensor, "C H W -> H W C").cpu().numpy() 148 | 149 | def min_max_normalization(tensor): 150 | return (tensor-torch.min(tensor))/(torch.max(tensor)-torch.min(tensor)) 151 | 152 | def mean_normalization(tensor): 153 | return (tensor - torch.mean(tensor)) / (torch.max(tensor) - torch.min(tensor)) --------------------------------------------------------------------------------