├── .gitignore
├── FSDS_code.py
├── FSDS_explanation.ipynb
├── FSDS_matlab_example.m
├── FrequencySpectrumDistributionSimilarity.m
├── HyRA Usage.ipynb
├── HyRA.py
├── HyRA_results
├── hr_fft.png
├── linear.png
├── linear_fft.png
├── lr_fft.png
├── non_linear.png
└── non_linear_fft.png
├── Impulse_Responses.pptx
├── LPF2ISR.py
├── Low-pass filter to ISR.ipynb
├── README.md
├── cal_metrcis_and_table.py
├── example_figures
├── EDSRx4.png
├── baby.png
├── baby_x2.png
├── gt.png
├── impulse_response.png
├── lr.png
└── sr.png
├── sr_results
└── readme.md
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/FSDS_code.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import einops
4 | import numpy
5 | def frequency_align_integrate(tsr):
6 | """
7 | :param tsr: CxHxW
8 | :return:
9 | """
10 | # pad zeros to make H even
11 |
12 | if tsr.shape[1] % 2 == 1:
13 | tsr = torch.concat([tsr, torch.zeros(tsr.shape[0], 1, tsr.shape[2])], dim=1)
14 |
15 | tsr_rfft = torch.fft.rfft2(tsr, norm="backward")
16 | part_1 = einops.rearrange(tsr_rfft, "C (p H) W -> C p H W", p=2)
17 | part_2 = part_1[:, 1, :, :] # C H/2 W
18 | part_1 = part_1[:, 0, :, :] # C H/2 W
19 | part_2 = torch.flip(part_2, dims=[1]).contiguous()
20 |
21 | part_1 = torch.cumsum(part_1, dim=1)
22 | part_1 = torch.cumsum(part_1, dim=2)
23 |
24 | part_2 = torch.cumsum(part_2, dim=1)
25 | part_2 = torch.cumsum(part_2, dim=2)
26 |
27 | return (part_1, part_2)
28 |
29 |
30 | def _FrequencySpectrumDistributionSimilarity(pred, tar):
31 | """
32 |
33 | :param pred: the predicted output, in shape CxHxW
34 | :param tar: the ground-truth data, in shape CxHxW
35 | :return:
36 | """
37 | if pred.shape != tar.shape:
38 | raise ValueError("The shape of pred is expected to be the same as tar")
39 |
40 | pred = pred - torch.mean(pred)
41 | pred /= torch.std(pred)
42 | tar = tar - torch.mean(tar)
43 | tar /= torch.std(tar)
44 |
45 | pred_part_1, pred_part_2 = frequency_align_integrate(pred) # C H/2 W
46 | tar_part_1, tar_part_2 = frequency_align_integrate(tar) # C H/2 W
47 |
48 | part_1_error = ((pred_part_1 - tar_part_1).abs()) ** 2
49 | part_2_error = ((pred_part_2 - tar_part_2).abs()) ** 2
50 |
51 | return -10*math.log10(torch.sum(part_1_error + part_2_error) / torch.sum(tar_part_2.abs() ** 2 + tar_part_1.abs() ** 2))
52 |
53 |
54 | def FrequencySpectrumDistributionSimilarity(pred, gt):
55 | if pred.shape != gt.shape:
56 | raise ValueError("The shape of input tensor does not match")
57 | if len(pred.shape) == 3:
58 | return _FrequencySpectrumDistributionSimilarity(pred, gt)
59 | elif len(pred.shape) == 4:
60 | index = []
61 | for i in range(pred.shape[0]):
62 | index.append(_FrequencySpectrumDistributionSimilarity(pred[i], gt[i]))
63 | return index
64 |
65 | def matlab_interface(pred, gt):
66 | pred = numpy.array(pred)
67 | gt = numpy.array(gt)
68 | if len(pred.shape) != 3 or len(gt.shape) != 3:
69 | print("The input array should be 3D")
70 | return 0.0
71 | if pred.shape != gt.shape:
72 | print("The input array should have the same shape")
73 | return 0.0
74 | pred = torch.from_numpy(pred)
75 | gt = torch.from_numpy(gt)
76 | pred = einops.rearrange(pred, "H W C -> C H W")
77 | gt = einops.rearrange(gt, "H W C -> C H W")
78 | return FrequencySpectrumDistributionSimilarity(pred, gt)
79 |
--------------------------------------------------------------------------------
/FSDS_matlab_example.m:
--------------------------------------------------------------------------------
1 | close all;clearvars;
2 | pred = imread('example_figures\EDSRx4.png');
3 | gt = imread('example_figures\gt.png');
4 | FrequencySpectrumDistributionSimilarity(single(pred), single(gt))
--------------------------------------------------------------------------------
/FrequencySpectrumDistributionSimilarity.m:
--------------------------------------------------------------------------------
1 | function scores = FrequencySpectrumDistributionSimilarity(pred, gt)
2 | pyenv('Version','C:\Users\RisingEntropy\scoop\apps\miniconda3-py311\current\envs\CV\python.exe');
3 | module = py.importlib.import_module('FSDS_code');
4 | scores = module.matlab_interface(single(pred), single(gt));
5 | end
6 |
--------------------------------------------------------------------------------
/HyRA.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib
4 | import torch
5 | from PIL import Image
6 | from einops import einops
7 | from torchvision.transforms import ToTensor
8 | import torch.nn.functional as F
9 | import matplotlib.pyplot as plt
10 | from utils import zero_interpolate_torch
11 | import argparse
12 |
13 |
14 | def normalize(ten:torch.Tensor):
15 | return (ten - torch.min(ten)) / (torch.max(ten) - torch.min(ten))
16 |
17 |
18 | def convertNp(img:torch.Tensor):
19 | return einops.rearrange(img, "C H W -> H W C").numpy()
20 |
21 |
22 | def get_linear(lr:torch.Tensor, impulse_response:torch.Tensor, scale):
23 | impulse_response = impulse_response.unsqueeze(dim=1)
24 | lr_pad = F.pad(input=lr,
25 | pad=(impulse_response.shape[2] // (2 * scale), impulse_response.shape[2] // (2 * scale),
26 | impulse_response.shape[3] // (2 * scale), impulse_response.shape[3] // (2 * scale)),
27 | mode="reflect")
28 | lr_inter = zero_interpolate_torch(lr_pad, scale)
29 | lr_lp = F.conv2d(input=lr_inter, weight=impulse_response, stride=1, padding="valid", groups=3)
30 | return lr_lp
31 |
32 |
33 | def get_nonlinear(linear:torch.Tensor, sr:torch.Tensor):
34 | sr = F.interpolate(sr.unsqueeze(dim=0), size=(linear.shape[1], linear.shape[2]))
35 | return (sr - linear).squeeze()
36 |
37 |
38 | def main():
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument("--lr", type=str, required=True)
41 | parser.add_argument("--impulse_response", type=str, required=True)
42 | parser.add_argument("--sr", type=str, required=True)
43 | parser.add_argument("--scale", type=int, required=True)
44 | parser.add_argument("--save_path", type=str, required=True)
45 | args = parser.parse_args()
46 |
47 | if args.scale <= 0:
48 | raise ValueError("Scale must be positive")
49 | converter = ToTensor()
50 |
51 | lr = converter(Image.open(args.lr))
52 | impulse_response = converter(Image.open(args.impulse_response))
53 | sr = converter(Image.open(args.sr))
54 |
55 | cmp = matplotlib.colormaps["Blues"]
56 | cmp = cmp.reversed()
57 | linear = get_linear(lr, impulse_response, args.scale)
58 | non_linear = get_nonlinear(linear, sr)
59 | linear_fft = torch.log(torch.fft.fftshift(torch.fft.fft2(linear)).abs() + 1)
60 | non_linear_fft = torch.log(torch.fft.fftshift(torch.fft.fft2(non_linear)).abs() + 1)
61 | lr = F.interpolate(lr.unsqueeze(dim=0), size=(linear.shape[1], linear.shape[2]))[0]
62 | lr_fft = torch.log(torch.fft.fftshift(torch.fft.fft2(lr)).abs() + 1)
63 | sr_fft = torch.log(torch.fft.fftshift(torch.fft.fft2(sr)).abs() + 1)
64 | plt.imsave(os.path.join(args.save_path, "linear.png"), convertNp(linear / torch.max(linear)))
65 | plt.imsave(os.path.join(args.save_path, "non_linear.png"), convertNp(normalize(non_linear)))
66 | plt.imsave(os.path.join(args.save_path, "linear_fft.png"), linear_fft[0], cmap=cmp) # display only one channel
67 | plt.imsave(os.path.join(args.save_path, "non_linear_fft.png"), non_linear_fft[0], cmap=cmp)
68 | plt.imsave(os.path.join(args.save_path, "lr_fft.png"), lr_fft[0], cmap=cmp)
69 | plt.imsave(os.path.join(args.save_path, "hr_fft.png"), sr_fft[0], cmap=cmp)
70 | print("done!")
71 |
72 |
73 | if __name__ == "__main__":
74 | main()
--------------------------------------------------------------------------------
/HyRA_results/hr_fft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/hr_fft.png
--------------------------------------------------------------------------------
/HyRA_results/linear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/linear.png
--------------------------------------------------------------------------------
/HyRA_results/linear_fft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/linear_fft.png
--------------------------------------------------------------------------------
/HyRA_results/lr_fft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/lr_fft.png
--------------------------------------------------------------------------------
/HyRA_results/non_linear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/non_linear.png
--------------------------------------------------------------------------------
/HyRA_results/non_linear_fft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/HyRA_results/non_linear_fft.png
--------------------------------------------------------------------------------
/Impulse_Responses.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/Impulse_Responses.pptx
--------------------------------------------------------------------------------
/LPF2ISR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn.functional as F
4 |
5 | def sinc(tensor, omega):
6 | """
7 | The sinc function implementation. sinc(t) is defined as sin(pi*t)/(pi*t), omega is a
8 | factor to adjust the scale
9 | :param tensor: variants of sinc function
10 | :param omega: scale factor
11 | :return:
12 | """
13 | return torch.sin(torch.abs(math.pi * tensor * omega) + 1e-9) / (torch.abs(math.pi * tensor * omega) + 1e-9)
14 |
15 |
16 | def nearest_odd(num):
17 | return num + 1 if num % 2 == 0 else num
18 |
19 |
20 | def zero_interpolate_torch(img: torch.Tensor, scale: int):
21 | """
22 | interpolate 0 by `scale` times
23 | :param img: NxCxHxW
24 | :param scale:
25 | :return:
26 | """
27 | if len(img.shape) != 4: # batched
28 | img = img.unsqueeze(dim=0)
29 | img_ = img.reshape(-1, 1, img.shape[2], img.shape[3])
30 | img_int = torch.concat(
31 | [img_, torch.zeros(img_.shape[0], scale * scale - 1, img_.shape[2], img_.shape[3]).to(img.device)],
32 | dim=1)
33 | return torch.nn.functional.pixel_shuffle(img_int, scale).reshape(img.shape[0], img.shape[1], img.shape[2] * scale,
34 | img.shape[3] * scale).squeeze(dim=0)
35 |
36 |
37 | def lpf_sr_single(img: torch.Tensor, scale: int, omega=3.):
38 | """
39 | Interpolate an image using the sinc function, it's slower than the cubic or others.
40 |
41 | :param img: the image to be interpolated.
42 | :param size: the expected size
43 | :param omega: the factor to adjust the scale of the sinc function
44 | :return: the interpolated image
45 | :param backend: use torch or cuda code to apply zero-interpolate
46 | """
47 | img_pad = F.pad(input=img,
48 | pad=(img.shape[2] // 2, img.shape[2] // 2, img.shape[3] // 2, img.shape[3] // 2),
49 | mode="reflect")
50 | target = zero_interpolate_torch(img_pad, scale)
51 | h_grid = torch.linspace(-1, 1, (img.shape[2] // 2) * scale * 2 + 1)
52 | w_grid = torch.linspace(-1, 1, (img.shape[3] // 2) * scale * 2 + 1)
53 | kernel = torch.meshgrid([h_grid, w_grid], indexing='xy')
54 |
55 | kernel = sinc(kernel[0], omega) * sinc(kernel[1], omega)
56 | kernel = kernel.unsqueeze(dim=0).unsqueeze(dim=0).to(img.device)
57 | # kernel.require_grad = False
58 | target = F.conv2d(input=target, weight=kernel, stride=1, padding="valid")
59 | for i in range(target.shape[0]):
60 | if torch.max(img[i])>0.001: # to avoid a all 0 image
61 | target[i] = (target[i] - torch.min(target[i]))/(torch.max(target[i])-torch.min(target[i])) * (torch.max(img[i])-torch.min(img[i])) + torch.min(img[i])
62 | return target
63 |
64 |
65 | def lpf_sr(img: torch.Tensor, scale: int, omega=3.):
66 | """
67 | Interpolate image(s) using the sinc function, it's slower than the cubic or others.
68 | :param img: the image to be interpolated.
69 | :param size: the expected size
70 | :param omega: the factor to adjust the scale of the sinc function
71 | :return: the interpolated image
72 | """
73 | if len(img.shape) == 4: # Batched
74 | origin_shape = img.shape
75 | img = img.view(-1, 1, img.shape[2], img.shape[3])
76 | out = lpf_sr_single(img, scale, omega)
77 | return out.reshape(origin_shape[0], origin_shape[1], origin_shape[2] * scale,
78 | origin_shape[3] * scale)
79 | else:
80 | origin_shape = img.shape
81 | img = img.view(-1, 1, img.shape[1], img.shape[2])
82 | out = lpf_sr_single(img, scale, omega)
83 | return out.reshape(origin_shape[0], origin_shape[1] * scale, origin_shape[2] * scale)
--------------------------------------------------------------------------------
/Low-pass filter to ISR.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "efe07f220f804a5f",
6 | "metadata": {
7 | "collapsed": false
8 | },
9 | "source": [
10 | "# Low-pass filter to ISR\n",
11 | "We can use a simple low pass filter to ISR, here is the example."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "id": "892215eb",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "!pip install torchmetrics\n",
22 | "!pip install einops"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "id": "dc9e595e35c92629",
29 | "metadata": {
30 | "ExecuteTime": {
31 | "end_time": "2024-02-01T09:01:33.162507300Z",
32 | "start_time": "2024-02-01T09:01:24.059457600Z"
33 | },
34 | "collapsed": false
35 | },
36 | "outputs": [
37 | {
38 | "name": "stdout",
39 | "output_type": "stream",
40 | "text": [
41 | "LPF ISR performance: psnr=27.45184326171875, ssim=0.7809420824050903, fsds=27.509355337158876\n"
42 | ]
43 | }
44 | ],
45 | "source": [
46 | "import FSDS_code\n",
47 | "import utils\n",
48 | "from PIL import Image\n",
49 | "from torchvision.transforms import ToTensor\n",
50 | "from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure\n",
51 | "\n",
52 | "psnr = PeakSignalNoiseRatio(data_range=1).cuda()\n",
53 | "ssim = StructuralSimilarityIndexMeasure(data_range=1).cuda()\n",
54 | "\n",
55 | "lr = ToTensor()(Image.open(\"./example_figures/baby_x2.png\")).unsqueeze(0).cuda()\n",
56 | "gt = ToTensor()(Image.open(\"./example_figures/baby.png\")).unsqueeze(0).cuda()\n",
57 | "sr = utils.lpf_sr(img=lr, scale=2, omega=48.5)\n",
58 | "print(f\"LPF ISR performance: psnr={psnr(sr, gt)}, ssim={ssim(sr, gt)}, fsds={FSDS_code.FrequencySpectrumDistributionSimilarity(sr, gt)[0]}\")"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "id": "2abc40f7b9990b93",
64 | "metadata": {
65 | "collapsed": false
66 | },
67 | "source": [
68 | "# Implementation\n",
69 | "## 1) Zero-interpolation\n",
70 | "Before we apply a low-pass filter to the LR image, we first interpolate 0 to it to achieve the target size by:\n",
71 | "```python\n",
72 | "def zero_interpolate_torch(img: torch.Tensor, scale: int):\n",
73 | " \"\"\"\n",
74 | " interpolate 0 by `scale` times\n",
75 | " :param img: NxCxHxW\n",
76 | " :param scale:\n",
77 | " :return:\n",
78 | " \"\"\"\n",
79 | " if len(img.shape) != 4: # batched\n",
80 | " img = img.unsqueeze(dim=0)\n",
81 | " img_ = img.reshape(-1, 1, img.shape[2], img.shape[3])\n",
82 | " img_int = torch.concat(\n",
83 | " [img_, torch.zeros(img_.shape[0], scale * scale - 1, img_.shape[2], img_.shape[3]).to(img.device)],\n",
84 | " dim=1)\n",
85 | " return torch.nn.functional.pixel_shuffle(img_int, scale).reshape(img.shape[0], img.shape[1], img.shape[2] * scale,\n",
86 | " img.shape[3] * scale).squeeze(dim=0)\n",
87 | "\n",
88 | "```\n",
89 | "## 2) Low-pass filter\n",
90 | "Then we apply a low-pass filter to the interpolated image using convolution. The full implementation is:\n",
91 | "```python\n",
92 | "def lpf_sr_single(img: torch.Tensor, scale: int, omega=3.):\n",
93 | " \"\"\"\n",
94 | " Interpolate an image using the sinc function, it's slower than the cubic or others.\n",
95 | "\n",
96 | " :param img: the image to be interpolated.\n",
97 | " :param size: the expected size\n",
98 | " :param omega: the factor to adjust the scale of the sinc function\n",
99 | " :return: the interpolated image\n",
100 | " \"\"\"\n",
101 | " img_pad = F.pad(input=img,\n",
102 | " pad=(img.shape[2] // 2, img.shape[2] // 2, img.shape[3] // 2, img.shape[3] // 2),\n",
103 | " mode=\"reflect\")\n",
104 | " target = zero_interpolate_torch(img_pad, scale) # zero interpolate to the target size\n",
105 | " h_grid = torch.linspace(-1, 1, (img.shape[2] // 2) * scale * 2 + 1)\n",
106 | " w_grid = torch.linspace(-1, 1, (img.shape[3] // 2) * scale * 2 + 1)\n",
107 | " kernel = torch.meshgrid([h_grid, w_grid], indexing='xy')\n",
108 | "\n",
109 | " kernel = sinc(kernel[0], omega) * sinc(kernel[1], omega) # generate the low-passfilter, the sinc function with parameter omega\n",
110 | " kernel = kernel.unsqueeze(dim=0).unsqueeze(dim=0).to(img.device)\n",
111 | " # low-pass filtering, since the sinc function is symmetric, we can directly utilize the torch.nn.functional.conv2d\n",
112 | " target = F.conv2d(input=target, weight=kernel, stride=1, padding=\"valid\") \n",
113 | " for i in range(target.shape[0]):\n",
114 | " if torch.max(img[i])>1: # to avoid a all 0 image\n",
115 | " target[i] = (target[i] - torch.min(target[i]))/(torch.max(target[i])-torch.min(target[i])) * (torch.max(img[i])-torch.min(img[i])) + torch.min(img[i])\n",
116 | " return target\n",
117 | "```\n",
118 | "In the code above, the sinc function is defined as:\n",
119 | "```python\n",
120 | "def sinc(tensor, omega):\n",
121 | " \"\"\"\n",
122 | " The sinc function implementation. sinc(t) is defined as sin(pi*t)/(pi*t), omega is a\n",
123 | " factor to adjust the scale\n",
124 | " :param tensor: variants of sinc function\n",
125 | " :param omega: scale factor\n",
126 | " :return:\n",
127 | " \"\"\"\n",
128 | " return torch.sin(torch.abs(math.pi * tensor * omega) + 1e-9) / (torch.abs(math.pi * tensor * omega) + 1e-9)\n",
129 | "```"
130 | ]
131 | }
132 | ],
133 | "metadata": {
134 | "kernelspec": {
135 | "display_name": "Python 3 (ipykernel)",
136 | "language": "python",
137 | "name": "python3"
138 | },
139 | "language_info": {
140 | "codemirror_mode": {
141 | "name": "ipython",
142 | "version": 3
143 | },
144 | "file_extension": ".py",
145 | "mimetype": "text/x-python",
146 | "name": "python",
147 | "nbconvert_exporter": "python",
148 | "pygments_lexer": "ipython3",
149 | "version": "3.8.16"
150 | }
151 | },
152 | "nbformat": 4,
153 | "nbformat_minor": 5
154 | }
155 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Exploring the Low-Pass Filtering Behavior in Image Super-Resolution
2 |
3 | [](https://icml.cc/virtual/2024/poster/35191) [
4 | ](https://arxiv.org/abs/2405.07919)
5 |
6 | Haoyu Deng, Zijing Xu, Yule Duan, Xiao Wu, Wenjie Shu, Liang-Jian Deng†
7 |
8 |
9 | †Corresponding author
10 |
11 |
12 | If you have any questions, feel free to raise an issue or send a mail to academic@hydeng.cn. I will respond to you as soon as possible. If you think our work is useful, please give us a warmful citation:
13 | ```
14 | @article{deng2024exploring,
15 | title={Exploring the Low-Pass Filtering Behavior in Image Super-Resolution},
16 | author={Deng, Haoyu and Xu, Zijing and Duan, Yule and Wu, Xiao and Shu, Wenjie and Deng, Liang-Jian},
17 | journal={arXiv preprint arXiv:2405.07919},
18 | year={2024}
19 | }
20 | ```
21 |
22 | ## TODOs
23 | - [x] matlab implementation of FSDS
24 | - [ ] make a video
25 | - [ ] try to use FSDS as a loss and share results
26 |
27 | ## Impulse Responses
28 | Please refer to `Impulse_Responses.pptx`. We provide two versions, with/without enhancement. We use the enhancement provided by PPT for better visualization. The enhanced figures are brighter in color while the unenhanced ones are mathematically closer to the sinc funtion. If your directly observe the output of networks (before clamp), you will find another feature of the sinc function: negative values near the main lobe.
29 |
30 | ## Hybrid Response Analysis (HyRA)
31 | We provide a script to directly analyze a network using HyRA, i.e., `HyRA.py`. Here is the usage
32 |
33 | ```bash
34 | python HyRA.py --lr [path to low-resolution image] --sr [path to super-resolution image, namely N(I) in the paper] --scale [the scale of super resolution] --impulse_response [path to impulse response] --save_path [path to save results]
35 | ```
36 |
37 | We also provide a tutorial for the code. Please refer to `HyRA Usage.ipynb`.
38 |
39 | ## Frequency Spectrum Distribution Similarity (FSDS)
40 | ### Python Version
41 | FSDS describes the image quality from the perspective of frequency spectrum. The complete implementation of FSDS is in `FSDS_code.py`. We provide an explanation and a tutorial for the code. Please refer to `FSDS_explanation.ipynb`.
42 |
43 | ### Matlab Version
44 | To ensure that matlab version produces the **same result** as python version, the best way is to use FFI produced by matlab to invoke python code. Please edit the python intepreter path in `FrequencySpectrumDistributionSimilarity.m`. Make sure the intepreter produced are equipped with proper environment!!! An example for this can be found in `FSDS_matlab_example.m`.
45 | ## Experimental Results
46 | We provide super-resolution results and code for Tab.1.
47 |
48 | The super-resolution results can be found in: [https://huggingface.co/RisingEntropy/NNsAreLPFing/blob/main/sr_results.zip](https://huggingface.co/RisingEntropy/NNsAreLPFing/blob/main/sr_results.zip)
49 |
50 | Code that generates Tab.1 is in `cal_metrcis_and_table.py`. Please unzip `sr_results.zip` to the `sr_results` directory and run the code. The `sr_results` directory should look like:
51 | ```
52 | sr_results
53 | ├─ArbSR
54 | │ ├─RCAN_x12
55 | │ ├─RCAN_x2
56 | │ ├─RCAN_x3
57 | │ ├─RCAN_x4
58 | │ ├─RCAN_x6
59 | │ └─RCAN_x8
60 | ├─Bicubic
61 | │ ├─Bicubic_x12
62 | │ ├─Bicubic_x18
63 | │ ├─Bicubic_x2
64 | │ ├─Bicubic_x3
65 | │ ├─Bicubic_x4
66 | │ └─Bicubic_x6
67 | ├─edsr_baseline
68 | │ ├─x2
69 | │ ├─x3
70 | │ └─x4
71 | ├─GRL
72 | │ ├─base
73 | │ │ ├─X2
74 | │ │ ├─X3
75 | │ │ └─X4
76 | │ ├─small
77 | │ │ ├─X2
78 | │ │ ├─X3
79 | │ │ └─X4
80 | │ └─tiny
81 | │ ├─X2
82 | │ ├─X3
83 | │ └─X4
84 | ├─HAT
85 | │ ├─HAT-S_SRx2
86 | │ ├─HAT-S_SRx3
87 | │ ├─HAT-S_SRx4
88 | │ ├─HAT_SRx2
89 | │ ├─HAT_SRx3
90 | │ └─HAT_SRx4
91 | ├─HDSRNet
92 | │ ├─X2
93 | │ ├─X3
94 | │ └─X4
95 | ├─hr
96 | ├─ITSRN
97 | │ ├─ITSRN_x12
98 | │ ├─ITSRN_x2
99 | │ ├─ITSRN_x3
100 | │ ├─ITSRN_x4
101 | │ └─ITSRN_x6
102 | ├─liif
103 | │ ├─edsr_x12
104 | │ ├─edsr_x18
105 | │ ├─edsr_x2
106 | │ ├─edsr_x3
107 | │ ├─edsr_x4
108 | │ ├─edsr_x6
109 | │ ├─rdn_x12
110 | │ ├─rdn_x18
111 | │ ├─rdn_x2
112 | │ ├─rdn_x3
113 | │ ├─rdn_x4
114 | │ └─rdn_x6
115 | ├─LTE
116 | │ ├─EDSR_baseline_x12
117 | │ ├─EDSR_baseline_x2
118 | │ ├─EDSR_baseline_x3
119 | │ ├─EDSR_baseline_x4
120 | │ ├─EDSR_baseline_x6
121 | │ ├─RDN_x12
122 | │ ├─RDN_x2
123 | │ ├─RDN_x3
124 | │ ├─RDN_x4
125 | │ ├─RDN_x6
126 | │ ├─SwinIR_x12
127 | │ ├─SwinIR_x2
128 | │ ├─SwinIR_x3
129 | │ ├─SwinIR_x4
130 | │ └─SwinIR_x6
131 | ├─OPESR
132 | │ ├─EDSR_x2
133 | │ ├─EDSR_x3
134 | │ ├─EDSR_x4
135 | │ ├─RDN_x2
136 | │ ├─RDN_x3
137 | │ └─RDN_x4
138 | ├─RDN
139 | │ ├─RDN_small_x2
140 | │ ├─RDN_small_x3
141 | │ └─RDN_small_x4
142 | ├─SRNO
143 | │ ├─EDSR_baseline_x12
144 | │ ├─EDSR_baseline_x2
145 | │ ├─EDSR_baseline_x3
146 | │ ├─EDSR_baseline_x4
147 | │ └─EDSR_baseline_x6
148 | └─SwinIR
149 | ├─swinir_classical_sr_x2
150 | ├─swinir_classical_sr_x3
151 | ├─swinir_classical_sr_x4
152 | └─swinir_classical_sr_x8
153 | ```
154 | ## Acknowledgement
155 | We appreciate anonymous reviewers for their previous suggestions to help this paper better. Moreover, we would like to express our sincere gratitude to [Ruijie Zhu](https://github.com/ridgerchu) for his generous support in GPUs Without his support, it is hard for us to do experiments using full-scale DIV2K dataset. This work is supported by NSFC (12271083).
--------------------------------------------------------------------------------
/cal_metrcis_and_table.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import torch
4 | import os
5 | import torchmetrics
6 | import numpy as np
7 | from PIL import Image
8 | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, \
9 | LearnedPerceptualImagePatchSimilarity
10 | from torchvision.transforms import ToTensor
11 | from tqdm import tqdm
12 |
13 | from FrequencySpectrumDistributionSimilarity import FrequencySpectrumDistributionSimilarity
14 |
15 | minList = ["LPIPS", "L1InFFT", "L1", "L2InFFT", "L2"]
16 | max_dict = {}
17 | color_table = ["red", "blue"]
18 |
19 |
20 | class TableRow:
21 | def __init__(self, paper, overall_scales, overall_metrics, citation=None):
22 | self.data = {}
23 | self.paper = paper
24 | self.citation = citation
25 | self.overall_scales = overall_scales
26 | self.overall_metrics = overall_metrics
27 |
28 | def addMetric(self, metric, value, scale):
29 | if metric not in self.data.keys():
30 | self.data[metric] = {}
31 | if scale not in self.data[metric].keys():
32 | self.data[metric][scale] = []
33 | self.data[metric][scale].append(value)
34 |
35 | def __str__(self):
36 | out = f"{self.paper}"
37 | if self.citation is not None:
38 | out += f"{self.citation}"
39 |
40 | for metric in self.overall_metrics:
41 | if metric not in self.data.keys():
42 | print("ERROR")
43 | return
44 | for scale in self.overall_scales:
45 | if scale in self.data[metric].keys():
46 | mean = np.mean(self.data[metric][scale])
47 | mark = False
48 | index = max_dict[metric][scale].index(mean)
49 | # for i in range(len(color_table)):
50 | # if mean == max_dict[metric][scale][i]:
51 | # if mean < 0.01:
52 | # out += f"&\\textcolor{{ {color_table[i]} }} {{ {mean:.3e}({index+1}) }} "
53 | # else:
54 | # out += f"&\\textcolor{{{color_table[i]}}}{{{mean:.3f}({index+1}) }}"
55 | # mark = True
56 | # if not mark:
57 | # if mean < 0.01:
58 | # out += f"&{mean:.3e}({index+1}) "
59 | # else:
60 | # out += f"&{mean:.3f}({index+1}) "
61 | for i in range(len(color_table)):
62 | if mean == max_dict[metric][scale][i]:
63 | if metric!="LPIPS" and metric!="SSIM":
64 | if mean < 0.01:
65 | out += f"&\\textcolor{{{color_table[i]}}}{{{mean:.2e}}}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}"
66 | else:
67 | out += f"&\\textcolor{{{color_table[i]}}}{{{mean:.2f}}}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}"
68 | mark = True
69 | else:
70 | if mean < 0.01:
71 | out += f"&\\textcolor{{{color_table[i]}}}{{{mean:.3f}}}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}"
72 | else:
73 | out += f"&\\textcolor{{{color_table[i]}}}{{{mean:.3f}}}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}"
74 | mark = True
75 | if not mark:
76 | if metric!="LPIPS" and metric!="SSIM":
77 | if mean < 0.01:
78 | out += f"&{mean:.2e}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}"
79 | else:
80 | out += f"&{mean:.2f}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}"
81 | else:
82 | if mean < 0.01:
83 | out += f"&{mean:.3f}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}"
84 | else:
85 | out += f"&{mean:.3f}\\textcolor{{gray}}{{\\textsuperscript{{{index + 1}}}}}"
86 |
87 | else:
88 | out += "&-"
89 | return out + "\\\\"
90 |
91 |
92 | def l1_in_fft(gt, img):
93 | diff = (torch.fft.fftn(gt) - torch.fft.fftn(img)).abs()
94 | return (diff.sum() / diff.numel()).item()
95 |
96 |
97 | def l2_in_fft(gt, img):
98 | diff = (torch.fft.fftn(gt) - torch.fft.fftn(img)).abs()
99 | return (torch.sum(diff ** 2) / diff.numel()).item()
100 |
101 |
102 | def psnr_in_fft(gt, img):
103 | gt_fft = torch.fft.fftn(gt)
104 | img_fft = torch.fft.fftn(img)
105 | _psnr = PeakSignalNoiseRatio(data_range=torch.max(gt_fft.abs()).item()).cuda()
106 | return _psnr(gt_fft.abs(), img_fft.abs()).item()
107 |
108 | def psnr(gt, img):
109 | _psnr = PeakSignalNoiseRatio(data_range=1).cuda()
110 | return _psnr(gt, img).item()
111 |
112 | def l1(gt, img):
113 | return torch.nn.functional.l1_loss(gt, img).item()
114 |
115 | def l2(gt, img):
116 | return torch.nn.functional.mse_loss(gt, img).item()
117 | convertor = ToTensor()
118 | psnr = PeakSignalNoiseRatio(data_range=1).cuda()
119 | ssim = StructuralSimilarityIndexMeasure(data_range=1).cuda()
120 | fsds = FrequencySpectrumDistributionSimilarity
121 | lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze').cuda()
122 | l1 = torch.nn.functional.l1_loss
123 | l2 = torch.nn.functional.mse_loss
124 | row_item = {
125 | "EDSR": {
126 | "paper": "EDSR",
127 | "citation": "\citep{edsr}",
128 | "overall_scales": [2, 3, 4],
129 | "scale_path": {
130 | 2: "./sr_results/edsr_baseline/x2",
131 | 3: "./sr_results/edsr_baseline/x3",
132 | 4: "./sr_results/edsr_baseline/x4"
133 | },
134 | },
135 |
136 | "EDSR-LIIF": {
137 | "paper": "EDSR-LIIF",
138 | "citation": "\citep{liif}",
139 | "overall_scales": [2, 3, 4, 6, 12],
140 | "scale_path": {
141 | 2: "./sr_results/liif/edsr_x2",
142 | 3: "./sr_results/liif/edsr_x3",
143 | 4: "./sr_results/liif/edsr_x4",
144 | 6: "./sr_results/liif/edsr_x6",
145 | 12: "./sr_results/liif/edsr_x12"
146 | },
147 | },
148 |
149 | "EDSR-OPESR": {
150 | "paper": "EDSR-OPESR",
151 | "citation": "\citep{OPESR}",
152 | "overall_scales": [2, 3, 4],
153 | "scale_path": {
154 | 2: "./sr_results/OPESR/EDSR_x2",
155 | 3: "./sr_results/OPESR/EDSR_x3",
156 | 4: "./sr_results/OPESR/EDSR_x4"
157 | },
158 | },
159 |
160 | "EDSR-SRNO": {
161 | "paper": "EDSR-SRNO",
162 | "citation": "\citep{SRNO}",
163 | "overall_scales": [2, 3, 4, 6, 12],
164 | "scale_path": {
165 | 2: "./sr_results/SRNO/EDSR_baseline_x2",
166 | 3: "./sr_results/SRNO/EDSR_baseline_x3",
167 | 4: "./sr_results/SRNO/EDSR_baseline_x4",
168 | 6: "./sr_results/SRNO/EDSR_baseline_x6",
169 | 12: "./sr_results/SRNO/EDSR_baseline_x12"
170 | },
171 | },
172 |
173 | "EDSR-LTE": {
174 | "paper": "EDSR-LTE",
175 | "citation": "\citep{LTE}",
176 | "overall_scales": [2, 3, 4, ],
177 | "scale_path": {
178 | 2: "./sr_results/LTE/EDSR_baseline_x2",
179 | 3: "./sr_results/LTE/EDSR_baseline_x3",
180 | 4: "./sr_results/LTE/EDSR_baseline_x4"
181 | },
182 | },
183 |
184 | "RDN": {
185 | "paper": "RDN",
186 | "citation": "\citep{rdn}",
187 | "overall_scales": [2, 3, 4],
188 | "scale_path": {
189 | 2: "./sr_results/RDN/RDN_small_x2",
190 | 3: "./sr_results/RDN/RDN_small_x3",
191 | 4: "./sr_results/RDN/RDN_small_x4"
192 | },
193 | },
194 |
195 | "RDN-LIIF": {
196 | "paper": "RDN-LIIF",
197 | "citation": "\citep{liif}",
198 | "overall_scales": [2, 3, 4, 6, 12],
199 | "scale_path": {
200 | 2: "./sr_results/liif/rdn_x2",
201 | 3: "./sr_results/liif/rdn_x3",
202 | 4: "./sr_results/liif/rdn_x4",
203 | 6: "./sr_results/liif/rdn_x6",
204 | 12: "./sr_results/liif/rdn_x12"
205 | },
206 |
207 | },
208 |
209 | "RDN-OPESR": {
210 | "paper": "RDN-OPESR",
211 | "citation": "\citep{OPESR}",
212 | "overall_scales": [2, 3, 4],
213 | "scale_path": {
214 | 2: "./sr_results/OPESR/RDN_x2",
215 | 3: "./sr_results/OPESR/RDN_x3",
216 | 4: "./sr_results/OPESR/RDN_x4"
217 | },
218 | },
219 |
220 | "RDN-LTE": {
221 | "paper": "RDN-LTE",
222 | "citation": "\citep{LTE}",
223 | "overall_scales": [2, 3, 4, 6, 12],
224 | "scale_path": {
225 | 2: "./sr_results/LTE/RDN_x2",
226 | 3: "./sr_results/LTE/RDN_x3",
227 | 4: "./sr_results/LTE/RDN_x4",
228 | 6: "./sr_results/LTE/RDN_x6",
229 | 12: "./sr_results/LTE/RDN_x12"
230 | },
231 | },
232 |
233 | "SwinIR-classical": {
234 | "paper": "SwinIR-classical",
235 | "citation": "\citep{swinir}",
236 | "overall_scales": [2, 3, 4],
237 | "scale_path": {
238 | 2: "./sr_results/SwinIR/swinir_classical_sr_x2",
239 | 3: "./sr_results/SwinIR/swinir_classical_sr_x3",
240 | 4: "./sr_results/SwinIR/swinir_classical_sr_x4"
241 | },
242 | },
243 |
244 | "ITSRN": {
245 | "paper": "ITSRN",
246 | "citation": "\citep{ITSRN}",
247 | "overall_scales": [2, 3, 4, 6, 12],
248 | "scale_path": {
249 | 2: "./sr_results/ITSRN/ITSRN_x2",
250 | 3: "./sr_results/ITSRN/ITSRN_x3",
251 | 4: "./sr_results/ITSRN/ITSRN_x4",
252 | 6: "./sr_results/ITSRN/ITSRN_x6",
253 | 12: "./sr_results/ITSRN/ITSRN_x12"
254 | },
255 | },
256 |
257 | "Bicubic": {
258 | "paper": "Bicubic",
259 | "citation": "",
260 | "overall_scales": [2, 3, 4, 6, 12],
261 | "scale_path": {
262 | 2: "./sr_results/Bicubic/Bicubic_x2",
263 | 3: "./sr_results/Bicubic/Bicubic_x3",
264 | 4: "./sr_results/Bicubic/Bicubic_x4",
265 | 6: "./sr_results/Bicubic/Bicubic_x6",
266 | 12: "./sr_results/Bicubic/Bicubic_x12"
267 | },
268 | },
269 |
270 | "HAT-S": {
271 | "paper": "HAT-S",
272 | "citation": "\citep{hat}",
273 | "overall_scales": [2, 3, 4],
274 | "scale_path": {
275 | 2: "./sr_results/HAT/HAT-S_SRx2",
276 | 3: "./sr_results/HAT/HAT-S_SRx3",
277 | 4: "./sr_results/HAT/HAT-S_SRx4"
278 | },
279 | },
280 |
281 | "HAT": {
282 | "paper": "HAT",
283 | "citation": "\citep{hat}",
284 | "overall_scales": [2, 3, 4],
285 | "scale_path": {
286 | 2: "./sr_results/HAT/HAT_SRx2",
287 | 3: "./sr_results/HAT/HAT_SRx3",
288 | 4: "./sr_results/HAT/HAT_SRx4"
289 | },
290 | },
291 |
292 | "HDSRNet": {
293 | "paper": "HDSRNet",
294 | "citation": "\citep{hdsrnet}",
295 | "overall_scales": [2, 3, 4],
296 | "scale_path": {
297 | 2: "./sr_results/HDSRNet/X2",
298 | 3: "./sr_results/HDSRNet/X3",
299 | 4: "./sr_results/HDSRNet/X4"
300 | },
301 | },
302 |
303 | "GRLBase": {
304 | "paper": "GRLBase",
305 | "citation": "\citep{grl}",
306 | "overall_scales": [2, 3, 4],
307 | "scale_path": {
308 | 2: "./sr_results/GRL/base/X2",
309 | 3: "./sr_results/GRL/base/X3",
310 | 4: "./sr_results/GRL/base/X4"
311 | },
312 | },
313 |
314 | "GRLSmall": {
315 | "paper": "GRLSmal",
316 | "citation": "\citep{grl}",
317 | "overall_scales": [2, 3, 4],
318 | "scale_path": {
319 | 2: "./sr_results/GRL/small/X2",
320 | 3: "./sr_results/GRL/small/X3",
321 | 4: "./sr_results/GRL/small/X4"
322 | },
323 | },
324 |
325 | "GRLTiny": {
326 | "paper": "GRLTiny",
327 | "citation": "\citep{grl}",
328 | "overall_scales": [2, 3, 4],
329 | "scale_path": {
330 | 2: "./sr_results/GRL/tiny/X2",
331 | 3: "./sr_results/GRL/tiny/X3",
332 | 4: "./sr_results/GRL/tiny/X4"
333 | },
334 | }
335 | }
336 |
337 |
338 | def getGTFileName(name):
339 | return name[0:4] + ".png"
340 |
341 |
342 | total_img = 0
343 |
344 |
345 | def validate_data(dic):
346 | for key in dic.keys():
347 | for scale in dic[key]["overall_scales"]:
348 | path = dic[key]["scale_path"][scale]
349 | for file in os.listdir(path):
350 | img = Image.open(os.path.join(path, file))
351 | gt = Image.open(os.path.join("./sr_results/hr", getGTFileName(file)))
352 | global total_img
353 | total_img += 1
354 | if img.size != gt.size:
355 | print(f"Size mismatch for {file}")
356 | return False
357 | return True
358 |
359 |
360 |
361 | with torch.no_grad():
362 | validate_data(row_item)
363 | print(f"total:{total_img}")
364 | print("validate ok")
365 | pbar = tqdm(total=total_img)
366 | rows = []
367 | gts = {}
368 | for file in os.listdir("./sr_results/hr"):
369 | gts[file] = convertor(Image.open(os.path.join("./sr_results/hr", file))).cuda()
370 | for net in row_item.keys():
371 | row_item[net]["metrics"] = {}
372 | for scale in row_item[net]["overall_scales"]:
373 | path = row_item[net]["scale_path"][scale]
374 | row_item[net]["metrics"][scale] = {}
375 | for file in os.listdir(path):
376 | img = convertor(Image.open(os.path.join(path, file))).cuda()
377 | gt = gts[getGTFileName(file)]
378 | row_item[net]["metrics"][scale][file] = {"PSNR": psnr(gt.unsqueeze(0), img.unsqueeze(0)).item(),
379 | "SSIM": ssim(gt.unsqueeze(0), img.unsqueeze(0)).item(),
380 | "LPIPS": lpips(gt.unsqueeze(0), img.unsqueeze(0)).item(),
381 | "FSDS": fsds(gt, img),
382 | "L1InFFT": l1_in_fft(gt, img),
383 | "L1": l1(gt, img).item(),
384 | "L2InFFT": l2_in_fft(gt, img),
385 | "L2": l2(gt, img, reduction="sum").item(),
386 | "PSNRInFFT": psnr_in_fft(gt, img)}
387 |
388 | pbar.update(1)
389 |
390 | with open("div2k_metrics_icml_rebuttal.json", "w") as f:
391 | json.dump(row_item, f, indent=4)
392 |
393 | json_text = ""
394 | with open("div2k_metrics_icml_rebuttal.json", "r") as f:
395 | json_text = f.read()
396 | dic = json.loads(json_text)
397 | row_item = dic
398 | rows = []
399 | # all_metrics = ["PSNR", "SSIM", "LPIPS", "FSDS", "L1InFFT", "L1", "L2InFFT", "L2", "PSNRInFFT"]
400 | # all_metrics = ["FSDS", "L1InFFT", "L1", "L2InFFT", "L2", "PSNRInFFT"]
401 | # all_metrics = ["FSDS", "PSNRInFFT"]
402 | all_metrics = ["PSNR", "SSIM", "LPIPS", "FSDS"]
403 | for key in row_item.keys():
404 | row = TableRow(row_item[key]["paper"], [2, 3, 4, 6, 12], all_metrics,None)
405 | for scale in row_item[key]["overall_scales"]:
406 | for item in row_item[key]["metrics"][str(scale)]:
407 | row.addMetric("PSNR", row_item[key]["metrics"][str(scale)][item]["PSNR"], scale)
408 | row.addMetric("SSIM", row_item[key]["metrics"][str(scale)][item]["SSIM"], scale)
409 | row.addMetric("LPIPS", row_item[key]["metrics"][str(scale)][item]["LPIPS"], scale)
410 | row.addMetric("FSDS", row_item[key]["metrics"][str(scale)][item]["FSDS"], scale)
411 | # row.addMetric("PSNRInFFT", row_item[key]["metrics"][str(scale)][item]["PSNRInFFT"], scale)
412 | # row.addMetric("PSNR", row_item[key]["metrics"][str(scale)][item]["PSNR"], scale)
413 | # row.addMetric("L1InFFT", row_item[key]["metrics"][str(scale)][item]["L1InFFT"], scale)
414 | # row.addMetric("L1", row_item[key]["metrics"][str(scale)][item]["L1"], scale)
415 | # row.addMetric("L2InFFT", row_item[key]["metrics"][str(scale)][item]["L2InFFT"], scale)
416 | # row.addMetric("L2", row_item[key]["metrics"][str(scale)][item]["L2"], scale)
417 | # row.addMetric("PSNRInFFT", row_item[key]["metrics"][str(scale)][item]["PSNRInFFT"], scale)
418 |
419 | rows.append(row)
420 |
421 | for metrics in all_metrics:
422 | max_dict[metrics] = {}
423 | for scale in [2, 3, 4, 6, 12]:
424 | if scale not in max_dict[metrics].keys():
425 | max_dict[metrics][scale] = []
426 | for row in rows:
427 | if scale in row.data[metrics].keys():
428 | max_dict[metrics][scale].append(np.mean(row.data[metrics][scale]))
429 | if metrics not in minList:
430 | max_dict[metrics][scale].sort(reverse=True)
431 | else:
432 | max_dict[metrics][scale].sort()
433 |
434 | for row in rows:
435 | print(row)
--------------------------------------------------------------------------------
/example_figures/EDSRx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/EDSRx4.png
--------------------------------------------------------------------------------
/example_figures/baby.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/baby.png
--------------------------------------------------------------------------------
/example_figures/baby_x2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/baby_x2.png
--------------------------------------------------------------------------------
/example_figures/gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/gt.png
--------------------------------------------------------------------------------
/example_figures/impulse_response.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/impulse_response.png
--------------------------------------------------------------------------------
/example_figures/lr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/lr.png
--------------------------------------------------------------------------------
/example_figures/sr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RisingEntropy/LPFInISR/00a475308dff1765acbd3c67f17ac6e83da60d60/example_figures/sr.png
--------------------------------------------------------------------------------
/sr_results/readme.md:
--------------------------------------------------------------------------------
1 | Please download super-resolution results from [https://huggingface.co/RisingEntropy/NNsAreLPFing/blob/main/sr_results.zip](https://huggingface.co/RisingEntropy/NNsAreLPFing/blob/main/sr_results.zip) and unzip it.
2 | This directory should be like:
3 | ```
4 | ├─ArbSR
5 | │ ├─RCAN_x12
6 | │ ├─RCAN_x2
7 | │ ├─RCAN_x3
8 | │ ├─RCAN_x4
9 | │ ├─RCAN_x6
10 | │ └─RCAN_x8
11 | ├─Bicubic
12 | │ ├─Bicubic_x12
13 | │ ├─Bicubic_x18
14 | │ ├─Bicubic_x2
15 | │ ├─Bicubic_x3
16 | │ ├─Bicubic_x4
17 | │ └─Bicubic_x6
18 | ├─edsr_baseline
19 | │ ├─x2
20 | │ ├─x3
21 | │ └─x4
22 | ├─GRL
23 | │ ├─base
24 | │ │ ├─X2
25 | │ │ ├─X3
26 | │ │ └─X4
27 | │ ├─small
28 | │ │ ├─X2
29 | │ │ ├─X3
30 | │ │ └─X4
31 | │ └─tiny
32 | │ ├─X2
33 | │ ├─X3
34 | │ └─X4
35 | ├─HAT
36 | │ ├─HAT-S_SRx2
37 | │ ├─HAT-S_SRx3
38 | │ ├─HAT-S_SRx4
39 | │ ├─HAT_SRx2
40 | │ ├─HAT_SRx3
41 | │ └─HAT_SRx4
42 | ├─HDSRNet
43 | │ ├─X2
44 | │ ├─X3
45 | │ └─X4
46 | ├─hr
47 | ├─ITSRN
48 | │ ├─ITSRN_x12
49 | │ ├─ITSRN_x2
50 | │ ├─ITSRN_x3
51 | │ ├─ITSRN_x4
52 | │ └─ITSRN_x6
53 | ├─liif
54 | │ ├─edsr_x12
55 | │ ├─edsr_x18
56 | │ ├─edsr_x2
57 | │ ├─edsr_x3
58 | │ ├─edsr_x4
59 | │ ├─edsr_x6
60 | │ ├─rdn_x12
61 | │ ├─rdn_x18
62 | │ ├─rdn_x2
63 | │ ├─rdn_x3
64 | │ ├─rdn_x4
65 | │ └─rdn_x6
66 | ├─LTE
67 | │ ├─EDSR_baseline_x12
68 | │ ├─EDSR_baseline_x2
69 | │ ├─EDSR_baseline_x3
70 | │ ├─EDSR_baseline_x4
71 | │ ├─EDSR_baseline_x6
72 | │ ├─RDN_x12
73 | │ ├─RDN_x2
74 | │ ├─RDN_x3
75 | │ ├─RDN_x4
76 | │ ├─RDN_x6
77 | │ ├─SwinIR_x12
78 | │ ├─SwinIR_x2
79 | │ ├─SwinIR_x3
80 | │ ├─SwinIR_x4
81 | │ └─SwinIR_x6
82 | ├─OPESR
83 | │ ├─EDSR_x2
84 | │ ├─EDSR_x3
85 | │ ├─EDSR_x4
86 | │ ├─RDN_x2
87 | │ ├─RDN_x3
88 | │ └─RDN_x4
89 | ├─RDN
90 | │ ├─RDN_small_x2
91 | │ ├─RDN_small_x3
92 | │ └─RDN_small_x4
93 | ├─SRNO
94 | │ ├─EDSR_baseline_x12
95 | │ ├─EDSR_baseline_x2
96 | │ ├─EDSR_baseline_x3
97 | │ ├─EDSR_baseline_x4
98 | │ └─EDSR_baseline_x6
99 | └─SwinIR
100 | ├─swinir_classical_sr_x2
101 | ├─swinir_classical_sr_x3
102 | ├─swinir_classical_sr_x4
103 | └─swinir_classical_sr_x8
104 | ```
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import warnings
3 | from datetime import datetime
4 |
5 | import einops
6 | import torch
7 | import math
8 | import torch.nn.functional as F
9 |
10 | def sinc(tensor, omega):
11 | """
12 | The sinc function implementation. sinc(t) is defined as sin(pi*t)/(pi*t), omega is a
13 | factor to adjust the scale
14 | :param tensor: variants of sinc function
15 | :param omega: scale factor
16 | :return:
17 | """
18 | return torch.sin(torch.abs(math.pi * tensor * omega) + 1e-9) / (torch.abs(math.pi * tensor * omega) + 1e-9)
19 |
20 |
21 | def nearest_odd(num):
22 | return num + 1 if num % 2 == 0 else num
23 |
24 |
25 | def zero_interpolate_torch(img: torch.Tensor, scale: int):
26 | """
27 | interpolate 0 by `scale` times
28 | :param img: NxCxHxW
29 | :param scale:
30 | :return:
31 | """
32 | if len(img.shape) != 4: # batched
33 | img = img.unsqueeze(dim=0)
34 | img_ = img.reshape(-1, 1, img.shape[2], img.shape[3])
35 | img_int = torch.concat(
36 | [img_, torch.zeros(img_.shape[0], scale * scale - 1, img_.shape[2], img_.shape[3]).to(img.device)],
37 | dim=1)
38 | return torch.nn.functional.pixel_shuffle(img_int, scale).reshape(img.shape[0], img.shape[1], img.shape[2] * scale,
39 | img.shape[3] * scale).squeeze(dim=0)
40 |
41 |
42 | def lpf_sr_single(img: torch.Tensor, scale: int, omega=3., rgb_range = 255):
43 | """
44 | Interpolate an image using the sinc function, it's slower than the cubic or others.
45 |
46 | :param img: the image to be interpolated.
47 | :param size: the expected size
48 | :param omega: the factor to adjust the scale of the sinc function
49 | :return: the interpolated image
50 | :param backend: use torch or cuda code to apply zero-interpolate
51 | """
52 | img_pad = F.pad(input=img,
53 | pad=(img.shape[2] // 2, img.shape[2] // 2, img.shape[3] // 2, img.shape[3] // 2),
54 | mode="reflect")
55 | target = zero_interpolate_torch(img_pad, scale)
56 | h_grid = torch.linspace(-1, 1, (img.shape[2] // 2) * scale * 2 + 1)
57 | w_grid = torch.linspace(-1, 1, (img.shape[3] // 2) * scale * 2 + 1)
58 | kernel = torch.meshgrid([h_grid, w_grid], indexing='xy')
59 |
60 | kernel = sinc(kernel[0], omega) * sinc(kernel[1], omega)
61 | kernel = kernel.unsqueeze(dim=0).unsqueeze(dim=0).to(img.device)
62 | # kernel.require_grad = False
63 | target = F.conv2d(input=target, weight=kernel, stride=1, padding="valid")
64 | for i in range(target.shape[0]):
65 | if torch.max(img[i])>=0.01: # to avoid a all 0 image
66 | target[i] = (target[i] - torch.min(target[i]))/(torch.max(target[i])-torch.min(target[i])) * (torch.max(img[i])-torch.min(img[i])) + torch.min(img[i])
67 | return target
68 |
69 |
70 |
71 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
72 | if hr.nelement() == 1: return 0
73 |
74 | diff = (sr - hr) / rgb_range
75 | if dataset and dataset.dataset.benchmark:
76 | shave = scale
77 | if diff.size(1) > 1:
78 | gray_coeffs = [65.738, 129.057, 25.064]
79 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
80 | diff = diff.mul(convert).sum(dim=1)
81 | else:
82 | shave = scale + 6
83 |
84 | valid = diff[..., shave:-shave, shave:-shave]
85 | mse = valid.pow(2).mean()
86 | if abs(mse)<1e-6:
87 | return 100
88 | return -10 * math.log10(mse)
89 |
90 |
91 | def psnr(a, b, range=255):
92 | return calc_psnr(a, b, 2, range)
93 |
94 |
95 | def lpf_sr(img: torch.Tensor, scale: int, omega=3., rgb_range=1):
96 | """
97 | Interpolate image(s) using the sinc function, it's slower than the cubic or others.
98 | :param img: the image to be interpolated.
99 | :param size: the expected size
100 | :param omega: the factor to adjust the scale of the sinc function
101 | :return: the interpolated image
102 | """
103 | if len(img.shape) == 4: # Batched
104 | origin_shape = img.shape
105 | img = img.view(-1, 1, img.shape[2], img.shape[3])
106 | out = lpf_sr_single(img, scale, omega, rgb_range=rgb_range)
107 | return out.reshape(origin_shape[0], origin_shape[1], origin_shape[2] * scale,
108 | origin_shape[3] * scale)
109 | else:
110 | origin_shape = img.shape
111 | img = img.view(-1, 1, img.shape[1], img.shape[2])
112 | out = lpf_sr_single(img, scale, omega, rgb_range=rgb_range)
113 | return out.reshape(origin_shape[0], origin_shape[1] * scale, origin_shape[2] * scale)
114 |
115 |
116 | def quantize(img, rgb_range):
117 | pixel_range = 255 / rgb_range
118 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
119 |
120 |
121 | class FileLogger:
122 | def __init__(self, exp_name):
123 | if not os.path.exists(os.path.join("ExperimentLogs", exp_name)):
124 | os.makedirs(os.path.join("ExperimentLogs", exp_name))
125 | self.exp_name = exp_name
126 | self.filename = os.path.join("ExperimentLogs", exp_name, "log output.txt")
127 | self.figure_file = os.path.join("ExperimentLogs", exp_name, "figure_curves.pt")
128 | self.figures = {}
129 |
130 | def print(self, text, append=True, sync_with_screen=True):
131 | with open(self.filename, "a" if append else "w") as f:
132 | print(datetime.now().strftime("%Y-%M-%d %H:%M:%S--->") + text, file=f)
133 | if sync_with_screen:
134 | print(datetime.now().strftime("%Y-%M-%d %H:%M:%S--->") + text)
135 |
136 | def log_figure(self, figure_name, figure: float):
137 | if figure_name in self.figures:
138 | self.figures[figure_name].append(figure)
139 | else:
140 | self.figures[figure_name] = [figure]
141 | torch.save(self.figures, self.figure_file)
142 |
143 | def save_model(self, obj, attribute: str = ""):
144 | torch.save(obj, os.path.join("ExperimentLogs", self.exp_name, f"check_point_{attribute}"))
145 |
146 | def torch2np(tensor):
147 | return einops.rearrange(tensor, "C H W -> H W C").cpu().numpy()
148 |
149 | def min_max_normalization(tensor):
150 | return (tensor-torch.min(tensor))/(torch.max(tensor)-torch.min(tensor))
151 |
152 | def mean_normalization(tensor):
153 | return (tensor - torch.mean(tensor)) / (torch.max(tensor) - torch.min(tensor))
--------------------------------------------------------------------------------