├── results
└── .gitkeep
├── samples
└── .gitkeep
├── assets
└── result.png
├── requirements.txt
├── scripts
├── run.py
└── prepare_dataset.py
├── data
└── README.md
├── .gitignore
├── config.py
├── test.py
├── model.py
├── README.md
├── utils.py
├── dataset.py
├── LICENSE
├── train.py
├── imgproc.py
└── image_quality_assessment.py
/results/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/samples/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/assets/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lornatang/RCAN-PyTorch/HEAD/assets/result.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python>=4.6.0.66
2 | numpy>=1.23.5
3 | tqdm>=4.63.1
4 | torch>=1.13.0+cu117
5 | natsort>=8.1.0
6 | typing>=3.7.4.3
7 | scipy>=1.9.3
--------------------------------------------------------------------------------
/scripts/run.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # Prepare dataset
4 | os.system("python ./prepare_dataset.py --images_dir ../data/DIV2K/original/train --output_dir ../data/DIV2K/RCAN/train --image_size 450 --step 225 --num_workers 10")
5 | os.system("python ./prepare_dataset.py --images_dir ../data/DIV2K/original/valid --output_dir ../data/DIV2K/RCAN/valid --image_size 450 --step 225 --num_workers 10")
6 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Usage
2 |
3 | ## Download datasets
4 |
5 | ### Download train dataset
6 |
7 | #### DIV2K
8 |
9 | - Image format
10 | - [Baidu Driver](https://pan.baidu.com/s/1EXXbhxxRDtqPosT2WL8NkA) access: `llot`
11 |
12 | ### Download valid dataset
13 |
14 | #### Set5
15 |
16 | - Image format
17 | - [Google Driver](https://drive.google.com/file/d/1GtQuoEN78q3AIP8vkh-17X90thYp_FfU/view?usp=sharing)
18 | - [Baidu Driver](https://pan.baidu.com/s/1dlPcpwRPUBOnxlfW5--S5g) access:`llot`
19 |
20 | #### Set14
21 |
22 | - Image format
23 | - [Google Driver](https://drive.google.com/file/d/1CzwwAtLSW9sog3acXj8s7Hg3S7kr2HiZ/view?usp=sharing)
24 | - [Baidu Driver](https://pan.baidu.com/s/1KBS38UAjM7bJ_e6a54eHaA) access:`llot`
25 |
26 | #### BSD100
27 |
28 | - Image format
29 | - [Google Driver](https://drive.google.com/file/d/1xkjWJGZgwWjDZZFN6KWlNMvHXmRORvdG/view?usp=sharing)
30 | - [Baidu Driver](https://pan.baidu.com/s/1EBVulUpsQrDmZfqnm4jOZw) access:`llot`
31 |
32 | #### BSD200
33 |
34 | - Image format
35 | - [Google Driver](https://drive.google.com/file/d/1cdMYTPr77RdOgyAvJPMQqaJHWrD5ma5n/view?usp=sharing)
36 | - [Baidu Driver](https://pan.baidu.com/s/1xahPw4dNNc3XspMMOuw1Bw) access:`llot`
37 |
38 | ## Train dataset struct information
39 |
40 | ### Image format
41 |
42 | ```text
43 | - DIV2K
44 | - RCAN
45 | - train
46 | - valid
47 | ```
48 |
49 | ## Test dataset struct information
50 |
51 | ### Image format
52 |
53 | ```text
54 | - Set5
55 | - GTmod12
56 | - baby.png
57 | - bird.png
58 | - ...
59 | - LRbicx4
60 | - baby.png
61 | - bird.png
62 | - ...
63 | - Set14
64 | - GTmod12
65 | - baboon.png
66 | - barbara.png
67 | - ...
68 | - LRbicx4
69 | - baboon.png
70 | - barbara.png
71 | - ...
72 | ```
73 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # pipenv
87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
90 | # install all needed dependencies.
91 | #Pipfile.lock
92 |
93 | # celery beat schedule file
94 | celerybeat-schedule
95 |
96 | # SageMath parsed files
97 | *.sage.py
98 |
99 | # Environments
100 | .env
101 | .venv
102 | env/
103 | venv/
104 | ENV/
105 | env.bak/
106 | venv.bak/
107 |
108 | # Spyder project settings
109 | .spyderproject
110 | .spyproject
111 |
112 | # Rope project settings
113 | .ropeproject
114 |
115 | # mkdocs documentation
116 | /site
117 | # mypy
118 | .mypy_cache/
119 | .dmypy.json
120 | dmypy.json
121 |
122 | # Pyre type checker
123 | .pyre/
124 |
125 | # custom
126 | .idea
127 | .vscode
128 |
129 | # Mac configure file.
130 | .DS_Store
131 |
132 | # Program run create directory.
133 | data
134 | results
135 | samples
136 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import random
15 |
16 | import numpy as np
17 | import torch
18 | from torch.backends import cudnn
19 |
20 | # Random seed to maintain reproducible results
21 | random.seed(0)
22 | torch.manual_seed(0)
23 | np.random.seed(0)
24 | # Use GPU for training by default
25 | device = torch.device("cuda", 0)
26 | # Turning on when the image size does not change during training can speed up training
27 | cudnn.benchmark = True
28 | # When evaluating the performance of the SR model, whether to verify only the Y channel image data
29 | only_test_y_channel = True
30 | # Model architecture name
31 | model_arch_name = "rcan_x4"
32 | # Upscale factor
33 | upscale_factor = 4
34 | # Current configuration parameter method
35 | mode = "train"
36 | # Experiment name, easy to save weights and log files
37 | exp_name = "RCAN_x4-DIV2K"
38 |
39 | if mode == "train":
40 | train_gt_images_dir = f"./data/DIV2K/RCAN/train"
41 |
42 | test_gt_images_dir = f"./data/Set5/GTmod12"
43 | test_lr_images_dir = f"./data/Set5/LRbicx{upscale_factor}"
44 |
45 | train_gt_image_size = int(upscale_factor * 48)
46 | batch_size = 16
47 | num_workers = 4
48 |
49 | # Load the address of the pretrained model
50 | pretrained_model_weights_path = f""
51 |
52 | # Incremental training and migration training
53 | resume_model_weights_path = f""
54 |
55 | # Total num epochs
56 | epochs = 1000
57 |
58 | # Loss function weight
59 | loss_weight = [1.0]
60 |
61 | # Optimizer parameter
62 | model_lr = 1e-4
63 | model_betas = (0.9, 0.99)
64 | model_eps = 1e-4 # Keep no nan
65 | model_weight_decay = 0.0
66 |
67 | # EMA parameter
68 | model_ema_decay = 0.999
69 |
70 | # StepLR scheduler parameter
71 | lr_scheduler_step_size = epochs // 5
72 | lr_scheduler_gamma = 0.5
73 |
74 | # How many iterations to print the training result
75 | train_print_frequency = 100
76 | test_print_frequency = 1
77 |
78 | if mode == "test":
79 | test_gt_images_dir = f"./data/Set5/GTmod12"
80 | test_sr_images_dir = f"./results/test/{exp_name}"
81 | test_lr_images_dir = f"./data/Set5/LRbicx{upscale_factor}"
82 |
83 | model_weights_path = f"./results/pretrained_models/RCAN_x4-DIV2K-2dfffdd2.pth.tar"
84 |
--------------------------------------------------------------------------------
/scripts/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import argparse
15 | import multiprocessing
16 | import os
17 | import shutil
18 |
19 | import cv2
20 | import numpy as np
21 | from tqdm import tqdm
22 |
23 |
24 | def main(args) -> None:
25 | if os.path.exists(args.output_dir):
26 | shutil.rmtree(args.output_dir)
27 | os.makedirs(args.output_dir)
28 |
29 | # Get all image paths
30 | image_file_names = os.listdir(args.images_dir)
31 |
32 | # Splitting images with multiple threads
33 | progress_bar = tqdm(total=len(image_file_names), unit="image", desc="Prepare split image")
34 | workers_pool = multiprocessing.Pool(args.num_workers)
35 | for image_file_name in image_file_names:
36 | workers_pool.apply_async(worker, args=(image_file_name, args), callback=lambda arg: progress_bar.update(1))
37 | workers_pool.close()
38 | workers_pool.join()
39 | progress_bar.close()
40 |
41 |
42 | def worker(image_file_name, args) -> None:
43 | image = cv2.imread(f"{args.images_dir}/{image_file_name}", cv2.IMREAD_UNCHANGED)
44 |
45 | image_height, image_width = image.shape[0:2]
46 |
47 | index = 1
48 | if image_height >= args.image_size and image_width >= args.image_size:
49 | for pos_y in range(0, image_height - args.image_size + 1, args.step):
50 | for pos_x in range(0, image_width - args.image_size + 1, args.step):
51 | # Crop
52 | crop_image = image[pos_y: pos_y + args.image_size, pos_x:pos_x + args.image_size, ...]
53 | crop_image = np.ascontiguousarray(crop_image)
54 | # Save image
55 | cv2.imwrite(f"{args.output_dir}/{image_file_name.split('.')[-2]}_{index:04d}.{image_file_name.split('.')[-1]}", crop_image)
56 |
57 | index += 1
58 |
59 |
60 | if __name__ == "__main__":
61 | parser = argparse.ArgumentParser(description="Prepare database scripts.")
62 | parser.add_argument("--images_dir", type=str, help="Path to input image directory.")
63 | parser.add_argument("--output_dir", type=str, help="Path to generator image directory.")
64 | parser.add_argument("--image_size", type=int, help="Low-resolution image size from raw image.")
65 | parser.add_argument("--step", type=int, help="Crop image similar to sliding window.")
66 | parser.add_argument("--num_workers", type=int, help="How many threads to open at the same time.")
67 | args = parser.parse_args()
68 |
69 | main(args)
70 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import time
15 |
16 | import torch
17 | from torch import nn
18 | from torch.utils.data import DataLoader
19 |
20 | import config
21 | import model
22 | from dataset import TestImageDataset, CUDAPrefetcher
23 | from utils import build_iqa_model, load_state_dict, make_directory, AverageMeter, ProgressMeter
24 |
25 |
26 | def load_dataset(test_gt_images_dir: str, test_lr_images_dir: str, device: torch.device) -> CUDAPrefetcher:
27 | test_datasets = TestImageDataset(test_gt_images_dir, test_lr_images_dir)
28 | test_dataloader = DataLoader(test_datasets,
29 | batch_size=1,
30 | shuffle=False,
31 | num_workers=1,
32 | pin_memory=True,
33 | drop_last=False,
34 | persistent_workers=False)
35 | test_data_prefetcher = CUDAPrefetcher(test_dataloader, device)
36 |
37 | return test_data_prefetcher
38 |
39 |
40 | def build_model(model_arch_name: str, device: torch.device) -> nn.Module:
41 | # Build model
42 | sr_model = model.__dict__[model_arch_name]()
43 | sr_model = sr_model.to(device=device)
44 | # Set the model to evaluation mode
45 | sr_model.eval()
46 |
47 | return sr_model
48 |
49 |
50 | def test(
51 | sr_model: nn.Module,
52 | test_data_prefetcher: CUDAPrefetcher,
53 | psnr_model: nn.Module,
54 | ssim_model: nn.Module,
55 | device: torch.device = torch.device("cpu"),
56 | print_frequency: int = 1,
57 | ) -> [float, float]:
58 | # The information printed by the progress bar
59 | batch_time = AverageMeter("Time", ":6.3f")
60 | psnres = AverageMeter("PSNR", ":4.2f")
61 | ssimes = AverageMeter("SSIM", ":4.4f")
62 | progress = ProgressMeter(len(test_data_prefetcher), [batch_time, psnres, ssimes], prefix=f"Test: ")
63 |
64 | # Set the model as validation model
65 | sr_model.eval()
66 |
67 | # Initialize data batches
68 | batch_index = 0
69 |
70 | # Set the data set iterator pointer to 0 and load the first batch of data
71 | test_data_prefetcher.reset()
72 | batch_data = test_data_prefetcher.next()
73 |
74 | # Record the start time of verifying a batch
75 | end = time.time()
76 |
77 | while batch_data is not None:
78 | # Load batches of data
79 | gt = batch_data["gt"].to(device=device, non_blocking=True)
80 | lr = batch_data["lr"].to(device=device, non_blocking=True)
81 |
82 | # inference
83 | with torch.no_grad():
84 | sr = sr_model(lr)
85 |
86 | # Calculate the image IQA
87 | psnr = psnr_model(sr, gt)
88 | ssim = ssim_model(sr, gt)
89 | psnres.update(psnr.item(), lr.size(0))
90 | ssimes.update(ssim.item(), lr.size(0))
91 |
92 | # Record the total time to verify a batch
93 | batch_time.update(time.time() - end)
94 | end = time.time()
95 |
96 | # Output a verification log information
97 | if batch_index % print_frequency == 0:
98 | progress.display(batch_index + 1)
99 |
100 | # Preload the next batch of data
101 | batch_data = test_data_prefetcher.next()
102 |
103 | # Add 1 to the number of data batches
104 | batch_index += 1
105 |
106 | # Print the performance index of the model at the current epoch
107 | progress.display_summary()
108 |
109 | return psnres.avg, ssimes.avg
110 |
111 |
112 | def main() -> None:
113 | test_data_prefetcher = load_dataset(config.test_gt_images_dir, config.test_lr_images_dir, config.device)
114 | sr_model = build_model(config.model_arch_name, config.device)
115 | psnr_model, ssim_model = build_iqa_model(config.upscale_factor, config.only_test_y_channel, config.device)
116 |
117 | # Load the super-resolution bsrgan_model weights
118 | sr_model = load_state_dict(sr_model, config.model_weights_path)
119 |
120 | # Create a folder of super-resolution experiment results
121 | make_directory(config.test_sr_images_dir)
122 |
123 | psnr, ssim = test(sr_model,
124 | test_data_prefetcher,
125 | psnr_model,
126 | ssim_model,
127 | config.device)
128 |
129 | print(f"PSNR: {psnr:.2f} dB"
130 | f"SSIM: {ssim:.4f} [u]")
131 |
132 |
133 | if __name__ == "__main__":
134 | main()
135 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import math
15 |
16 | import torch
17 | from torch import nn, Tensor
18 |
19 | __all__ = [
20 | "RCAN",
21 | "rcan_x2", "rcan_x3", "rcan_x4", "rcan_x8",
22 | ]
23 |
24 |
25 | class RCAN(nn.Module):
26 | def __init__(
27 | self,
28 | in_channels: int,
29 | out_channels: int,
30 | channels: int,
31 | reduction: int,
32 | num_rcab: int,
33 | num_rg: int,
34 | upscale_factor: int,
35 | rgb_mean: tuple = None,
36 | ) -> None:
37 | super(RCAN, self).__init__()
38 | if rgb_mean is None:
39 | rgb_mean = [0.4488, 0.4371, 0.4040]
40 |
41 | # The first layer of convolutional layer
42 | self.conv1 = nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1))
43 |
44 | # Feature extraction backbone
45 | trunk = []
46 | for _ in range(num_rg):
47 | trunk.append(_ResidualGroup(channels, reduction, num_rcab))
48 | self.trunk = nn.Sequential(*trunk)
49 |
50 | # After the feature extraction network, reconnect a layer of convolutional blocks
51 | self.conv2 = nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1))
52 |
53 | # Upsampling convolutional layer.
54 | upsampling = []
55 | if upscale_factor == 2 or upscale_factor == 4 or upscale_factor == 8:
56 | for _ in range(int(math.log(upscale_factor, 2))):
57 | upsampling.append(_UpsampleBlock(channels, 2))
58 | elif upscale_factor == 3:
59 | upsampling.append(_UpsampleBlock(channels, 3))
60 | self.upsampling = nn.Sequential(*upsampling)
61 |
62 | # Output layer.
63 | self.conv3 = nn.Conv2d(channels, out_channels, (3, 3), (1, 1), (1, 1))
64 |
65 | self.register_buffer("mean", Tensor(rgb_mean).view(1, 3, 1, 1))
66 |
67 | def forward(self, x: Tensor) -> Tensor:
68 | x = x.sub_(self.mean).mul_(1.)
69 |
70 | conv1 = self.conv1(x)
71 | x = self.trunk(conv1)
72 | x = self.conv2(x)
73 | x = torch.add(x, conv1)
74 | x = self.upsampling(x)
75 | x = self.conv3(x)
76 |
77 | x = x.div_(1.).add_(self.mean)
78 |
79 | return x
80 |
81 |
82 | class _ChannelAttentionLayer(nn.Module):
83 | def __init__(self, channel: int, reduction: int):
84 | super(_ChannelAttentionLayer, self).__init__()
85 | self.channel_attention_layer = nn.Sequential(
86 | nn.AdaptiveAvgPool2d(1),
87 | nn.Conv2d(channel, channel // reduction, (1, 1), (1, 1), (0, 0)),
88 | nn.ReLU(True),
89 | nn.Conv2d(channel // reduction, channel, (1, 1), (1, 1), (0, 0)),
90 | nn.Sigmoid(),
91 | )
92 |
93 | def forward(self, x: Tensor) -> Tensor:
94 | out = self.channel_attention_layer(x)
95 |
96 | out = torch.mul(out, x)
97 |
98 | return out
99 |
100 |
101 | class _ResidualChannelAttentionBlock(nn.Module):
102 | def __init__(self, channel: int, reduction: int):
103 | super(_ResidualChannelAttentionBlock, self).__init__()
104 | self.residual_channel_attention_block = nn.Sequential(
105 | nn.Conv2d(channel, channel, (3, 3), (1, 1), (1, 1)),
106 | nn.ReLU(True),
107 | nn.Conv2d(channel, channel, (3, 3), (1, 1), (1, 1)),
108 | _ChannelAttentionLayer(channel, reduction),
109 | )
110 |
111 | def forward(self, x: Tensor) -> Tensor:
112 | identity = x
113 |
114 | out = self.residual_channel_attention_block(x)
115 |
116 | out = torch.add(out, identity)
117 |
118 | return out
119 |
120 |
121 | class _ResidualGroup(nn.Module):
122 | def __init__(self, channel: int, reduction: int, num_rcab: int):
123 | super(_ResidualGroup, self).__init__()
124 | residual_group = []
125 |
126 | for _ in range(num_rcab):
127 | residual_group.append(_ResidualChannelAttentionBlock(channel, reduction))
128 | residual_group.append(nn.Conv2d(channel, channel, (3, 3), (1, 1), (1, 1)))
129 |
130 | self.residual_group = nn.Sequential(*residual_group)
131 |
132 | def forward(self, x: Tensor) -> Tensor:
133 | identity = x
134 |
135 | out = self.residual_group(x)
136 |
137 | out = torch.add(out, identity)
138 |
139 | return out
140 |
141 |
142 | class _UpsampleBlock(nn.Module):
143 | def __init__(self, channels: int, upscale_factor: int) -> None:
144 | super(_UpsampleBlock, self).__init__()
145 | self.upsample_block = nn.Sequential(
146 | nn.Conv2d(channels, channels * upscale_factor * upscale_factor, (3, 3), (1, 1), (1, 1)),
147 | nn.PixelShuffle(upscale_factor),
148 | )
149 |
150 | def forward(self, x: Tensor) -> Tensor:
151 | x = self.upsample_block(x)
152 |
153 | return x
154 |
155 |
156 | def rcan_x2(**kwargs) -> RCAN:
157 | model = RCAN(3, 3, 64, 16, 20, 10, 2, **kwargs)
158 |
159 | return model
160 |
161 |
162 | def rcan_x3(**kwargs) -> RCAN:
163 | model = RCAN(3, 3, 64, 16, 20, 10, 3, **kwargs)
164 |
165 | return model
166 |
167 |
168 | def rcan_x4(**kwargs) -> RCAN:
169 | model = RCAN(3, 3, 64, 16, 20, 10, 4, **kwargs)
170 |
171 | return model
172 |
173 |
174 | def rcan_x8(**kwargs) -> RCAN:
175 | model = RCAN(3, 3, 64, 16, 20, 10, 8, **kwargs)
176 |
177 | return model
178 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RCAN-PyTorch
2 |
3 | ### Overview
4 |
5 | This repository contains an op-for-op PyTorch reimplementation of [Image Super-Resolution Using Very Deep Residual Channel Attention Networks](https://arxiv.org/abs/1807.02758).
6 |
7 | ### Table of contents
8 |
9 | - [RCAN-PyTorch](#rcan-pytorch)
10 | - [Overview](#overview)
11 | - [Table of contents](#table-of-contents)
12 | - [About Image Super-Resolution Using Very Deep Residual Channel Attention Networks](#about-image-super-resolution-using-very-deep-residual-channel-attention-networks)
13 | - [Download weights](#download-weights)
14 | - [Download datasets](#download-datasets)
15 | - [Test](#test)
16 | - [Train](#train)
17 | - [Result](#result)
18 | - [Credit](#credit)
19 | - [Image Super-Resolution Using Very Deep Residual Channel Attention Networks](#image-super-resolution-using-very-deep-residual-channel-attention-networks)
20 |
21 | ## About Image Super-Resolution Using Very Deep Residual Channel Attention Networks
22 |
23 | If you're new to RCAN, here's an abstract straight from the paper:
24 |
25 | Convolutional neural network (CNN) depth is of crucial importance for image super-resolution (SR). However, we observe that deeper networks for image
26 | SR are more difficult to train. The lowresolution inputs and features contain abundant low-frequency information, which is treated equally across
27 | channels, hence hindering the representational ability of CNNs. To solve these problems, we propose the very deep residual channel attention
28 | networks (RCAN). Specifically, we propose a residual in residual (RIR) structure to form very deep network, which consists of several residual groups
29 | with long skip connections. Each residual group contains some residual blocks with short skip connections. Meanwhile, RIR allows abundant
30 | low-frequency information to be bypassed through multiple skip connections, making the main network focus on learning high-frequency information.
31 | Furthermore, we propose a channel attention mechanism to adaptively rescale channel-wise features by considering interdependencies among channels.
32 | Extensive experiments show that our RCAN achieves better accuracy and visual improvements against state-of-the-art methods.
33 |
34 | ## Download weights
35 |
36 | - [Google Driver](https://drive.google.com/drive/folders/17ju2HN7Y6pyPK2CC_AqnAfTOe9_3hCQ8?usp=sharing)
37 | - [Baidu Driver](https://pan.baidu.com/s/1yNs4rqIb004-NKEdKBJtYg?pwd=llot)
38 |
39 | ## Download datasets
40 |
41 | Contains DIV2K, DIV8K, Flickr2K, OST, T91, Set5, Set14, BSDS100 and BSDS200, etc.
42 |
43 | - [Google Driver](https://drive.google.com/drive/folders/1A6lzGeQrFMxPqJehK9s37ce-tPDj20mD?usp=sharing)
44 | - [Baidu Driver](https://pan.baidu.com/s/1o-8Ty_7q6DiS3ykLU09IVg?pwd=llot)
45 |
46 | ## Test
47 |
48 | Modify the contents of the file as follows.
49 |
50 | - line 29: `upscale_factor` change to the magnification you need to enlarge.
51 | - line 31: `mode` change Set to valid mode.
52 | - line 70: `model_path` change weight address after training.
53 |
54 | ## Train
55 |
56 | Modify the contents of the file as follows.
57 |
58 | - line 29: `upscale_factor` change to the magnification you need to enlarge.
59 | - line 31: `mode` change Set to train mode.
60 |
61 | If you want to load weights that you've trained before, modify the contents of the file as follows.
62 |
63 | ### Resume model
64 |
65 | - line 47: `start_epoch` change number of model training iterations in the previous round.
66 | - line 48: `resume` change to SRResNet model address that needs to be loaded.
67 |
68 | ## Result
69 |
70 | Source of original paper results: https://arxiv.org/pdf/1807.02758.pdf
71 |
72 | In the following table, the value in `()` indicates the result of the project, and `-` indicates no test.
73 |
74 | | Dataset | Scale | PSNR |
75 | |:-------:|:-----:|:----------------:|
76 | | Set5 | 2 | 38.27(**38.09**) |
77 | | Set5 | 3 | 34.74(**34.56**) |
78 | | Set5 | 4 | 32.63(**32.41**) |
79 | | Set5 | 8 | 27.31(**26.97**) |
80 |
81 | Low Resolution / Super Resolution / High Resolution
82 |
83 |
84 | ### Credit
85 |
86 | #### Image Super-Resolution Using Very Deep Residual Channel Attention Networks
87 |
88 | _Yulun Zhang, Kunpeng Li, Kai Li, Lichen Wang, Bineng Zhong, Yun Fu_
89 |
90 | **Abstract**
91 | Convolutional neural network (CNN) depth is of crucial importance for image super-resolution (SR). However, we observe that deeper networks for image
92 | SR are more difficult to train. The low-resolution inputs and features contain abundant low-frequency information, which is treated equally across
93 | channels, hence hindering the representational ability of CNNs. To solve these problems, we propose the very deep residual channel attention
94 | networks (RCAN). Specifically, we propose a residual in residual (RIR) structure to form very deep network, which consists of several residual groups
95 | with long skip connections. Each residual group contains some residual blocks with short skip connections. Meanwhile, RIR allows abundant
96 | low-frequency information to be bypassed through multiple skip connections, making the main network focus on learning high-frequency information.
97 | Furthermore, we propose a channel attention mechanism to adaptively rescale channel-wise features by considering interdependencies among channels.
98 | Extensive experiments show that our RCAN achieves better accuracy and visual improvements against state-of-the-art methods.
99 |
100 | [[Code]](https://github.com/yulunzhang/RCAN) [[Paper]](https://arxiv.org/pdf/1807.02758)
101 |
102 | ```
103 | @article{DBLP:journals/corr/abs-1807-02758,
104 | author = {Yulun Zhang and
105 | Kunpeng Li and
106 | Kai Li and
107 | Lichen Wang and
108 | Bineng Zhong and
109 | Yun Fu},
110 | title = {Image Super-Resolution Using Very Deep Residual Channel Attention
111 | Networks},
112 | journal = {CoRR},
113 | volume = {abs/1807.02758},
114 | year = {2018},
115 | url = {http://arxiv.org/abs/1807.02758},
116 | eprinttype = {arXiv},
117 | eprint = {1807.02758},
118 | timestamp = {Tue, 20 Nov 2018 12:24:39 +0100},
119 | biburl = {https://dblp.org/rec/journals/corr/abs-1807-02758.bib},
120 | bibsource = {dblp computer science bibliography, https://dblp.org}
121 | }
122 | ```
123 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import os
15 | import shutil
16 | from enum import Enum
17 | from typing import Any
18 |
19 | import torch
20 | from torch import nn
21 | from torch.nn import Module
22 | from torch.optim import Optimizer
23 |
24 | from image_quality_assessment import PSNR, SSIM
25 |
26 | __all__ = [
27 | "build_iqa_model", "load_state_dict", "make_directory", "save_checkpoint",
28 | "Summary", "AverageMeter", "ProgressMeter"
29 | ]
30 |
31 |
32 | def build_iqa_model(upscale_factor: int, only_test_y_channel: bool, device: torch.device) -> tuple[PSNR, SSIM]:
33 | psnr_model = PSNR(upscale_factor, only_test_y_channel)
34 | ssim_model = SSIM(upscale_factor, only_test_y_channel)
35 | psnr_model = psnr_model.to(device=device)
36 | ssim_model = ssim_model.to(device=device)
37 |
38 | return psnr_model, ssim_model
39 |
40 |
41 | def load_state_dict(
42 | model: nn.Module,
43 | model_weights_path: str,
44 | ema_model: nn.Module = None,
45 | optimizer: torch.optim.Optimizer = None,
46 | scheduler: torch.optim.lr_scheduler = None,
47 | load_mode: str = None,
48 | ) -> tuple[Module, Module, Any, Any, Any, Optimizer | None, Any] | tuple[Module, Any, Any, Any, Optimizer | None, Any] | Module:
49 | # Load model weights
50 | checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage)
51 |
52 | if load_mode == "resume":
53 | # Restore the parameters in the training node to this point
54 | start_epoch = checkpoint["epoch"]
55 | best_psnr = checkpoint["best_psnr"]
56 | best_ssim = checkpoint["best_ssim"]
57 | # Load model state dict. Extract the fitted model weights
58 | model_state_dict = model.state_dict()
59 | state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys()}
60 | # Overwrite the model weights to the current model (base model)
61 | model_state_dict.update(state_dict)
62 | model.load_state_dict(model_state_dict)
63 | # Load the optimizer model
64 | optimizer.load_state_dict(checkpoint["optimizer"])
65 |
66 | if scheduler is not None:
67 | # Load the scheduler model
68 | scheduler.load_state_dict(checkpoint["scheduler"])
69 |
70 | if ema_model is not None:
71 | # Load ema model state dict. Extract the fitted model weights
72 | ema_model_state_dict = ema_model.state_dict()
73 | ema_state_dict = {k: v for k, v in checkpoint["ema_state_dict"].items() if k in ema_model_state_dict.keys()}
74 | # Overwrite the model weights to the current model (ema model)
75 | ema_model_state_dict.update(ema_state_dict)
76 | ema_model.load_state_dict(ema_model_state_dict)
77 |
78 | return model, ema_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler
79 | else:
80 | # Load model state dict. Extract the fitted model weights
81 | model_state_dict = model.state_dict()
82 | state_dict = {k: v for k, v in checkpoint["state_dict"].items() if
83 | k in model_state_dict.keys() and v.size() == model_state_dict[k].size()}
84 | # Overwrite the model weights to the current model
85 | model_state_dict.update(state_dict)
86 | model.load_state_dict(model_state_dict)
87 |
88 | return model
89 |
90 |
91 | def make_directory(dir_path: str) -> None:
92 | if not os.path.exists(dir_path):
93 | os.makedirs(dir_path)
94 |
95 |
96 | def save_checkpoint(
97 | state_dict: dict,
98 | file_name: str,
99 | samples_dir: str,
100 | results_dir: str,
101 | best_file_name: str,
102 | last_file_name: str,
103 | is_best: bool = False,
104 | is_last: bool = False,
105 | ) -> None:
106 | checkpoint_path = os.path.join(samples_dir, file_name)
107 | torch.save(state_dict, checkpoint_path)
108 |
109 | if is_best:
110 | shutil.copyfile(checkpoint_path, os.path.join(results_dir, best_file_name))
111 | if is_last:
112 | shutil.copyfile(checkpoint_path, os.path.join(results_dir, last_file_name))
113 |
114 |
115 | class Summary(Enum):
116 | NONE = 0
117 | AVERAGE = 1
118 | SUM = 2
119 | COUNT = 3
120 |
121 |
122 | class AverageMeter(object):
123 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
124 | self.name = name
125 | self.fmt = fmt
126 | self.summary_type = summary_type
127 | self.reset()
128 |
129 | def reset(self):
130 | self.val = 0
131 | self.avg = 0
132 | self.sum = 0
133 | self.count = 0
134 |
135 | def update(self, val, n=1):
136 | self.val = val
137 | self.sum += val * n
138 | self.count += n
139 | self.avg = self.sum / self.count
140 |
141 | def __str__(self):
142 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
143 | return fmtstr.format(**self.__dict__)
144 |
145 | def summary(self):
146 | if self.summary_type is Summary.NONE:
147 | fmtstr = ""
148 | elif self.summary_type is Summary.AVERAGE:
149 | fmtstr = "{name} {avg:.2f}"
150 | elif self.summary_type is Summary.SUM:
151 | fmtstr = "{name} {sum:.2f}"
152 | elif self.summary_type is Summary.COUNT:
153 | fmtstr = "{name} {count:.2f}"
154 | else:
155 | raise ValueError(f"Invalid summary type {self.summary_type}")
156 |
157 | return fmtstr.format(**self.__dict__)
158 |
159 |
160 | class ProgressMeter(object):
161 | def __init__(self, num_batches, meters, prefix=""):
162 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
163 | self.meters = meters
164 | self.prefix = prefix
165 |
166 | def display(self, batch):
167 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
168 | entries += [str(meter) for meter in self.meters]
169 | print("\t".join(entries))
170 |
171 | def display_summary(self):
172 | entries = [" *"]
173 | entries += [meter.summary() for meter in self.meters]
174 | print(" ".join(entries))
175 |
176 | def _get_batch_fmtstr(self, num_batches):
177 | num_digits = len(str(num_batches // 1))
178 | fmt = "{:" + str(num_digits) + "d}"
179 | return "[" + fmt + "/" + fmt.format(num_batches) + "]"
180 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import os
15 | import queue
16 | import threading
17 |
18 | import cv2
19 | import numpy as np
20 | import torch
21 | from torch import Tensor
22 | from torch.utils.data import Dataset, DataLoader
23 |
24 | import imgproc
25 |
26 | __all__ = [
27 | "TrainImageDataset", "TestImageDataset",
28 | "PrefetchGenerator", "PrefetchDataLoader", "CPUPrefetcher", "CUDAPrefetcher",
29 | ]
30 |
31 |
32 | class TrainImageDataset(Dataset):
33 | """Define train dataset loading methods.
34 |
35 | Args:
36 | train_gt_image_dir (str): Train ground-truth dataset address.
37 | train_gt_image_size (int): Train ground-truth resolution image size.
38 | upscale_factor (int): Image up scale factor.
39 |
40 | """
41 |
42 | def __init__(
43 | self,
44 | train_gt_image_dir: str,
45 | train_gt_image_size: int,
46 | upscale_factor: int,
47 | ) -> None:
48 | super(TrainImageDataset, self).__init__()
49 | self.image_file_names = [os.path.join(train_gt_image_dir, image_file_name) for image_file_name in
50 | os.listdir(train_gt_image_dir)]
51 | self.train_gt_image_size = train_gt_image_size
52 | self.upscale_factor = upscale_factor
53 |
54 | def __getitem__(self, batch_index: int) -> [dict[str, Tensor], dict[str, Tensor]]:
55 | # Read a batch of image data
56 | gt_image = cv2.imread(self.image_file_names[batch_index]).astype(np.float32) / 255.
57 |
58 | # Image processing operations
59 | gt_crop_image = imgproc.random_crop(gt_image, self.train_gt_image_size)
60 | gt_crop_image = imgproc.random_rotate(gt_crop_image, [90, 180, 270])
61 | gt_crop_image = imgproc.random_horizontally_flip(gt_crop_image, 0.5)
62 | gt_crop_image = imgproc.random_vertically_flip(gt_crop_image, 0.5)
63 |
64 | lr_crop_image = imgproc.image_resize(gt_crop_image, 1 / self.upscale_factor)
65 |
66 | # BGR convert RGB
67 | gt_crop_image = cv2.cvtColor(gt_crop_image, cv2.COLOR_BGR2RGB)
68 | lr_crop_image = cv2.cvtColor(lr_crop_image, cv2.COLOR_BGR2RGB)
69 |
70 | # Convert image data into Tensor stream format (PyTorch).
71 | # Note: The range of input and output is between [0, 1]
72 | gt_crop_tensor = imgproc.image_to_tensor(gt_crop_image, False, False)
73 | lr_crop_tensor = imgproc.image_to_tensor(lr_crop_image, False, False)
74 |
75 | return {"gt": gt_crop_tensor, "lr": lr_crop_tensor}
76 |
77 | def __len__(self) -> int:
78 | return len(self.image_file_names)
79 |
80 |
81 | class TestImageDataset(Dataset):
82 | """Define Test dataset loading methods.
83 |
84 | Args:
85 | test_gt_images_dir (str): ground truth image in test image
86 | test_lr_images_dir (str): low-resolution image in test image
87 | """
88 |
89 | def __init__(self, test_gt_images_dir: str, test_lr_images_dir: str) -> None:
90 | super(TestImageDataset, self).__init__()
91 | # Get all image file names in folder
92 | self.gt_image_file_names = [os.path.join(test_gt_images_dir, x) for x in os.listdir(test_gt_images_dir)]
93 | self.lr_image_file_names = [os.path.join(test_lr_images_dir, x) for x in os.listdir(test_lr_images_dir)]
94 |
95 | def __getitem__(self, batch_index: int) -> [torch.Tensor, torch.Tensor]:
96 | # Read a batch of image data
97 | gt_image = cv2.imread(self.gt_image_file_names[batch_index]).astype(np.float32) / 255.
98 | lr_image = cv2.imread(self.lr_image_file_names[batch_index]).astype(np.float32) / 255.
99 |
100 | # BGR convert RGB
101 | gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB)
102 | lr_image = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB)
103 |
104 | # Convert image data into Tensor stream format (PyTorch).
105 | # Note: The range of input and output is between [0, 1]
106 | gt_tensor = imgproc.image_to_tensor(gt_image, False, False)
107 | lr_tensor = imgproc.image_to_tensor(lr_image, False, False)
108 |
109 | return {"gt": gt_tensor, "lr": lr_tensor}
110 |
111 | def __len__(self) -> int:
112 | return len(self.gt_image_file_names)
113 |
114 |
115 | class PrefetchGenerator(threading.Thread):
116 | """A fast data prefetch generator.
117 |
118 | Args:
119 | generator: Data generator.
120 | num_data_prefetch_queue (int): How many early data load queues.
121 | """
122 |
123 | def __init__(self, generator, num_data_prefetch_queue: int) -> None:
124 | threading.Thread.__init__(self)
125 | self.queue = queue.Queue(num_data_prefetch_queue)
126 | self.generator = generator
127 | self.daemon = True
128 | self.start()
129 |
130 | def run(self) -> None:
131 | for item in self.generator:
132 | self.queue.put(item)
133 | self.queue.put(None)
134 |
135 | def __next__(self):
136 | next_item = self.queue.get()
137 | if next_item is None:
138 | raise StopIteration
139 | return next_item
140 |
141 | def __iter__(self):
142 | return self
143 |
144 |
145 | class PrefetchDataLoader(DataLoader):
146 | """A fast data prefetch dataloader.
147 |
148 | Args:
149 | num_data_prefetch_queue (int): How many early data load queues.
150 | kwargs (dict): Other extended parameters.
151 | """
152 |
153 | def __init__(self, num_data_prefetch_queue: int, **kwargs) -> None:
154 | self.num_data_prefetch_queue = num_data_prefetch_queue
155 | super(PrefetchDataLoader, self).__init__(**kwargs)
156 |
157 | def __iter__(self):
158 | return PrefetchGenerator(super().__iter__(), self.num_data_prefetch_queue)
159 |
160 |
161 | class CPUPrefetcher:
162 | """Use the CPU side to accelerate data reading.
163 |
164 | Args:
165 | dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
166 | """
167 |
168 | def __init__(self, dataloader: DataLoader) -> None:
169 | self.original_dataloader = dataloader
170 | self.data = iter(dataloader)
171 |
172 | def next(self):
173 | try:
174 | return next(self.data)
175 | except StopIteration:
176 | return None
177 |
178 | def reset(self):
179 | self.data = iter(self.original_dataloader)
180 |
181 | def __len__(self) -> int:
182 | return len(self.original_dataloader)
183 |
184 |
185 | class CUDAPrefetcher:
186 | """Use the CUDA side to accelerate data reading.
187 |
188 | Args:
189 | dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
190 | device (torch.device): Specify running device.
191 | """
192 |
193 | def __init__(self, dataloader: DataLoader, device: torch.device):
194 | self.batch_data = None
195 | self.original_dataloader = dataloader
196 | self.device = device
197 |
198 | self.data = iter(dataloader)
199 | self.stream = torch.cuda.Stream()
200 | self.preload()
201 |
202 | def preload(self):
203 | try:
204 | self.batch_data = next(self.data)
205 | except StopIteration:
206 | self.batch_data = None
207 | return None
208 |
209 | with torch.cuda.stream(self.stream):
210 | for k, v in self.batch_data.items():
211 | if torch.is_tensor(v):
212 | self.batch_data[k] = self.batch_data[k].to(self.device, non_blocking=True)
213 |
214 | def next(self):
215 | torch.cuda.current_stream().wait_stream(self.stream)
216 | batch_data = self.batch_data
217 | self.preload()
218 | return batch_data
219 |
220 | def reset(self):
221 | self.data = iter(self.original_dataloader)
222 | self.preload()
223 |
224 | def __len__(self) -> int:
225 | return len(self.original_dataloader)
226 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ============================================================================
14 | import os
15 | import time
16 |
17 | import torch
18 | from torch import nn, optim
19 | from torch.cuda import amp
20 | from torch.optim import lr_scheduler
21 | from torch.optim.swa_utils import AveragedModel
22 | from torch.utils.data import DataLoader
23 | from torch.utils.tensorboard import SummaryWriter
24 |
25 | import config
26 | import model
27 | from dataset import CUDAPrefetcher, TrainImageDataset, TestImageDataset
28 | from test import test
29 | from utils import build_iqa_model, load_state_dict, make_directory, save_checkpoint, AverageMeter, ProgressMeter
30 |
31 |
32 | def main():
33 | # Initialize the gradient scaler
34 | scaler = amp.GradScaler()
35 |
36 | # Initialize the number of training epochs
37 | start_epoch = 0
38 |
39 | # Initialize training to generate network evaluation indicators
40 | best_psnr = 0.0
41 | best_ssim = 0.0
42 |
43 | train_data_prefetcher, test_data_prefetcher = load_dataset(config.train_gt_images_dir,
44 | config.train_gt_image_size,
45 | config.test_gt_images_dir,
46 | config.test_lr_images_dir,
47 | config.upscale_factor,
48 | config.batch_size,
49 | config.num_workers,
50 | config.device)
51 | print("Load all datasets successfully.")
52 |
53 | sr_model, ema_sr_model = build_model(config.model_arch_name, config.device)
54 | print(f"Build `{config.model_arch_name}` model successfully.")
55 |
56 | criterion = define_loss(config.device)
57 | print("Define all loss functions successfully.")
58 |
59 | optimizer = define_optimizer(sr_model)
60 | print("Define all optimizer functions successfully.")
61 |
62 | scheduler = define_scheduler(optimizer)
63 | print("Define all optimizer scheduler functions successfully.")
64 |
65 | # Create an IQA evaluation model
66 | psnr_model, ssim_model = build_iqa_model(config.upscale_factor, config.only_test_y_channel, config.device)
67 |
68 | print("Check whether to load pretrained model weights...")
69 | if config.pretrained_model_weights_path:
70 | sr_model = load_state_dict(sr_model, config.pretrained_model_weights_path)
71 | print(f"Loaded `{config.pretrained_model_weights_path}` pretrained model weights successfully.")
72 | else:
73 | print("Pretrained model weights not found.")
74 |
75 | print("Check whether the resume model is restored...")
76 | if config.resume_model_weights_path:
77 | sr_model, ema_sr_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler = load_state_dict(
78 | sr_model,
79 | config.resume_model_weights_path,
80 | ema_sr_model,
81 | optimizer,
82 | scheduler,
83 | "resume")
84 | print("Loaded resume model weights.")
85 | else:
86 | print("Resume training model not found. Start training from scratch.")
87 |
88 | # Create a experiment results
89 | samples_dir = os.path.join("samples", config.exp_name)
90 | results_dir = os.path.join("results", config.exp_name)
91 | make_directory(samples_dir)
92 | make_directory(results_dir)
93 |
94 | # Create training process log file
95 | writer = SummaryWriter(os.path.join("samples", "logs", config.exp_name))
96 |
97 | for epoch in range(start_epoch, config.epochs):
98 | train(sr_model,
99 | ema_sr_model,
100 | train_data_prefetcher,
101 | criterion,
102 | optimizer,
103 | epoch,
104 | scaler,
105 | writer,
106 | config.device,
107 | config.train_print_frequency)
108 | psnr, ssim = test(sr_model,
109 | test_data_prefetcher,
110 | psnr_model,
111 | ssim_model,
112 | config.device,
113 | config.test_print_frequency)
114 |
115 | # Write the evaluation results to the tensorboard
116 | writer.add_scalar(f"Test/PSNR", psnr, epoch + 1)
117 | writer.add_scalar(f"Test/SSIM", ssim, epoch + 1)
118 |
119 | print("\n")
120 |
121 | # Update LR
122 | scheduler.step()
123 |
124 | # Automatically save the model with the highest index
125 | is_best = psnr > best_psnr and ssim > best_ssim
126 | is_last = (epoch + 1) == config.epochs
127 | best_psnr = max(psnr, best_psnr)
128 | best_ssim = max(ssim, best_ssim)
129 | save_checkpoint({"epoch": epoch + 1,
130 | "best_psnr": best_psnr,
131 | "best_ssim": best_ssim,
132 | "state_dict": sr_model.state_dict(),
133 | "ema_state_dict": ema_sr_model.state_dict(),
134 | "optimizer": optimizer.state_dict(),
135 | "scheduler": scheduler.state_dict()},
136 | f"epoch_{epoch + 1}.pth.tar",
137 | samples_dir,
138 | results_dir,
139 | "best.pth.tar",
140 | "last.pth.tar",
141 | is_best,
142 | is_last)
143 |
144 |
145 | def load_dataset(
146 | train_gt_images_dir: str,
147 | train_gt_image_size: int,
148 | test_gt_images_dir: str,
149 | test_lr_images_dir: str,
150 | upscale_factor: int,
151 | batch_size: int,
152 | num_workers: int,
153 | device: torch.device,
154 | ) -> [CUDAPrefetcher, CUDAPrefetcher]:
155 | # Load train, test and valid datasets
156 | train_datasets = TrainImageDataset(train_gt_images_dir, train_gt_image_size, upscale_factor)
157 | test_datasets = TestImageDataset(test_gt_images_dir, test_lr_images_dir)
158 |
159 | # Generator all dataloader
160 | train_dataloader = DataLoader(train_datasets,
161 | batch_size=batch_size,
162 | shuffle=True,
163 | num_workers=num_workers,
164 | pin_memory=True,
165 | drop_last=True,
166 | persistent_workers=True)
167 | test_dataloader = DataLoader(test_datasets,
168 | batch_size=1,
169 | shuffle=False,
170 | num_workers=1,
171 | pin_memory=True,
172 | drop_last=False,
173 | persistent_workers=False)
174 |
175 | # Place all data on the preprocessing data loader
176 | train_prefetcher = CUDAPrefetcher(train_dataloader, device)
177 | test_prefetcher = CUDAPrefetcher(test_dataloader, device)
178 |
179 | return train_prefetcher, test_prefetcher
180 |
181 |
182 | def build_model(model_arch_name: str, device: torch.device) -> [nn.Module, nn.Module]:
183 | # Build model
184 | sr_model = model.__dict__[model_arch_name]()
185 | # Generate exponential average model, stabilize model training
186 | ema_avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: (1 - config.model_ema_decay) * averaged_model_parameter + config.model_ema_decay * model_parameter
187 | ema_sr_model = AveragedModel(sr_model, avg_fn=ema_avg_fn)
188 |
189 | sr_model = sr_model.to(device=device)
190 | ema_sr_model = ema_sr_model.to(device=device)
191 |
192 | return sr_model, ema_sr_model
193 |
194 |
195 | def define_loss(device) -> nn.L1Loss:
196 | criterion = nn.L1Loss().to(device=device)
197 |
198 | return criterion
199 |
200 |
201 | def define_optimizer(sr_model: nn.Module) -> optim.Adam:
202 | optimizer = optim.Adam(sr_model.parameters(),
203 | config.model_lr,
204 | config.model_betas,
205 | config.model_eps)
206 |
207 | return optimizer
208 |
209 |
210 | def define_scheduler(optimizer) -> lr_scheduler.StepLR:
211 | scheduler = lr_scheduler.StepLR(optimizer,
212 | config.lr_scheduler_step_size,
213 | config.lr_scheduler_gamma)
214 |
215 | return scheduler
216 |
217 |
218 | def train(
219 | sr_model: nn.Module,
220 | ema_sr_model: nn.Module,
221 | train_data_prefetcher: CUDAPrefetcher,
222 | criterion: nn.L1Loss,
223 | optimizer: optim.Adam,
224 | epoch: int,
225 | scaler: amp.GradScaler,
226 | writer: SummaryWriter,
227 | device: torch.device = torch.device("cpu"),
228 | print_frequency: int = 1,
229 | ) -> None:
230 | # Calculate how many iterations there are under epoch
231 | batches = len(train_data_prefetcher)
232 | # Progress bar print information
233 | batch_time = AverageMeter("Time", ":6.3f")
234 | data_time = AverageMeter("Data", ":6.3f")
235 | losses = AverageMeter("Loss", ":6.6f")
236 | progress = ProgressMeter(batches, [batch_time, data_time, losses], prefix=f"Epoch: [{epoch}]")
237 |
238 | # Put the generator in training mode
239 | sr_model.train()
240 |
241 | # Define loss function weights
242 | loss_weight = torch.Tensor(config.loss_weight).to(device=device)
243 |
244 | # Initialize data batches
245 | batch_index = 0
246 | # Set the dataset iterator pointer to 0
247 | train_data_prefetcher.reset()
248 | # Record the start time of training a batch
249 | end = time.time()
250 | # load the first batch of data
251 | batch_data = train_data_prefetcher.next()
252 |
253 | while batch_data is not None:
254 | gt = batch_data["gt"].to(config.device, non_blocking=True)
255 | lr = batch_data["lr"].to(config.device, non_blocking=True)
256 |
257 | # Record the data loading time for training a batch
258 | data_time.update(time.time() - end)
259 |
260 | # Initialize the generator gradient
261 | sr_model.zero_grad(set_to_none=True)
262 |
263 | # Mixed precision training
264 | with amp.autocast():
265 | sr = sr_model(lr)
266 | loss = criterion(sr, gt)
267 | loss = torch.sum(torch.mul(loss_weight, loss))
268 |
269 | # Gradient zoom
270 | scaler.scale(loss).backward()
271 | scaler.unscale_(optimizer)
272 | # Update generator weight
273 | scaler.step(optimizer)
274 | scaler.update()
275 |
276 | # update exponentially averaged model weights
277 | ema_sr_model.update_parameters(sr_model)
278 |
279 | # record the loss value
280 | losses.update(loss.item(), lr.size(0))
281 |
282 | # Record the total time of training a batch
283 | batch_time.update(time.time() - end)
284 | end = time.time()
285 |
286 | # Output training log information once
287 | if batch_index % print_frequency == 0:
288 | # Write training log information to tensorboard
289 | writer.add_scalar("Train/Loss", loss.item(), batch_index + epoch * batches + 1)
290 | progress.display(batch_index)
291 |
292 | # Preload the next batch of data
293 | batch_data = train_data_prefetcher.next()
294 |
295 | # Add 1 to the number of data batches
296 | batch_index += 1
297 |
298 |
299 | if __name__ == "__main__":
300 | main()
301 |
--------------------------------------------------------------------------------
/imgproc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import math
15 | import random
16 | from typing import Any
17 |
18 | import cv2
19 | import numpy as np
20 | import torch
21 | from numpy import ndarray
22 | from torch import Tensor
23 |
24 | __all__ = [
25 | "image_to_tensor", "tensor_to_image",
26 | "image_resize", "preprocess_one_image",
27 | "expand_y", "rgb_to_ycbcr", "bgr_to_ycbcr", "ycbcr_to_bgr", "ycbcr_to_rgb",
28 | "rgb_to_ycbcr_torch", "bgr_to_ycbcr_torch",
29 | "center_crop", "random_crop", "random_rotate", "random_vertically_flip", "random_horizontally_flip",
30 | ]
31 |
32 |
33 | def _cubic(x: Any) -> Any:
34 | """Implementation of `cubic` function in Matlab under Python language.
35 |
36 | Args:
37 | x: Element vector.
38 |
39 | Returns:
40 | Bicubic interpolation
41 |
42 | """
43 | absx = torch.abs(x)
44 | absx2 = absx ** 2
45 | absx3 = absx ** 3
46 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
47 | -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (
48 | ((absx > 1) * (absx <= 2)).type_as(absx))
49 |
50 |
51 | def _calculate_weights_indices(in_length: int,
52 | out_length: int,
53 | scale: float,
54 | kernel_width: int,
55 | antialiasing: bool) -> [np.ndarray, np.ndarray, int, int]:
56 | """Implementation of `calculate_weights_indices` function in Matlab under Python language.
57 |
58 | Args:
59 | in_length (int): Input length.
60 | out_length (int): Output length.
61 | scale (float): Scale factor.
62 | kernel_width (int): Kernel width.
63 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
64 | Caution: Bicubic down-sampling in PIL uses antialiasing by default.
65 |
66 | Returns:
67 | weights, indices, sym_len_s, sym_len_e
68 |
69 | """
70 | if (scale < 1) and antialiasing:
71 | # Use a modified kernel (larger kernel width) to simultaneously
72 | # interpolate and antialiasing
73 | kernel_width = kernel_width / scale
74 |
75 | # Output-space coordinates
76 | x = torch.linspace(1, out_length, out_length)
77 |
78 | # Input-space coordinates. Calculate the inverse mapping such that 0.5
79 | # in output space maps to 0.5 in input space, and 0.5 + scale in output
80 | # space maps to 1.5 in input space.
81 | u = x / scale + 0.5 * (1 - 1 / scale)
82 |
83 | # What is the left-most pixel that can be involved in the computation?
84 | left = torch.floor(u - kernel_width / 2)
85 |
86 | # What is the maximum number of pixels that can be involved in the
87 | # computation? Note: it's OK to use an extra pixel here; if the
88 | # corresponding weights are all zero, it will be eliminated at the end
89 | # of this function.
90 | p = math.ceil(kernel_width) + 2
91 |
92 | # The indices of the input pixels involved in computing the k-th output
93 | # pixel are in row k of the indices matrix.
94 | indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
95 | out_length, p)
96 |
97 | # The weights used to compute the k-th output pixel are in row k of the
98 | # weights matrix.
99 | distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
100 |
101 | # apply cubic kernel
102 | if (scale < 1) and antialiasing:
103 | weights = scale * _cubic(distance_to_center * scale)
104 | else:
105 | weights = _cubic(distance_to_center)
106 |
107 | # Normalize the weights matrix so that each row sums to 1.
108 | weights_sum = torch.sum(weights, 1).view(out_length, 1)
109 | weights = weights / weights_sum.expand(out_length, p)
110 |
111 | # If a column in weights is all zero, get rid of it. only consider the
112 | # first and last column.
113 | weights_zero_tmp = torch.sum((weights == 0), 0)
114 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
115 | indices = indices.narrow(1, 1, p - 2)
116 | weights = weights.narrow(1, 1, p - 2)
117 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
118 | indices = indices.narrow(1, 0, p - 2)
119 | weights = weights.narrow(1, 0, p - 2)
120 | weights = weights.contiguous()
121 | indices = indices.contiguous()
122 | sym_len_s = -indices.min() + 1
123 | sym_len_e = indices.max() - in_length
124 | indices = indices + sym_len_s - 1
125 | return weights, indices, int(sym_len_s), int(sym_len_e)
126 |
127 |
128 | def image_to_tensor(image: ndarray, range_norm: bool, half: bool) -> Tensor:
129 | """Convert the image data type to the Tensor (NCWH) data type supported by PyTorch
130 |
131 | Args:
132 | image (np.ndarray): The image data read by ``OpenCV.imread``, the data range is [0,255] or [0, 1]
133 | range_norm (bool): Scale [0, 1] data to between [-1, 1]
134 | half (bool): Whether to convert torch.float32 similarly to torch.half type
135 |
136 | Returns:
137 | tensor (Tensor): Data types supported by PyTorch
138 |
139 | Examples:
140 | >>> example_image = cv2.imread("lr_image.bmp")
141 | >>> example_tensor = image_to_tensor(example_image, range_norm=True, half=False)
142 |
143 | """
144 | # Convert image data type to Tensor data type
145 | tensor = torch.from_numpy(np.ascontiguousarray(image)).permute(2, 0, 1).float()
146 |
147 | # Scale the image data from [0, 1] to [-1, 1]
148 | if range_norm:
149 | tensor = tensor.mul(2.0).sub(1.0)
150 |
151 | # Convert torch.float32 image data type to torch.half image data type
152 | if half:
153 | tensor = tensor.half()
154 |
155 | return tensor
156 |
157 |
158 | def tensor_to_image(tensor: Tensor, range_norm: bool, half: bool) -> Any:
159 | """Convert the Tensor(NCWH) data type supported by PyTorch to the np.ndarray(WHC) image data type
160 |
161 | Args:
162 | tensor (Tensor): Data types supported by PyTorch (NCHW), the data range is [0, 1]
163 | range_norm (bool): Scale [-1, 1] data to between [0, 1]
164 | half (bool): Whether to convert torch.float32 similarly to torch.half type.
165 |
166 | Returns:
167 | image (np.ndarray): Data types supported by PIL or OpenCV
168 |
169 | Examples:
170 | >>> example_image = cv2.imread("lr_image.bmp")
171 | >>> example_tensor = image_to_tensor(example_image, range_norm=False, half=False)
172 |
173 | """
174 | if range_norm:
175 | tensor = tensor.add(1.0).div(2.0)
176 | if half:
177 | tensor = tensor.half()
178 |
179 | image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8")
180 |
181 | return image
182 |
183 |
184 | def preprocess_one_image(image_path: str, device: torch.device) -> Tensor:
185 | image = cv2.imread(image_path).astype(np.float32) / 255.0
186 |
187 | # BGR to RGB
188 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
189 |
190 | # Convert image data to pytorch format data
191 | tensor = image_to_tensor(image, False, False).unsqueeze_(0)
192 |
193 | # Transfer tensor channel image format data to CUDA device
194 | tensor = tensor.to(device=device, memory_format=torch.channels_last, non_blocking=True)
195 |
196 | return tensor
197 |
198 |
199 | def image_resize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any:
200 | """Implementation of `imresize` function in Matlab under Python language.
201 |
202 | Args:
203 | image: The input image.
204 | scale_factor (float): Scale factor. The same scale applies for both height and width.
205 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
206 | Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``.
207 |
208 | Returns:
209 | out_2 (np.ndarray): Output image with shape (c, h, w), [0, 1] range, w/o round
210 |
211 | """
212 | squeeze_flag = False
213 | if type(image).__module__ == np.__name__: # numpy type
214 | numpy_type = True
215 | if image.ndim == 2:
216 | image = image[:, :, None]
217 | squeeze_flag = True
218 | image = torch.from_numpy(image.transpose(2, 0, 1)).float()
219 | else:
220 | numpy_type = False
221 | if image.ndim == 2:
222 | image = image.unsqueeze(0)
223 | squeeze_flag = True
224 |
225 | in_c, in_h, in_w = image.size()
226 | out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor)
227 | kernel_width = 4
228 |
229 | # get weights and indices
230 | weights_h, indices_h, sym_len_hs, sym_len_he = _calculate_weights_indices(in_h, out_h, scale_factor, kernel_width,
231 | antialiasing)
232 | weights_w, indices_w, sym_len_ws, sym_len_we = _calculate_weights_indices(in_w, out_w, scale_factor, kernel_width,
233 | antialiasing)
234 | # process H dimension
235 | # symmetric copying
236 | img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
237 | img_aug.narrow(1, sym_len_hs, in_h).copy_(image)
238 |
239 | sym_patch = image[:, :sym_len_hs, :]
240 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
241 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
242 | img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
243 |
244 | sym_patch = image[:, -sym_len_he:, :]
245 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
246 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
247 | img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
248 |
249 | out_1 = torch.FloatTensor(in_c, out_h, in_w)
250 | kernel_width = weights_h.size(1)
251 | for i in range(out_h):
252 | idx = int(indices_h[i][0])
253 | for j in range(in_c):
254 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
255 |
256 | # process W dimension
257 | # symmetric copying
258 | out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
259 | out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
260 |
261 | sym_patch = out_1[:, :, :sym_len_ws]
262 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
263 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
264 | out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
265 |
266 | sym_patch = out_1[:, :, -sym_len_we:]
267 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
268 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
269 | out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
270 |
271 | out_2 = torch.FloatTensor(in_c, out_h, out_w)
272 | kernel_width = weights_w.size(1)
273 | for i in range(out_w):
274 | idx = int(indices_w[i][0])
275 | for j in range(in_c):
276 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
277 |
278 | if squeeze_flag:
279 | out_2 = out_2.squeeze(0)
280 | if numpy_type:
281 | out_2 = out_2.numpy()
282 | if not squeeze_flag:
283 | out_2 = out_2.transpose(1, 2, 0)
284 |
285 | return out_2
286 |
287 |
288 | def expand_y(image: np.ndarray) -> np.ndarray:
289 | """Convert BGR channel to YCbCr format,
290 | and expand Y channel data in YCbCr, from HW to HWC
291 |
292 | Args:
293 | image (np.ndarray): Y channel image data
294 |
295 | Returns:
296 | y_image (np.ndarray): Y-channel image data in HWC form
297 |
298 | """
299 | # Normalize image data to [0, 1]
300 | image = image.astype(np.float32) / 255.
301 |
302 | # Convert BGR to YCbCr, and extract only Y channel
303 | y_image = bgr_to_ycbcr(image, only_use_y_channel=True)
304 |
305 | # Expand Y channel
306 | y_image = y_image[..., None]
307 |
308 | # Normalize the image data to [0, 255]
309 | y_image = y_image.astype(np.float64) * 255.0
310 |
311 | return y_image
312 |
313 |
314 | def rgb_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray:
315 | """Implementation of rgb2ycbcr function in Matlab under Python language
316 |
317 | Args:
318 | image (np.ndarray): Image input in RGB format.
319 | only_use_y_channel (bool): Extract Y channel separately
320 |
321 | Returns:
322 | image (np.ndarray): YCbCr image array data
323 |
324 | """
325 | if only_use_y_channel:
326 | image = np.dot(image, [65.481, 128.553, 24.966]) + 16.0
327 | else:
328 | image = np.matmul(image, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [
329 | 16, 128, 128]
330 |
331 | image /= 255.
332 | image = image.astype(np.float32)
333 |
334 | return image
335 |
336 |
337 | def bgr_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray:
338 | """Implementation of bgr2ycbcr function in Matlab under Python language.
339 |
340 | Args:
341 | image (np.ndarray): Image input in BGR format
342 | only_use_y_channel (bool): Extract Y channel separately
343 |
344 | Returns:
345 | image (np.ndarray): YCbCr image array data
346 |
347 | """
348 | if only_use_y_channel:
349 | image = np.dot(image, [24.966, 128.553, 65.481]) + 16.0
350 | else:
351 | image = np.matmul(image, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [
352 | 16, 128, 128]
353 |
354 | image /= 255.
355 | image = image.astype(np.float32)
356 |
357 | return image
358 |
359 |
360 | def ycbcr_to_rgb(image: np.ndarray) -> np.ndarray:
361 | """Implementation of ycbcr2rgb function in Matlab under Python language.
362 |
363 | Args:
364 | image (np.ndarray): Image input in YCbCr format.
365 |
366 | Returns:
367 | image (np.ndarray): RGB image array data
368 |
369 | """
370 | image_dtype = image.dtype
371 | image *= 255.
372 |
373 | image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
374 | [0, -0.00153632, 0.00791071],
375 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
376 |
377 | image /= 255.
378 | image = image.astype(image_dtype)
379 |
380 | return image
381 |
382 |
383 | def ycbcr_to_bgr(image: np.ndarray) -> np.ndarray:
384 | """Implementation of ycbcr2bgr function in Matlab under Python language.
385 |
386 | Args:
387 | image (np.ndarray): Image input in YCbCr format.
388 |
389 | Returns:
390 | image (np.ndarray): BGR image array data
391 |
392 | """
393 | image_dtype = image.dtype
394 | image *= 255.
395 |
396 | image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
397 | [0.00791071, -0.00153632, 0],
398 | [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921]
399 |
400 | image /= 255.
401 | image = image.astype(image_dtype)
402 |
403 | return image
404 |
405 |
406 | def rgb_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor:
407 | """Implementation of rgb2ycbcr function in Matlab under PyTorch
408 |
409 | References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion`
410 |
411 | Args:
412 | tensor (Tensor): Image data in PyTorch format
413 | only_use_y_channel (bool): Extract only Y channel
414 |
415 | Returns:
416 | tensor (Tensor): YCbCr image data in PyTorch format
417 |
418 | """
419 | if only_use_y_channel:
420 | weight = Tensor([[65.481], [128.553], [24.966]]).to(tensor)
421 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
422 | else:
423 | weight = Tensor([[65.481, -37.797, 112.0],
424 | [128.553, -74.203, -93.786],
425 | [24.966, 112.0, -18.214]]).to(tensor)
426 | bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor)
427 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
428 |
429 | tensor /= 255.
430 |
431 | return tensor
432 |
433 |
434 | def bgr_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor:
435 | """Implementation of bgr2ycbcr function in Matlab under PyTorch
436 |
437 | References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion`
438 |
439 | Args:
440 | tensor (Tensor): Image data in PyTorch format
441 | only_use_y_channel (bool): Extract only Y channel
442 |
443 | Returns:
444 | tensor (Tensor): YCbCr image data in PyTorch format
445 |
446 | """
447 | if only_use_y_channel:
448 | weight = Tensor([[24.966], [128.553], [65.481]]).to(tensor)
449 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
450 | else:
451 | weight = Tensor([[24.966, 112.0, -18.214],
452 | [128.553, -74.203, -93.786],
453 | [65.481, -37.797, 112.0]]).to(tensor)
454 | bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor)
455 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
456 |
457 | tensor /= 255.
458 |
459 | return tensor
460 |
461 |
462 | def center_crop(image: np.ndarray, image_size: int) -> np.ndarray:
463 | """Crop small image patches from one image center area.
464 |
465 | Args:
466 | image (np.ndarray): The input image for `OpenCV.imread`.
467 | image_size (int): The size of the captured image area.
468 |
469 | Returns:
470 | patch_image (np.ndarray): Small patch image
471 |
472 | """
473 | image_height, image_width = image.shape[:2]
474 |
475 | # Just need to find the top and left coordinates of the image
476 | top = (image_height - image_size) // 2
477 | left = (image_width - image_size) // 2
478 |
479 | # Crop image patch
480 | patch_image = image[top:top + image_size, left:left + image_size, ...]
481 |
482 | return patch_image
483 |
484 |
485 | def random_crop(image: np.ndarray, image_size: int) -> np.ndarray:
486 | """Crop small image patches from one image.
487 |
488 | Args:
489 | image (np.ndarray): The input image for `OpenCV.imread`.
490 | image_size (int): The size of the captured image area.
491 |
492 | Returns:
493 | patch_image (np.ndarray): Small patch image
494 |
495 | """
496 | image_height, image_width = image.shape[:2]
497 |
498 | # Just need to find the top and left coordinates of the image
499 | top = random.randint(0, image_height - image_size)
500 | left = random.randint(0, image_width - image_size)
501 |
502 | # Crop image patch
503 | patch_image = image[top:top + image_size, left:left + image_size, ...]
504 |
505 | return patch_image
506 |
507 |
508 | def random_rotate(image,
509 | angles: list,
510 | center: tuple[int, int] = None,
511 | scale_factor: float = 1.0) -> np.ndarray:
512 | """Rotate an image by a random angle
513 |
514 | Args:
515 | image (np.ndarray): Image read with OpenCV
516 | angles (list): Rotation angle range
517 | center (optional, tuple[int, int]): High resolution image selection center point. Default: ``None``
518 | scale_factor (optional, float): scaling factor. Default: 1.0
519 |
520 | Returns:
521 | rotated_image (np.ndarray): image after rotation
522 |
523 | """
524 | image_height, image_width = image.shape[:2]
525 |
526 | if center is None:
527 | center = (image_width // 2, image_height // 2)
528 |
529 | # Random select specific angle
530 | angle = random.choice(angles)
531 | matrix = cv2.getRotationMatrix2D(center, angle, scale_factor)
532 | rotated_image = cv2.warpAffine(image, matrix, (image_width, image_height))
533 |
534 | return rotated_image
535 |
536 |
537 | def random_horizontally_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray:
538 | """Flip the image upside down randomly
539 |
540 | Args:
541 | image (np.ndarray): Image read with OpenCV
542 | p (optional, float): Horizontally flip probability. Default: 0.5
543 |
544 | Returns:
545 | horizontally_flip_image (np.ndarray): image after horizontally flip
546 |
547 | """
548 | if random.random() < p:
549 | horizontally_flip_image = cv2.flip(image, 1)
550 | else:
551 | horizontally_flip_image = image
552 |
553 | return horizontally_flip_image
554 |
555 |
556 | def random_vertically_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray:
557 | """Flip an image horizontally randomly
558 |
559 | Args:
560 | image (np.ndarray): Image read with OpenCV
561 | p (optional, float): Vertically flip probability. Default: 0.5
562 |
563 | Returns:
564 | vertically_flip_image (np.ndarray): image after vertically flip
565 |
566 | """
567 | if random.random() < p:
568 | vertically_flip_image = cv2.flip(image, 0)
569 | else:
570 | vertically_flip_image = image
571 |
572 | return vertically_flip_image
573 |
--------------------------------------------------------------------------------
/image_quality_assessment.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import collections.abc
15 | import math
16 | import typing
17 | import warnings
18 | from itertools import repeat
19 | from typing import Any
20 |
21 | import cv2
22 | import numpy as np
23 | import torch
24 | from numpy import ndarray
25 | from scipy.io import loadmat
26 | from scipy.ndimage.filters import convolve
27 | from scipy.special import gamma
28 | from torch import nn, Tensor
29 | from torch.nn import functional as F
30 |
31 | from imgproc import image_resize, expand_y, bgr_to_ycbcr, rgb_to_ycbcr_torch
32 |
33 | __all__ = [
34 | "psnr", "ssim", "niqe",
35 | "PSNR", "SSIM", "NIQE",
36 | ]
37 |
38 | _I = typing.Optional[int]
39 | _D = typing.Optional[torch.dtype]
40 |
41 |
42 | # The following is the implementation of IQA method in Python, using CPU as processing device
43 | def _check_image(raw_image: np.ndarray, dst_image: np.ndarray):
44 | """Check whether the size and type of the two images are the same
45 |
46 | Args:
47 | raw_image (np.ndarray): image data to be compared, BGR format, data range [0, 255]
48 | dst_image (np.ndarray): reference image data, BGR format, data range [0, 255]
49 |
50 | """
51 | # check image scale
52 | assert raw_image.shape == dst_image.shape, \
53 | f"Supplied images have different sizes {str(raw_image.shape)} and {str(dst_image.shape)}"
54 |
55 | # check image type
56 | if raw_image.dtype != dst_image.dtype:
57 | warnings.warn(f"Supplied images have different dtypes{str(raw_image.shape)} and {str(dst_image.shape)}")
58 |
59 |
60 | def psnr(raw_image: np.ndarray, dst_image: np.ndarray, crop_border: int, only_test_y_channel: bool) -> float:
61 | """Python implements PSNR (Peak Signal-to-Noise Ratio, peak signal-to-noise ratio) function
62 |
63 | Args:
64 | raw_image (np.ndarray): image data to be compared, BGR format, data range [0, 255]
65 | dst_image (np.ndarray): reference image data, BGR format, data range [0, 255]
66 | crop_border (int): crop border a few pixels
67 | only_test_y_channel (bool): Whether to test only the Y channel of the image.
68 |
69 | Returns:
70 | psnr_metrics (np.float64): PSNR metrics
71 |
72 | """
73 | # Check if two images are similar in scale and type
74 | _check_image(raw_image, dst_image)
75 |
76 | # crop border pixels
77 | if crop_border > 0:
78 | raw_image = raw_image[crop_border:-crop_border, crop_border:-crop_border, ...]
79 | dst_image = dst_image[crop_border:-crop_border, crop_border:-crop_border, ...]
80 |
81 | # If you only test the Y channel, you need to extract the Y channel data of the YCbCr channel data separately
82 | if only_test_y_channel:
83 | raw_image = expand_y(raw_image)
84 | dst_image = expand_y(dst_image)
85 |
86 | # Convert data type to numpy.float64 bit
87 | raw_image = raw_image.astype(np.float64)
88 | dst_image = dst_image.astype(np.float64)
89 |
90 | psnr_metrics = 10 * np.log10((255.0 ** 2) / np.mean((raw_image - dst_image) ** 2) + 1e-8)
91 |
92 | return psnr_metrics
93 |
94 |
95 | def _ssim(raw_image: np.ndarray, dst_image: np.ndarray) -> float:
96 | """Python implements the SSIM (Structural Similarity) function, which only calculates single-channel data
97 |
98 | Args:
99 | raw_image (np.ndarray): The image data to be compared, in BGR format, the data range is [0, 255]
100 | dst_image (np.ndarray): reference image data, BGR format, data range is [0, 255]
101 |
102 | Returns:
103 | ssim_metrics (float): SSIM metrics for single channel
104 |
105 | """
106 | c1 = (0.01 * 255.0) ** 2
107 | c2 = (0.03 * 255.0) ** 2
108 |
109 | kernel = cv2.getGaussianKernel(11, 1.5)
110 | kernel_window = np.outer(kernel, kernel.transpose())
111 |
112 | raw_mean = cv2.filter2D(raw_image, -1, kernel_window)[5:-5, 5:-5]
113 | dst_mean = cv2.filter2D(dst_image, -1, kernel_window)[5:-5, 5:-5]
114 | raw_mean_square = raw_mean ** 2
115 | dst_mean_square = dst_mean ** 2
116 | raw_dst_mean = raw_mean * dst_mean
117 | raw_variance = cv2.filter2D(raw_image ** 2, -1, kernel_window)[5:-5, 5:-5] - raw_mean_square
118 | dst_variance = cv2.filter2D(dst_image ** 2, -1, kernel_window)[5:-5, 5:-5] - dst_mean_square
119 | raw_dst_covariance = cv2.filter2D(raw_image * dst_image, -1, kernel_window)[5:-5, 5:-5] - raw_dst_mean
120 |
121 | ssim_molecular = (2 * raw_dst_mean + c1) * (2 * raw_dst_covariance + c2)
122 | ssim_denominator = (raw_mean_square + dst_mean_square + c1) * (raw_variance + dst_variance + c2)
123 |
124 | ssim_metrics = ssim_molecular / ssim_denominator
125 | ssim_metrics = float(np.mean(ssim_metrics))
126 |
127 | return ssim_metrics
128 |
129 |
130 | def ssim(raw_image: np.ndarray, dst_image: np.ndarray, crop_border: int, only_test_y_channel: bool) -> float:
131 | """Python implements the SSIM (Structural Similarity) function, which calculates single/multi-channel data
132 |
133 | Args:
134 | raw_image (np.ndarray): The image data to be compared, in BGR format, the data range is [0, 255]
135 | dst_image (np.ndarray): reference image data, BGR format, data range is [0, 255]
136 | crop_border (int): crop border a few pixels
137 | only_test_y_channel (bool): Whether to test only the Y channel of the image
138 |
139 | Returns:
140 | ssim_metrics (float): SSIM metrics for single channel
141 |
142 | """
143 | # Check if two images are similar in scale and type
144 | _check_image(raw_image, dst_image)
145 |
146 | # crop border pixels
147 | if crop_border > 0:
148 | raw_image = raw_image[crop_border:-crop_border, crop_border:-crop_border, ...]
149 | dst_image = dst_image[crop_border:-crop_border, crop_border:-crop_border, ...]
150 |
151 | # If you only test the Y channel, you need to extract the Y channel data of the YCbCr channel data separately
152 | if only_test_y_channel:
153 | raw_image = expand_y(raw_image)
154 | dst_image = expand_y(dst_image)
155 |
156 | # Convert data type to numpy.float64 bit
157 | raw_image = raw_image.astype(np.float64)
158 | dst_image = dst_image.astype(np.float64)
159 |
160 | channels_ssim_metrics = []
161 | for channel in range(raw_image.shape[2]):
162 | ssim_metrics = _ssim(raw_image[..., channel], dst_image[..., channel])
163 | channels_ssim_metrics.append(ssim_metrics)
164 | ssim_metrics = np.mean(np.asarray(channels_ssim_metrics))
165 |
166 | return float(ssim_metrics)
167 |
168 |
169 | def _estimate_aggd_parameters(vector: np.ndarray) -> [np.ndarray, float, float]:
170 | """Python implements the NIQE (Natural Image Quality Evaluator) function,
171 | This function is used to estimate an asymmetric generalized Gaussian distribution
172 |
173 | Reference papers:
174 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video`
175 |
176 | Args:
177 | vector (np.ndarray): data vector
178 |
179 | Returns:
180 | aggd_parameters (np.ndarray): asymmetric generalized Gaussian distribution
181 | left_beta (float): symmetric left data vector variance mean product
182 | right_beta (float): symmetric right side data vector variance mean product
183 |
184 | """
185 | # The following is obtained according to the formula and the method provided in the paper on WIki encyclopedia
186 | vector = vector.flatten()
187 | gam = np.arange(0.2, 10.001, 0.001) # len = 9801
188 | gam_reciprocal = np.reciprocal(gam)
189 | r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
190 |
191 | left_std = np.sqrt(np.mean(vector[vector < 0] ** 2))
192 | right_std = np.sqrt(np.mean(vector[vector > 0] ** 2))
193 | gamma_hat = left_std / right_std
194 | rhat = (np.mean(np.abs(vector))) ** 2 / np.mean(vector ** 2)
195 | rhat_norm = (rhat * (gamma_hat ** 3 + 1) * (gamma_hat + 1)) / ((gamma_hat ** 2 + 1) ** 2)
196 | array_position = np.argmin((r_gam - rhat_norm) ** 2)
197 |
198 | aggd_parameters = gam[array_position]
199 | left_beta = left_std * np.sqrt(gamma(1 / aggd_parameters) / gamma(3 / aggd_parameters))
200 | right_beta = right_std * np.sqrt(gamma(1 / aggd_parameters) / gamma(3 / aggd_parameters))
201 |
202 | return aggd_parameters, left_beta, right_beta
203 |
204 |
205 | def _get_mscn_feature(image: np.ndarray) -> list[float | Any]:
206 | """Python implements the NIQE (Natural Image Quality Evaluator) function,
207 | This function is used to calculate the MSCN feature map
208 |
209 | Reference papers:
210 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video`
211 |
212 | Args:
213 | image (np.ndarray): Grayscale image of MSCN feature to be calculated, BGR format, data range is [0, 255]
214 |
215 | Returns:
216 | mscn_feature (np.ndarray): MSCN feature map of the image
217 |
218 | """
219 | mscn_feature = []
220 | # Calculate the asymmetric generalized Gaussian distribution
221 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters(image)
222 | mscn_feature.extend([aggd_parameters, (left_beta + right_beta) / 2])
223 |
224 | shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
225 | for i in range(len(shifts)):
226 | shifted_block = np.roll(image, shifts[i], axis=(0, 1))
227 | # Calculate the asymmetric generalized Gaussian distribution
228 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters(image * shifted_block)
229 | mean = (right_beta - left_beta) * (gamma(2 / aggd_parameters) / gamma(1 / aggd_parameters))
230 | mscn_feature.extend([aggd_parameters, mean, left_beta, right_beta])
231 |
232 | return mscn_feature
233 |
234 |
235 | def _fit_mscn_ipac(image: np.ndarray,
236 | mu_pris_param: np.ndarray,
237 | cov_pris_param: np.ndarray,
238 | gaussian_window: np.ndarray,
239 | block_size_height: int,
240 | block_size_width: int) -> float:
241 | """Python implements the NIQE (Natural Image Quality Evaluator) function,
242 | This function is used to fit the inner product of adjacent coefficients of MSCN
243 |
244 | Reference papers:
245 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video`
246 |
247 | Args:
248 | image (np.ndarray): The image data of the NIQE to be tested, in BGR format, the data range is [0, 255]
249 | mu_pris_param (np.ndarray): Mean of predefined multivariate Gaussians, model computed on original dataset.
250 | cov_pris_param (np.ndarray): Covariance of predefined multivariate Gaussian model computed on original dataset.
251 | gaussian_window (np.ndarray): 7x7 Gaussian window for smoothing the image
252 | block_size_height (int): the height of the block into which the image is divided
253 | block_size_width (int): The width of the block into which the image is divided
254 |
255 | Returns:
256 | niqe_metric (np.ndarray): NIQE score
257 |
258 | """
259 | image_height, image_width = image.shape
260 | num_block_height = math.floor(image_height / block_size_height)
261 | num_block_width = math.floor(image_width / block_size_width)
262 | image = image[0:num_block_height * block_size_height, 0:num_block_width * block_size_width]
263 |
264 | features_parameters = []
265 | for scale in (1, 2):
266 | mu = convolve(image, gaussian_window, mode="nearest")
267 | sigma = np.sqrt(np.abs(convolve(np.square(image), gaussian_window, mode="nearest") - np.square(mu)))
268 | image_norm = (image - mu) / (sigma + 1)
269 |
270 | feature = []
271 | for idx_w in range(num_block_width):
272 | for idx_h in range(num_block_height):
273 | vector = image_norm[
274 | idx_h * block_size_height // scale:(idx_h + 1) * block_size_height // scale,
275 | idx_w * block_size_width // scale:(idx_w + 1) * block_size_width // scale]
276 | feature.append(_get_mscn_feature(vector))
277 |
278 | features_parameters.append(np.array(feature))
279 |
280 | if scale == 1:
281 | image = image_resize(image / 255., scale_factor=0.5, antialiasing=True)
282 | image = image * 255.
283 |
284 | features_parameters = np.concatenate(features_parameters, axis=1)
285 |
286 | # Fitting a multivariate Gaussian kernel model to distorted patch features
287 | mu_distparam = np.nanmean(features_parameters, axis=0)
288 | distparam_no_nan = features_parameters[~np.isnan(features_parameters).any(axis=1)]
289 | cov_distparam = np.cov(distparam_no_nan, rowvar=False)
290 |
291 | invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
292 | niqe_metric = np.matmul(np.matmul((mu_pris_param - mu_distparam), invcov_param),
293 | np.transpose((mu_pris_param - mu_distparam)))
294 |
295 | niqe_metric = np.sqrt(niqe_metric)
296 | niqe_metric = float(np.squeeze(niqe_metric))
297 |
298 | return niqe_metric
299 |
300 |
301 | def niqe(image: np.ndarray,
302 | crop_border: int,
303 | niqe_model_path: str,
304 | block_size_height: int = 96,
305 | block_size_width: int = 96) -> float:
306 | """Python implements the NIQE (Natural Image Quality Evaluator) function,
307 | This function computes single/multi-channel data
308 |
309 | Args:
310 | image (np.ndarray): The image data to be compared, in BGR format, the data range is [0, 255]
311 | crop_border (int): crop border a few pixels
312 | niqe_model_path: NIQE estimator model address
313 | block_size_height (int): The height of the block the image is divided into. Default: 96
314 | block_size_width (int): The width of the block the image is divided into. Default: 96
315 |
316 | Returns:
317 | niqe_metrics (float): NIQE indicator under single channel
318 |
319 | """
320 | # crop border pixels
321 | if crop_border > 0:
322 | image = image[crop_border:-crop_border, crop_border:-crop_border, ...]
323 |
324 | # Defining the NIQE Feature Extraction Model
325 | niqe_model = np.load(niqe_model_path)
326 |
327 | mu_pris_param = niqe_model["mu_pris_param"]
328 | cov_pris_param = niqe_model["cov_pris_param"]
329 | gaussian_window = niqe_model["gaussian_window"]
330 |
331 | # NIQE only tests on Y channel images and needs to convert the images
332 | y_image = bgr_to_ycbcr(image, only_use_y_channel=True)
333 |
334 | # Convert data type to numpy.float64 bit
335 | y_image = y_image.astype(np.float64)
336 |
337 | niqe_metric = _fit_mscn_ipac(y_image,
338 | mu_pris_param,
339 | cov_pris_param,
340 | gaussian_window,
341 | block_size_height,
342 | block_size_width)
343 |
344 | return niqe_metric
345 |
346 |
347 | # The following is the IQA method implemented by PyTorch, using CUDA as the processing device
348 | def _check_tensor_shape(raw_tensor: torch.Tensor, dst_tensor: torch.Tensor):
349 | """Check if the dimensions of the two tensors are the same
350 |
351 | Args:
352 | raw_tensor (np.ndarray or torch.Tensor): image tensor flow to be compared, RGB format, data range [0, 1]
353 | dst_tensor (np.ndarray or torch.Tensor): reference image tensorflow, RGB format, data range [0, 1]
354 |
355 | """
356 | # Check if tensor scales are consistent
357 | assert raw_tensor.shape == dst_tensor.shape, \
358 | f"Supplied images have different sizes {str(raw_tensor.shape)} and {str(dst_tensor.shape)}"
359 |
360 |
361 | def _psnr_torch(raw_tensor: torch.Tensor, dst_tensor: torch.Tensor, crop_border: int,
362 | only_test_y_channel: bool) -> float:
363 | """PyTorch implements PSNR (Peak Signal-to-Noise Ratio, peak signal-to-noise ratio) function
364 |
365 | Args:
366 | raw_tensor (torch.Tensor): image tensor flow to be compared, RGB format, data range [0, 1]
367 | dst_tensor (torch.Tensor): reference image tensorflow, RGB format, data range [0, 1]
368 | crop_border (int): crop border a few pixels
369 | only_test_y_channel (bool): Whether to test only the Y channel of the image
370 |
371 | Returns:
372 | psnr_metrics (torch.Tensor): PSNR metrics
373 |
374 | """
375 | # Check if two tensor scales are similar
376 | _check_tensor_shape(raw_tensor, dst_tensor)
377 |
378 | # crop border pixels
379 | if crop_border > 0:
380 | raw_tensor = raw_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border]
381 | dst_tensor = dst_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border]
382 |
383 | # Convert RGB tensor data to YCbCr tensor, and extract only Y channel data
384 | if only_test_y_channel:
385 | raw_tensor = rgb_to_ycbcr_torch(raw_tensor, only_use_y_channel=True)
386 | dst_tensor = rgb_to_ycbcr_torch(dst_tensor, only_use_y_channel=True)
387 |
388 | # Convert data type to torch.float64 bit
389 | raw_tensor = raw_tensor.to(torch.float64)
390 | dst_tensor = dst_tensor.to(torch.float64)
391 |
392 | mse_value = torch.mean((raw_tensor * 255.0 - dst_tensor * 255.0) ** 2 + 1e-8, dim=[1, 2, 3])
393 | psnr_metrics = 10 * torch.log10_(255.0 ** 2 / mse_value)
394 |
395 | return psnr_metrics
396 |
397 |
398 | class PSNR(nn.Module):
399 | """PyTorch implements PSNR (Peak Signal-to-Noise Ratio, peak signal-to-noise ratio) function
400 |
401 | Attributes:
402 | crop_border (int): crop border a few pixels
403 | only_test_y_channel (bool): Whether to test only the Y channel of the image
404 |
405 | Returns:
406 | psnr_metrics (torch.Tensor): PSNR metrics
407 |
408 | """
409 |
410 | def __init__(self, crop_border: int, only_test_y_channel: bool) -> None:
411 | super().__init__()
412 | self.crop_border = crop_border
413 | self.only_test_y_channel = only_test_y_channel
414 |
415 | def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor) -> float:
416 | psnr_metrics = _psnr_torch(raw_tensor, dst_tensor, self.crop_border, self.only_test_y_channel)
417 |
418 | return psnr_metrics
419 |
420 |
421 | def _ssim_torch(raw_tensor: torch.Tensor,
422 | dst_tensor: torch.Tensor,
423 | window_size: int,
424 | gaussian_kernel_window: np.ndarray) -> float:
425 | """PyTorch implements the SSIM (Structural Similarity) function, which only calculates single-channel data
426 |
427 | Args:
428 | raw_tensor (torch.Tensor): image tensor flow to be compared, RGB format, data range [0, 255]
429 | dst_tensor (torch.Tensor): reference image tensorflow, RGB format, data range [0, 255]
430 | window_size (int): Gaussian filter size
431 | gaussian_kernel_window (np.ndarray): Gaussian filter
432 |
433 | Returns:
434 | ssim_metrics (torch.Tensor): SSIM metrics
435 |
436 | """
437 | c1 = (0.01 * 255.0) ** 2
438 | c2 = (0.03 * 255.0) ** 2
439 |
440 | gaussian_kernel_window = torch.from_numpy(gaussian_kernel_window).view(1, 1, window_size, window_size)
441 | gaussian_kernel_window = gaussian_kernel_window.expand(raw_tensor.size(1), 1, window_size, window_size)
442 | gaussian_kernel_window = gaussian_kernel_window.to(device=raw_tensor.device, dtype=raw_tensor.dtype)
443 |
444 | raw_mean = F.conv2d(raw_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0), groups=raw_tensor.shape[1])
445 | dst_mean = F.conv2d(dst_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0), groups=dst_tensor.shape[1])
446 | raw_mean_square = raw_mean ** 2
447 | dst_mean_square = dst_mean ** 2
448 | raw_dst_mean = raw_mean * dst_mean
449 | raw_variance = F.conv2d(raw_tensor * raw_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0),
450 | groups=raw_tensor.shape[1]) - raw_mean_square
451 | dst_variance = F.conv2d(dst_tensor * dst_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0),
452 | groups=raw_tensor.shape[1]) - dst_mean_square
453 | raw_dst_covariance = F.conv2d(raw_tensor * dst_tensor, gaussian_kernel_window, stride=1, padding=(0, 0),
454 | groups=raw_tensor.shape[1]) - raw_dst_mean
455 |
456 | ssim_molecular = (2 * raw_dst_mean + c1) * (2 * raw_dst_covariance + c2)
457 | ssim_denominator = (raw_mean_square + dst_mean_square + c1) * (raw_variance + dst_variance + c2)
458 |
459 | ssim_metrics = ssim_molecular / ssim_denominator
460 | ssim_metrics = torch.mean(ssim_metrics, [1, 2, 3]).float()
461 |
462 | return ssim_metrics
463 |
464 |
465 | def _ssim_single_torch(raw_tensor: torch.Tensor,
466 | dst_tensor: torch.Tensor,
467 | crop_border: int,
468 | only_test_y_channel: bool,
469 | window_size: int,
470 | gaussian_kernel_window: ndarray) -> float:
471 | """PyTorch implements the SSIM (Structural Similarity) function, which only calculates single-channel data
472 |
473 | Args:
474 | raw_tensor (Tensor): image tensor flow to be compared, RGB format, data range [0, 1]
475 | dst_tensor (Tensor): reference image tensorflow, RGB format, data range [0, 1]
476 | crop_border (int): crop border a few pixels
477 | only_test_y_channel (bool): Whether to test only the Y channel of the image
478 | window_size (int): Gaussian filter size
479 | gaussian_kernel_window (ndarray): Gaussian filter
480 |
481 | Returns:
482 | ssim_metrics (torch.Tensor): SSIM metrics
483 |
484 | """
485 | # Check if two tensor scales are similar
486 | _check_tensor_shape(raw_tensor, dst_tensor)
487 |
488 | # crop border pixels
489 | if crop_border > 0:
490 | raw_tensor = raw_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border]
491 | dst_tensor = dst_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border]
492 |
493 | # Convert RGB tensor data to YCbCr tensor, and extract only Y channel data
494 | if only_test_y_channel:
495 | raw_tensor = rgb_to_ycbcr_torch(raw_tensor, only_use_y_channel=True)
496 | dst_tensor = rgb_to_ycbcr_torch(dst_tensor, only_use_y_channel=True)
497 |
498 | # Convert data type to torch.float64 bit
499 | raw_tensor = raw_tensor.to(torch.float64)
500 | dst_tensor = dst_tensor.to(torch.float64)
501 |
502 | ssim_metrics = _ssim_torch(raw_tensor * 255.0, dst_tensor * 255.0, window_size, gaussian_kernel_window)
503 |
504 | return ssim_metrics
505 |
506 |
507 | class SSIM(nn.Module):
508 | """PyTorch implements the SSIM (Structural Similarity) function, which only calculates single-channel data
509 |
510 | Args:
511 | crop_border (int): crop border a few pixels
512 | only_only_test_y_channel (bool): Whether to test only the Y channel of the image
513 | window_size (int): Gaussian filter size
514 | gaussian_sigma (float): sigma parameter in Gaussian filter
515 |
516 | Returns:
517 | ssim_metrics (torch.Tensor): SSIM metrics
518 |
519 | """
520 |
521 | def __init__(self, crop_border: int,
522 | only_only_test_y_channel: bool,
523 | window_size: int = 11,
524 | gaussian_sigma: float = 1.5) -> None:
525 | super().__init__()
526 | self.crop_border = crop_border
527 | self.only_test_y_channel = only_only_test_y_channel
528 | self.window_size = window_size
529 |
530 | gaussian_kernel = cv2.getGaussianKernel(window_size, gaussian_sigma)
531 | self.gaussian_kernel_window = np.outer(gaussian_kernel, gaussian_kernel.transpose())
532 |
533 | def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor) -> float:
534 | ssim_metrics = _ssim_single_torch(raw_tensor,
535 | dst_tensor,
536 | self.crop_border,
537 | self.only_test_y_channel,
538 | self.window_size,
539 | self.gaussian_kernel_window)
540 |
541 | return ssim_metrics
542 |
543 |
544 | def _fspecial_gaussian_torch(window_size: int, sigma: float, channels: int):
545 | """PyTorch implements the fspecial_gaussian() function in MATLAB
546 |
547 | Args:
548 | window_size (int): Gaussian filter size
549 | sigma (float): sigma parameter in Gaussian filter
550 | channels (int): number of input image channels
551 |
552 | Returns:
553 | gaussian_kernel_window (torch.Tensor): Gaussian filter in Tensor format
554 |
555 | """
556 | if type(window_size) is int:
557 | shape = (window_size, window_size)
558 | else:
559 | shape = window_size
560 | m, n = [(ss - 1.) / 2. for ss in shape]
561 | y, x = np.ogrid[-m:m + 1, -n:n + 1]
562 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))
563 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
564 | sumh = h.sum()
565 |
566 | if sumh != 0:
567 | h /= sumh
568 |
569 | gaussian_kernel_window = torch.from_numpy(h).float().repeat(channels, 1, 1, 1)
570 |
571 | return gaussian_kernel_window
572 |
573 |
574 | def _to_tuple(n):
575 | def parse(x):
576 | if isinstance(x, collections.abc.Iterable):
577 | return x
578 | return tuple(repeat(x, n))
579 |
580 | return parse
581 |
582 |
583 | def _excact_padding_2d(tensor: torch.Tensor,
584 | kernel: torch.Tensor | tuple,
585 | stride: int = 1,
586 | dilation: int = 1,
587 | mode: str = "same") -> torch.Tensor:
588 | assert len(tensor.shape) == 4, f"Only support 4D tensor input, but got {tensor.shape}"
589 | kernel = _to_tuple(2)(kernel)
590 | stride = _to_tuple(2)(stride)
591 | dilation = _to_tuple(2)(dilation)
592 | b, c, h, w = tensor.shape
593 | h2 = math.ceil(h / stride[0])
594 | w2 = math.ceil(w / stride[1])
595 | pad_row = (h2 - 1) * stride[0] + (kernel[0] - 1) * dilation[0] + 1 - h
596 | pad_col = (w2 - 1) * stride[1] + (kernel[1] - 1) * dilation[1] + 1 - w
597 | pad_l, pad_r, pad_t, pad_b = (pad_col // 2, pad_col - pad_col // 2, pad_row // 2, pad_row - pad_row // 2)
598 |
599 | mode = mode if mode != "same" else "constant"
600 | if mode != "symmetric":
601 | tensor = F.pad(tensor, (pad_l, pad_r, pad_t, pad_b), mode=mode)
602 | elif mode == "symmetric":
603 | sym_h = torch.flip(tensor, [2])
604 | sym_w = torch.flip(tensor, [3])
605 | sym_hw = torch.flip(tensor, [2, 3])
606 |
607 | row1 = torch.cat((sym_hw, sym_h, sym_hw), dim=3)
608 | row2 = torch.cat((sym_w, tensor, sym_w), dim=3)
609 | row3 = torch.cat((sym_hw, sym_h, sym_hw), dim=3)
610 |
611 | whole_map = torch.cat((row1, row2, row3), dim=2)
612 |
613 | tensor = whole_map[:, :, h - pad_t:2 * h + pad_b, w - pad_l:2 * w + pad_r, ]
614 |
615 | return tensor
616 |
617 |
618 | class ExactPadding2d(nn.Module):
619 | r"""This function calculate exact padding values for 4D tensor inputs,
620 | and support the same padding mode as tensorflow.
621 |
622 | Args:
623 | kernel (int or tuple): kernel size.
624 | stride (int or tuple): stride size.
625 | dilation (int or tuple): dilation size, default with 1.
626 | mode (srt): padding mode can be ('same', 'symmetric', 'replicate', 'circular')
627 |
628 | """
629 |
630 | def __init__(self, kernel, stride=1, dilation=1, mode="same") -> None:
631 | super().__init__()
632 | self.kernel = _to_tuple(2)(kernel)
633 | self.stride = _to_tuple(2)(stride)
634 | self.dilation = _to_tuple(2)(dilation)
635 | self.mode = mode
636 |
637 | def forward(self, tensor: torch.Tensor) -> torch.Tensor:
638 | return _excact_padding_2d(tensor, self.kernel, self.stride, self.dilation, self.mode)
639 |
640 |
641 | def _image_filter(tensor: torch.Tensor,
642 | weight: torch.Tensor,
643 | bias=None,
644 | stride: int = 1,
645 | padding: str = "same",
646 | dilation: int = 1,
647 | groups: int = 1):
648 | """PyTorch implements the imfilter() function in MATLAB
649 |
650 | Args:
651 | tensor (torch.Tensor): Tensor image data
652 | weight (torch.Tensor): filter weight
653 | padding (str): how to pad pixels. Default: ``same``
654 | dilation (int): convolution dilation scale
655 | groups (int): number of grouped convolutions
656 |
657 | """
658 | kernel_size = weight.shape[2:]
659 | exact_padding_2d = ExactPadding2d(kernel_size, stride, dilation, mode=padding)
660 |
661 | return F.conv2d(exact_padding_2d(tensor), weight, bias, stride, dilation=dilation, groups=groups)
662 |
663 |
664 | def _reshape_input_torch(tensor: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]:
665 | if tensor.dim() == 4:
666 | b, c, h, w = tensor.size()
667 | elif tensor.dim() == 3:
668 | c, h, w = tensor.size()
669 | b = None
670 | elif tensor.dim() == 2:
671 | h, w = tensor.size()
672 | b = c = None
673 | else:
674 | raise ValueError('{}-dim Tensor is not supported!'.format(tensor.dim()))
675 |
676 | tensor = tensor.view(-1, 1, h, w)
677 | return tensor, b, c, h, w
678 |
679 |
680 | def _reshape_output_torch(tensor: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
681 | rh = tensor.size(-2)
682 | rw = tensor.size(-1)
683 | # Back to the original dimension
684 | if b is not None:
685 | tensor = tensor.view(b, c, rh, rw) # 4-dim
686 | else:
687 | if c is not None:
688 | tensor = tensor.view(c, rh, rw) # 3-dim
689 | else:
690 | tensor = tensor.view(rh, rw) # 2-dim
691 |
692 | return tensor
693 |
694 |
695 | def _cast_input_torch(tensor: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
696 | if tensor.dtype != torch.float32 or tensor.dtype != torch.float64:
697 | dtype = tensor.dtype
698 | tensor = tensor.float()
699 | else:
700 | dtype = None
701 |
702 | return tensor, dtype
703 |
704 |
705 | def _cast_output_torch(tensor: torch.Tensor, dtype: _D) -> torch.Tensor:
706 | if dtype is not None:
707 | if not dtype.is_floating_point:
708 | tensor = tensor.round()
709 | # To prevent over/underflow when converting types
710 | if dtype is torch.uint8:
711 | tensor = tensor.clamp(0, 255)
712 |
713 | tensor = tensor.to(dtype=dtype)
714 |
715 | return tensor
716 |
717 |
718 | def _cubic_contribution_torch(tensor: torch.Tensor, a: float = -0.5) -> torch.Tensor:
719 | ax = tensor.abs()
720 | ax2 = ax * ax
721 | ax3 = ax * ax2
722 |
723 | range_01 = ax.le(1)
724 | range_12 = torch.logical_and(ax.gt(1), ax.le(2))
725 |
726 | cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
727 | cont_01 = cont_01 * range_01.to(dtype=tensor.dtype)
728 |
729 | cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
730 | cont_12 = cont_12 * range_12.to(dtype=tensor.dtype)
731 |
732 | cont = cont_01 + cont_12
733 | return cont
734 |
735 |
736 | def _gaussian_contribution_torch(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
737 | range_3sigma = (x.abs() <= 3 * sigma + 1)
738 | # Normalization will be done after
739 | cont = torch.exp(-x.pow(2) / (2 * sigma ** 2))
740 | cont = cont * range_3sigma.to(dtype=x.dtype)
741 | return cont
742 |
743 |
744 | def _reflect_padding_torch(tensor: torch.Tensor, dim: int, pad_pre: int, pad_post: int) -> torch.Tensor:
745 | """
746 | Apply reflect padding to the given Tensor.
747 | Note that it is slightly different from the PyTorch functional.pad,
748 | where boundary elements are used only once.
749 | Instead, we follow the MATLAB implementation
750 | which uses boundary elements twice.
751 | For example,
752 | [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
753 | while our implementation yields [a, a, b, c, d, d].
754 | """
755 | b, c, h, w = tensor.size()
756 | if dim == 2 or dim == -2:
757 | padding_buffer = tensor.new_zeros(b, c, h + pad_pre + pad_post, w)
758 | padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(tensor)
759 | for p in range(pad_pre):
760 | padding_buffer[..., pad_pre - p - 1, :].copy_(tensor[..., p, :])
761 | for p in range(pad_post):
762 | padding_buffer[..., h + pad_pre + p, :].copy_(tensor[..., -(p + 1), :])
763 | else:
764 | padding_buffer = tensor.new_zeros(b, c, h, w + pad_pre + pad_post)
765 | padding_buffer[..., pad_pre:(w + pad_pre)].copy_(tensor)
766 | for p in range(pad_pre):
767 | padding_buffer[..., pad_pre - p - 1].copy_(tensor[..., p])
768 | for p in range(pad_post):
769 | padding_buffer[..., w + pad_pre + p].copy_(tensor[..., -(p + 1)])
770 |
771 | return padding_buffer
772 |
773 |
774 | def _padding_torch(tensor: torch.Tensor,
775 | dim: int,
776 | pad_pre: int,
777 | pad_post: int,
778 | padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:
779 | if padding_type is None:
780 | return tensor
781 | elif padding_type == 'reflect':
782 | x_pad = _reflect_padding_torch(tensor, dim, pad_pre, pad_post)
783 | else:
784 | raise ValueError('{} padding is not supported!'.format(padding_type))
785 |
786 | return x_pad
787 |
788 |
789 | def _get_padding_torch(tensor: torch.Tensor, kernel_size: int, x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
790 | tensor = tensor.long()
791 | r_min = tensor.min()
792 | r_max = tensor.max() + kernel_size - 1
793 |
794 | if r_min <= 0:
795 | pad_pre = -r_min
796 | pad_pre = pad_pre.item()
797 | tensor += pad_pre
798 | else:
799 | pad_pre = 0
800 |
801 | if r_max >= x_size:
802 | pad_post = r_max - x_size + 1
803 | pad_post = pad_post.item()
804 | else:
805 | pad_post = 0
806 |
807 | return pad_pre, pad_post, tensor
808 |
809 |
810 | def _get_weight_torch(tensor: torch.Tensor,
811 | kernel_size: int,
812 | kernel: str = "cubic",
813 | sigma: float = 2.0,
814 | antialiasing_factor: float = 1) -> torch.Tensor:
815 | buffer_pos = tensor.new_zeros(kernel_size, len(tensor))
816 | for idx, buffer_sub in enumerate(buffer_pos):
817 | buffer_sub.copy_(tensor - idx)
818 |
819 | # Expand (downsampling) / Shrink (upsampling) the receptive field.
820 | buffer_pos *= antialiasing_factor
821 | if kernel == 'cubic':
822 | weight = _cubic_contribution_torch(buffer_pos)
823 | elif kernel == 'gaussian':
824 | weight = _gaussian_contribution_torch(buffer_pos, sigma=sigma)
825 | else:
826 | raise ValueError('{} kernel is not supported!'.format(kernel))
827 |
828 | weight /= weight.sum(dim=0, keepdim=True)
829 | return weight
830 |
831 |
832 | def _reshape_tensor_torch(tensor: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
833 | # Resize height
834 | if dim == 2 or dim == -2:
835 | k = (kernel_size, 1)
836 | h_out = tensor.size(-2) - kernel_size + 1
837 | w_out = tensor.size(-1)
838 | # Resize width
839 | else:
840 | k = (1, kernel_size)
841 | h_out = tensor.size(-2)
842 | w_out = tensor.size(-1) - kernel_size + 1
843 |
844 | unfold = F.unfold(tensor, k)
845 | unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
846 | return unfold
847 |
848 |
849 | def _resize_1d_torch(tensor: torch.Tensor,
850 | dim: int,
851 | size: int,
852 | scale: float,
853 | kernel: str = 'cubic',
854 | sigma: float = 2.0,
855 | padding_type: str = 'reflect',
856 | antialiasing: bool = True) -> torch.Tensor:
857 | """
858 | Args:
859 | tensor (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
860 | dim (int):
861 | scale (float):
862 | size (int):
863 | Return:
864 | """
865 | # Identity case
866 | if scale == 1:
867 | return tensor
868 |
869 | # Default bicubic kernel with antialiasing (only when downsampling)
870 | if kernel == 'cubic':
871 | kernel_size = 4
872 | else:
873 | kernel_size = math.floor(6 * sigma)
874 |
875 | if antialiasing and (scale < 1):
876 | antialiasing_factor = scale
877 | kernel_size = math.ceil(kernel_size / antialiasing_factor)
878 | else:
879 | antialiasing_factor = 1
880 |
881 | # We allow margin to both sizes
882 | kernel_size += 2
883 |
884 | # Weights only depend on the shape of input and output,
885 | # so we do not calculate gradients here.
886 | with torch.no_grad():
887 | pos = torch.linspace(
888 | 0,
889 | size - 1,
890 | steps=size,
891 | dtype=tensor.dtype,
892 | device=tensor.device,
893 | )
894 | pos = (pos + 0.5) / scale - 0.5
895 | base = pos.floor() - (kernel_size // 2) + 1
896 | dist = pos - base
897 | weight = _get_weight_torch(
898 | dist,
899 | kernel_size,
900 | kernel=kernel,
901 | sigma=sigma,
902 | antialiasing_factor=antialiasing_factor,
903 | )
904 | pad_pre, pad_post, base = _get_padding_torch(base, kernel_size, tensor.size(dim))
905 |
906 | # To back-propagate through x
907 | x_pad = _padding_torch(tensor, dim, pad_pre, pad_post, padding_type=padding_type)
908 | unfold = _reshape_tensor_torch(x_pad, dim, kernel_size)
909 | # Subsampling first
910 | if dim == 2 or dim == -2:
911 | sample = unfold[..., base, :]
912 | weight = weight.view(1, kernel_size, sample.size(2), 1)
913 | else:
914 | sample = unfold[..., base]
915 | weight = weight.view(1, kernel_size, 1, sample.size(3))
916 |
917 | # Apply the kernel
918 | tensor = sample * weight
919 | tensor = tensor.sum(dim=1, keepdim=True)
920 | return tensor
921 |
922 |
923 | def _downsampling_2d_torch(tensor: torch.Tensor, k: torch.Tensor, scale: int,
924 | padding_type: str = 'reflect') -> torch.Tensor:
925 | c = tensor.size(1)
926 | k_h = k.size(-2)
927 | k_w = k.size(-1)
928 |
929 | k = k.to(dtype=tensor.dtype, device=tensor.device)
930 | k = k.view(1, 1, k_h, k_w)
931 | k = k.repeat(c, c, 1, 1)
932 | e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
933 | e = e.view(c, c, 1, 1)
934 | k = k * e
935 |
936 | pad_h = (k_h - scale) // 2
937 | pad_w = (k_w - scale) // 2
938 | tensor = _padding_torch(tensor, -2, pad_h, pad_h, padding_type=padding_type)
939 | tensor = _padding_torch(tensor, -1, pad_w, pad_w, padding_type=padding_type)
940 | y = F.conv2d(tensor, k, padding=0, stride=scale)
941 | return y
942 |
943 |
944 | def _cov_torch(tensor, rowvar=True, bias=False):
945 | r"""Estimate a covariance matrix (np.cov)
946 | Ref: https://gist.github.com/ModarTensai/5ab449acba9df1a26c12060240773110
947 | """
948 | tensor = tensor if rowvar else tensor.transpose(-1, -2)
949 | tensor = tensor - tensor.mean(dim=-1, keepdim=True)
950 | factor = 1 / (tensor.shape[-1] - int(not bool(bias)))
951 | return factor * tensor @ tensor.transpose(-1, -2)
952 |
953 |
954 | def _nancov_torch(x):
955 | r"""Calculate nancov for batched tensor, rows that contains nan value
956 | will be removed.
957 | Args:
958 | x (tensor): (B, row_num, feat_dim)
959 | Return:
960 | cov (tensor): (B, feat_dim, feat_dim)
961 | """
962 | assert len(x.shape) == 3, f'Shape of input should be (batch_size, row_num, feat_dim), but got {x.shape}'
963 | b, rownum, feat_dim = x.shape
964 | nan_mask = torch.isnan(x).any(dim=2, keepdim=True)
965 | x_no_nan = x.masked_select(~nan_mask).reshape(b, -1, feat_dim)
966 | cov_x = _cov_torch(x_no_nan, rowvar=False)
967 | return cov_x
968 |
969 |
970 | def _nanmean_torch(v, *args, inplace=False, **kwargs):
971 | r"""nanmean same as matlab function: calculate mean values by removing all nan.
972 | """
973 | if not inplace:
974 | v = v.clone()
975 | is_nan = torch.isnan(v)
976 | v[is_nan] = 0
977 | return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
978 |
979 |
980 | def _symm_pad_torch(im: torch.Tensor, padding: [int, int, int, int]):
981 | """Symmetric padding same as tensorflow.
982 | Ref: https://discuss.pytorch.org/t/symmetric-padding/19866/3
983 | """
984 | h, w = im.shape[-2:]
985 | left, right, top, bottom = padding
986 |
987 | x_idx = np.arange(-left, w + right)
988 | y_idx = np.arange(-top, h + bottom)
989 |
990 | def reflect(x, minx, maxx):
991 | """ Reflects an array around two points making a triangular waveform that ramps up
992 | and down, allowing for pad lengths greater than the input length """
993 | rng = maxx - minx
994 | double_rng = 2 * rng
995 | mod = np.fmod(x - minx, double_rng)
996 | normed_mod = np.where(mod < 0, mod + double_rng, mod)
997 | out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx
998 | return np.array(out, dtype=x.dtype)
999 |
1000 | x_pad = reflect(x_idx, -0.5, w - 0.5)
1001 | y_pad = reflect(y_idx, -0.5, h - 0.5)
1002 | xx, yy = np.meshgrid(x_pad, y_pad)
1003 | return im[..., yy, xx]
1004 |
1005 |
1006 | def _blockproc_torch(x, kernel: int or tuple or list, fun, border_size=None, pad_partial=False, pad_method='zero'):
1007 | r"""blockproc function like matlab
1008 |
1009 | Difference:
1010 | - Partial blocks is discarded (if exist) for fast GPU process.
1011 |
1012 | Args:
1013 | x (tensor): shape (b, c, h, w)
1014 | kernel (int or tuple): block size
1015 | func: function to process each block
1016 | border_size (int or tuple): border pixels to each block
1017 | pad_partial: pad partial blocks to make them full-sized, default False
1018 | pad_method: [zero, replicate, symmetric] how to pad partial block when pad_partial is set True
1019 |
1020 | Return:
1021 | results (tensor): concatenated results of each block
1022 |
1023 | """
1024 | assert len(x.shape) == 4, f'Shape of input has to be (b, c, h, w) but got {x.shape}'
1025 | kernel = _to_tuple(2)(kernel)
1026 | if pad_partial:
1027 | b, c, h, w = x.shape
1028 | stride = kernel
1029 | h2 = math.ceil(h / stride[0])
1030 | w2 = math.ceil(w / stride[1])
1031 | pad_row = (h2 - 1) * stride[0] + kernel[0] - h
1032 | pad_col = (w2 - 1) * stride[1] + kernel[1] - w
1033 | padding = (0, pad_col, 0, pad_row)
1034 | if pad_method == 'zero':
1035 | x = F.pad(x, padding, mode='constant')
1036 | elif pad_method == 'symmetric':
1037 | x = _symm_pad_torch(x, padding)
1038 | else:
1039 | x = F.pad(x, padding, mode=pad_method)
1040 |
1041 | if border_size is not None:
1042 | raise NotImplementedError('Blockproc with border is not implemented yet')
1043 | else:
1044 | b, c, h, w = x.shape
1045 | block_size_h, block_size_w = kernel
1046 | num_block_h = math.floor(h / block_size_h)
1047 | num_block_w = math.floor(w / block_size_w)
1048 |
1049 | # extract blocks in (row, column) manner, i.e., stored with column first
1050 | blocks = F.unfold(x, kernel, stride=kernel)
1051 | blocks = blocks.reshape(b, c, *kernel, num_block_h, num_block_w)
1052 | blocks = blocks.permute(5, 4, 0, 1, 2, 3).reshape(num_block_h * num_block_w * b, c, *kernel)
1053 |
1054 | results = fun(blocks)
1055 | results = results.reshape(num_block_h * num_block_w, b, *results.shape[1:]).transpose(0, 1)
1056 | return results
1057 |
1058 |
1059 | def _image_resize_torch(tensor: torch.Tensor,
1060 | scale_factor: typing.Optional[float] = None,
1061 | sizes: typing.Optional[typing.Tuple[int, int]] = None,
1062 | kernel: typing.Union[str, torch.Tensor] = "cubic",
1063 | sigma: float = 2,
1064 | padding_type: str = "reflect",
1065 | antialiasing: bool = True) -> torch.Tensor:
1066 | """
1067 | Args:
1068 | tensor (torch.Tensor):
1069 | scale_factor (float):
1070 | sizes (tuple(int, int)):
1071 | kernel (str, default='cubic'):
1072 | sigma (float, default=2):
1073 | padding_type (str, default='reflect'):
1074 | antialiasing (bool, default=True):
1075 | Return:
1076 | torch.Tensor:
1077 | """
1078 | scales = (scale_factor, scale_factor)
1079 |
1080 | if scale_factor is None and sizes is None:
1081 | raise ValueError('One of scale or sizes must be specified!')
1082 | if scale_factor is not None and sizes is not None:
1083 | raise ValueError('Please specify scale or sizes to avoid conflict!')
1084 |
1085 | tensor, b, c, h, w = _reshape_input_torch(tensor)
1086 |
1087 | if sizes is None and scale_factor is not None:
1088 | '''
1089 | # Check if we can apply the convolution algorithm
1090 | scale_inv = 1 / scale
1091 | if isinstance(kernel, str) and scale_inv.is_integer():
1092 | kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
1093 | elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
1094 | raise ValueError(
1095 | 'An integer downsampling factor '
1096 | 'should be used with a predefined kernel!'
1097 | )
1098 | '''
1099 | # Determine output size
1100 | sizes = (math.ceil(h * scale_factor), math.ceil(w * scale_factor))
1101 | scales = (scale_factor, scale_factor)
1102 |
1103 | if scale_factor is None and sizes is not None:
1104 | scales = (sizes[0] / h, sizes[1] / w)
1105 |
1106 | tensor, dtype = _cast_input_torch(tensor)
1107 |
1108 | if isinstance(kernel, str) and sizes is not None:
1109 | # Core resizing module
1110 | tensor = _resize_1d_torch(
1111 | tensor,
1112 | -2,
1113 | size=sizes[0],
1114 | scale=scales[0],
1115 | kernel=kernel,
1116 | sigma=sigma,
1117 | padding_type=padding_type,
1118 | antialiasing=antialiasing)
1119 | tensor = _resize_1d_torch(
1120 | tensor,
1121 | -1,
1122 | size=sizes[1],
1123 | scale=scales[1],
1124 | kernel=kernel,
1125 | sigma=sigma,
1126 | padding_type=padding_type,
1127 | antialiasing=antialiasing)
1128 | elif isinstance(kernel, torch.Tensor) and scale_factor is not None:
1129 | tensor = _downsampling_2d_torch(tensor, kernel, scale=int(1 / scale_factor))
1130 |
1131 | tensor = _reshape_output_torch(tensor, b, c)
1132 | tensor = _cast_output_torch(tensor, dtype)
1133 | return tensor
1134 |
1135 |
1136 | def _estimate_aggd_parameters_torch(tensor: torch.Tensor,
1137 | get_sigma: bool) -> [torch.Tensor, torch.Tensor, torch.Tensor]:
1138 | """PyTorch implements the BRISQUE (Blind/Referenceless Image Spatial Quality Evaluator) function
1139 | This function is used to estimate an asymmetric generalized Gaussian distribution
1140 |
1141 | Reference papers:
1142 | `No-Reference Image Quality Assessment in the Spatial Domain`
1143 | `Referenceless Image Spatial Quality Evaluation Engine`
1144 |
1145 | Args:
1146 | tensor (torch.Tensor): data vector
1147 | get_sigma (bool): whether to return the covariance mean
1148 |
1149 | Returns:
1150 | aggd_parameters (torch.Tensor): asymmetric generalized Gaussian distribution
1151 | left_std (torch.Tensor): symmetric left data vector variance mean
1152 | right_std (torch.Tensor): Symmetric right side data vector variance mean
1153 |
1154 | """
1155 | # The following is obtained according to the formula and the method provided in the paper on WIki encyclopedia
1156 | aggd = torch.arange(0.2, 10 + 0.001, 0.001).to(tensor)
1157 | r_gam = (2 * torch.lgamma(2. / aggd) - (torch.lgamma(1. / aggd) + torch.lgamma(3. / aggd))).exp()
1158 | r_gam = r_gam.repeat(tensor.size(0), 1)
1159 |
1160 | mask_left = tensor < 0
1161 | mask_right = tensor > 0
1162 | count_left = mask_left.sum(dim=(-1, -2), dtype=torch.float32)
1163 | count_right = mask_right.sum(dim=(-1, -2), dtype=torch.float32)
1164 |
1165 | left_std = torch.sqrt_((tensor * mask_left).pow(2).sum(dim=(-1, -2)) / (count_left + 1e-8))
1166 | right_std = torch.sqrt_((tensor * mask_right).pow(2).sum(dim=(-1, -2)) / (count_right + 1e-8))
1167 | gamma_hat = left_std / right_std
1168 | rhat = tensor.abs().mean(dim=(-1, -2)).pow(2) / tensor.pow(2).mean(dim=(-1, -2))
1169 | rhat_norm = (rhat * (gamma_hat.pow(3) + 1) * (gamma_hat + 1)) / (gamma_hat.pow(2) + 1).pow(2)
1170 |
1171 | array_position = (r_gam - rhat_norm).abs().argmin(dim=-1)
1172 | aggd_parameters = aggd[array_position]
1173 |
1174 | if get_sigma:
1175 | left_beta = left_std.squeeze(-1) * (
1176 | torch.lgamma(1 / aggd_parameters) - torch.lgamma(3 / aggd_parameters)).exp().sqrt()
1177 | right_beta = right_std.squeeze(-1) * (
1178 | torch.lgamma(1 / aggd_parameters) - torch.lgamma(3 / aggd_parameters)).exp().sqrt()
1179 | return aggd_parameters, left_beta, right_beta
1180 |
1181 | else:
1182 | left_std = left_std.squeeze_(-1)
1183 | right_std = right_std.squeeze_(-1)
1184 | return aggd_parameters, left_std, right_std
1185 |
1186 |
1187 | def _get_mscn_feature_torch(tensor: torch.Tensor) -> Tensor:
1188 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function,
1189 | This function is used to calculate the feature map
1190 |
1191 | Reference papers:
1192 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video`
1193 |
1194 | Args:
1195 | tensor (torch.Tensor): The image to be evaluated for NIQE sharpness
1196 |
1197 | Returns:
1198 | feature (torch.Tensor): image feature map
1199 |
1200 | """
1201 | batch_size = tensor.shape[0]
1202 | aggd_block = tensor[:, [0]]
1203 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters_torch(aggd_block, True)
1204 | feature = [aggd_parameters, (left_beta + right_beta) / 2]
1205 |
1206 | shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
1207 | for i in range(len(shifts)):
1208 | shifted_block = torch.roll(aggd_block, shifts[i], dims=(2, 3))
1209 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters_torch(aggd_block * shifted_block, True)
1210 | mean = (right_beta - left_beta) * (torch.lgamma(2 / aggd_parameters) - torch.lgamma(1 / aggd_parameters)).exp()
1211 | feature.extend((aggd_parameters, mean, left_beta, right_beta))
1212 |
1213 | feature = [x.reshape(batch_size, 1) for x in feature]
1214 | feature = torch.cat(feature, dim=-1)
1215 |
1216 | return feature
1217 |
1218 |
1219 | def _fit_mscn_ipac_torch(tensor: torch.Tensor,
1220 | mu_pris_param: torch.Tensor,
1221 | cov_pris_param: torch.Tensor,
1222 | block_size_height: int,
1223 | block_size_width: int,
1224 | kernel_size: int = 7,
1225 | kernel_sigma: float = 7. / 6,
1226 | padding: str = "replicate") -> Tensor:
1227 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function,
1228 | This function is used to fit the inner product of adjacent coefficients of MSCN
1229 |
1230 | Reference papers:
1231 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video`
1232 |
1233 | Args:
1234 | tensor (torch.Tensor): The image to be evaluated for NIQE sharpness
1235 | mu_pris_param (torch.Tensor): mean of predefined multivariate Gaussians, model computed on original dataset
1236 | cov_pris_param (torch.Tensor): Covariance of predefined multivariate Gaussian model computed on original dataset
1237 | block_size_height (int): the height of the block into which the image is divided
1238 | block_size_width (int): The width of the block into which the image is divided
1239 | kernel_size (int): Gaussian filter size
1240 | kernel_sigma (int): sigma value in Gaussian filter
1241 | padding (str): how to pad pixels. Default: ``replicate``
1242 |
1243 | Returns:
1244 | niqe_metric (torch.Tensor): NIQE score
1245 |
1246 | """
1247 | # crop image
1248 | b, c, h, w = tensor.shape
1249 | num_block_height = math.floor(h / block_size_height)
1250 | num_block_width = math.floor(w / block_size_width)
1251 | tensor = tensor[..., 0:num_block_height * block_size_height, 0:num_block_width * block_size_width]
1252 |
1253 | distparam = []
1254 | for scale in (1, 2):
1255 | kernel = _fspecial_gaussian_torch(kernel_size, kernel_sigma, 1).to(tensor)
1256 | mu = _image_filter(tensor, kernel, padding=padding)
1257 | std = _image_filter(tensor ** 2, kernel, padding=padding)
1258 | sigma = torch.sqrt_((std - mu ** 2).abs() + 1e-8)
1259 | structdis = (tensor - mu) / (sigma + 1)
1260 |
1261 | distparam.append(_blockproc_torch(structdis,
1262 | [block_size_height // scale, block_size_width // scale],
1263 | fun=_get_mscn_feature_torch))
1264 |
1265 | if scale == 1:
1266 | tensor = _image_resize_torch(tensor / 255., scale_factor=0.5, antialiasing=True)
1267 | tensor = tensor * 255.
1268 |
1269 | distparam = torch.cat(distparam, -1)
1270 |
1271 | # Fit MVG (Multivariate Gaussian) model to distorted patch features
1272 | mu_distparam = _nanmean_torch(distparam, dim=1)
1273 | cov_distparam = _nancov_torch(distparam)
1274 |
1275 | invcov_param = torch.linalg.pinv((cov_pris_param + cov_distparam) / 2)
1276 | diff = (mu_pris_param - mu_distparam).unsqueeze(1)
1277 | niqe_metric = torch.bmm(torch.bmm(diff, invcov_param), diff.transpose(1, 2)).squeeze()
1278 | niqe_metric = torch.sqrt(niqe_metric)
1279 |
1280 | return niqe_metric
1281 |
1282 |
1283 | def _niqe_torch(tensor: torch.Tensor,
1284 | crop_border: int,
1285 | niqe_model_path: str,
1286 | block_size_height: int = 96,
1287 | block_size_width: int = 96
1288 | ) -> Tensor:
1289 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function,
1290 |
1291 | Attributes:
1292 | tensor (torch.Tensor): The image to evaluate the sharpness of the BRISQUE
1293 | crop_border (int): crop border a few pixels
1294 | niqe_model_path (str): NIQE model estimator weight address
1295 | block_size_height (int): The height of the block the image is divided into. Default: 96
1296 | block_size_width (int): The width of the block the image is divided into. Default: 96
1297 |
1298 | Returns:
1299 | niqe_metrics (torch.Tensor): NIQE metrics
1300 |
1301 | """
1302 | # crop border pixels
1303 | if crop_border > 0:
1304 | tensor = tensor[:, :, crop_border:-crop_border, crop_border:-crop_border]
1305 |
1306 | # Load the NIQE feature extraction model
1307 | niqe_model = loadmat(niqe_model_path)
1308 |
1309 | mu_pris_param = np.ravel(niqe_model["mu_prisparam"])
1310 | cov_pris_param = niqe_model["cov_prisparam"]
1311 | mu_pris_param = torch.from_numpy(mu_pris_param).to(tensor)
1312 | cov_pris_param = torch.from_numpy(cov_pris_param).to(tensor)
1313 |
1314 | mu_pris_param = mu_pris_param.repeat(tensor.size(0), 1)
1315 | cov_pris_param = cov_pris_param.repeat(tensor.size(0), 1, 1)
1316 |
1317 | # NIQE only tests on Y channel images and needs to convert the images
1318 | y_tensor = rgb_to_ycbcr_torch(tensor, only_use_y_channel=True)
1319 | y_tensor *= 255.0
1320 | y_tensor = y_tensor.round()
1321 |
1322 | # Convert data type to torch.float64 bit
1323 | y_tensor = y_tensor.to(torch.float64)
1324 |
1325 | niqe_metric = _fit_mscn_ipac_torch(y_tensor,
1326 | mu_pris_param,
1327 | cov_pris_param,
1328 | block_size_height,
1329 | block_size_width)
1330 |
1331 | return niqe_metric
1332 |
1333 |
1334 | class NIQE(nn.Module):
1335 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function,
1336 |
1337 | Attributes:
1338 | crop_border (int): crop border a few pixels
1339 | niqe_model_path (str): NIQE model address
1340 | block_size_height (int): The height of the block the image is divided into. Default: 96
1341 | block_size_width (int): The width of the block the image is divided into. Default: 96
1342 |
1343 | Returns:
1344 | niqe_metrics (torch.Tensor): NIQE metrics
1345 |
1346 | """
1347 |
1348 | def __init__(self, crop_border: int,
1349 | niqe_model_path: str,
1350 | block_size_height: int = 96,
1351 | block_size_width: int = 96) -> None:
1352 | super().__init__()
1353 | self.crop_border = crop_border
1354 | self.niqe_model_path = niqe_model_path
1355 | self.block_size_height = block_size_height
1356 | self.block_size_width = block_size_width
1357 |
1358 | def forward(self, raw_tensor: torch.Tensor) -> Tensor:
1359 | niqe_metrics = _niqe_torch(raw_tensor,
1360 | self.crop_border,
1361 | self.niqe_model_path,
1362 | self.block_size_height,
1363 | self.block_size_width)
1364 |
1365 | return niqe_metrics
1366 |
--------------------------------------------------------------------------------