├── results └── .gitkeep ├── samples └── .gitkeep ├── assets └── result.png ├── requirements.txt ├── scripts ├── run.py └── prepare_dataset.py ├── data └── README.md ├── .gitignore ├── config.py ├── test.py ├── model.py ├── README.md ├── utils.py ├── dataset.py ├── LICENSE ├── train.py ├── imgproc.py └── image_quality_assessment.py /results/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /samples/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /assets/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lornatang/RCAN-PyTorch/HEAD/assets/result.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python>=4.6.0.66 2 | numpy>=1.23.5 3 | tqdm>=4.63.1 4 | torch>=1.13.0+cu117 5 | natsort>=8.1.0 6 | typing>=3.7.4.3 7 | scipy>=1.9.3 -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Prepare dataset 4 | os.system("python ./prepare_dataset.py --images_dir ../data/DIV2K/original/train --output_dir ../data/DIV2K/RCAN/train --image_size 450 --step 225 --num_workers 10") 5 | os.system("python ./prepare_dataset.py --images_dir ../data/DIV2K/original/valid --output_dir ../data/DIV2K/RCAN/valid --image_size 450 --step 225 --num_workers 10") 6 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | ## Download datasets 4 | 5 | ### Download train dataset 6 | 7 | #### DIV2K 8 | 9 | - Image format 10 | - [Baidu Driver](https://pan.baidu.com/s/1EXXbhxxRDtqPosT2WL8NkA) access: `llot` 11 | 12 | ### Download valid dataset 13 | 14 | #### Set5 15 | 16 | - Image format 17 | - [Google Driver](https://drive.google.com/file/d/1GtQuoEN78q3AIP8vkh-17X90thYp_FfU/view?usp=sharing) 18 | - [Baidu Driver](https://pan.baidu.com/s/1dlPcpwRPUBOnxlfW5--S5g) access:`llot` 19 | 20 | #### Set14 21 | 22 | - Image format 23 | - [Google Driver](https://drive.google.com/file/d/1CzwwAtLSW9sog3acXj8s7Hg3S7kr2HiZ/view?usp=sharing) 24 | - [Baidu Driver](https://pan.baidu.com/s/1KBS38UAjM7bJ_e6a54eHaA) access:`llot` 25 | 26 | #### BSD100 27 | 28 | - Image format 29 | - [Google Driver](https://drive.google.com/file/d/1xkjWJGZgwWjDZZFN6KWlNMvHXmRORvdG/view?usp=sharing) 30 | - [Baidu Driver](https://pan.baidu.com/s/1EBVulUpsQrDmZfqnm4jOZw) access:`llot` 31 | 32 | #### BSD200 33 | 34 | - Image format 35 | - [Google Driver](https://drive.google.com/file/d/1cdMYTPr77RdOgyAvJPMQqaJHWrD5ma5n/view?usp=sharing) 36 | - [Baidu Driver](https://pan.baidu.com/s/1xahPw4dNNc3XspMMOuw1Bw) access:`llot` 37 | 38 | ## Train dataset struct information 39 | 40 | ### Image format 41 | 42 | ```text 43 | - DIV2K 44 | - RCAN 45 | - train 46 | - valid 47 | ``` 48 | 49 | ## Test dataset struct information 50 | 51 | ### Image format 52 | 53 | ```text 54 | - Set5 55 | - GTmod12 56 | - baby.png 57 | - bird.png 58 | - ... 59 | - LRbicx4 60 | - baby.png 61 | - bird.png 62 | - ... 63 | - Set14 64 | - GTmod12 65 | - baboon.png 66 | - barbara.png 67 | - ... 68 | - LRbicx4 69 | - baboon.png 70 | - barbara.png 71 | - ... 72 | ``` 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # custom 126 | .idea 127 | .vscode 128 | 129 | # Mac configure file. 130 | .DS_Store 131 | 132 | # Program run create directory. 133 | data 134 | results 135 | samples 136 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import random 15 | 16 | import numpy as np 17 | import torch 18 | from torch.backends import cudnn 19 | 20 | # Random seed to maintain reproducible results 21 | random.seed(0) 22 | torch.manual_seed(0) 23 | np.random.seed(0) 24 | # Use GPU for training by default 25 | device = torch.device("cuda", 0) 26 | # Turning on when the image size does not change during training can speed up training 27 | cudnn.benchmark = True 28 | # When evaluating the performance of the SR model, whether to verify only the Y channel image data 29 | only_test_y_channel = True 30 | # Model architecture name 31 | model_arch_name = "rcan_x4" 32 | # Upscale factor 33 | upscale_factor = 4 34 | # Current configuration parameter method 35 | mode = "train" 36 | # Experiment name, easy to save weights and log files 37 | exp_name = "RCAN_x4-DIV2K" 38 | 39 | if mode == "train": 40 | train_gt_images_dir = f"./data/DIV2K/RCAN/train" 41 | 42 | test_gt_images_dir = f"./data/Set5/GTmod12" 43 | test_lr_images_dir = f"./data/Set5/LRbicx{upscale_factor}" 44 | 45 | train_gt_image_size = int(upscale_factor * 48) 46 | batch_size = 16 47 | num_workers = 4 48 | 49 | # Load the address of the pretrained model 50 | pretrained_model_weights_path = f"" 51 | 52 | # Incremental training and migration training 53 | resume_model_weights_path = f"" 54 | 55 | # Total num epochs 56 | epochs = 1000 57 | 58 | # Loss function weight 59 | loss_weight = [1.0] 60 | 61 | # Optimizer parameter 62 | model_lr = 1e-4 63 | model_betas = (0.9, 0.99) 64 | model_eps = 1e-4 # Keep no nan 65 | model_weight_decay = 0.0 66 | 67 | # EMA parameter 68 | model_ema_decay = 0.999 69 | 70 | # StepLR scheduler parameter 71 | lr_scheduler_step_size = epochs // 5 72 | lr_scheduler_gamma = 0.5 73 | 74 | # How many iterations to print the training result 75 | train_print_frequency = 100 76 | test_print_frequency = 1 77 | 78 | if mode == "test": 79 | test_gt_images_dir = f"./data/Set5/GTmod12" 80 | test_sr_images_dir = f"./results/test/{exp_name}" 81 | test_lr_images_dir = f"./data/Set5/LRbicx{upscale_factor}" 82 | 83 | model_weights_path = f"./results/pretrained_models/RCAN_x4-DIV2K-2dfffdd2.pth.tar" 84 | -------------------------------------------------------------------------------- /scripts/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import argparse 15 | import multiprocessing 16 | import os 17 | import shutil 18 | 19 | import cv2 20 | import numpy as np 21 | from tqdm import tqdm 22 | 23 | 24 | def main(args) -> None: 25 | if os.path.exists(args.output_dir): 26 | shutil.rmtree(args.output_dir) 27 | os.makedirs(args.output_dir) 28 | 29 | # Get all image paths 30 | image_file_names = os.listdir(args.images_dir) 31 | 32 | # Splitting images with multiple threads 33 | progress_bar = tqdm(total=len(image_file_names), unit="image", desc="Prepare split image") 34 | workers_pool = multiprocessing.Pool(args.num_workers) 35 | for image_file_name in image_file_names: 36 | workers_pool.apply_async(worker, args=(image_file_name, args), callback=lambda arg: progress_bar.update(1)) 37 | workers_pool.close() 38 | workers_pool.join() 39 | progress_bar.close() 40 | 41 | 42 | def worker(image_file_name, args) -> None: 43 | image = cv2.imread(f"{args.images_dir}/{image_file_name}", cv2.IMREAD_UNCHANGED) 44 | 45 | image_height, image_width = image.shape[0:2] 46 | 47 | index = 1 48 | if image_height >= args.image_size and image_width >= args.image_size: 49 | for pos_y in range(0, image_height - args.image_size + 1, args.step): 50 | for pos_x in range(0, image_width - args.image_size + 1, args.step): 51 | # Crop 52 | crop_image = image[pos_y: pos_y + args.image_size, pos_x:pos_x + args.image_size, ...] 53 | crop_image = np.ascontiguousarray(crop_image) 54 | # Save image 55 | cv2.imwrite(f"{args.output_dir}/{image_file_name.split('.')[-2]}_{index:04d}.{image_file_name.split('.')[-1]}", crop_image) 56 | 57 | index += 1 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser(description="Prepare database scripts.") 62 | parser.add_argument("--images_dir", type=str, help="Path to input image directory.") 63 | parser.add_argument("--output_dir", type=str, help="Path to generator image directory.") 64 | parser.add_argument("--image_size", type=int, help="Low-resolution image size from raw image.") 65 | parser.add_argument("--step", type=int, help="Crop image similar to sliding window.") 66 | parser.add_argument("--num_workers", type=int, help="How many threads to open at the same time.") 67 | args = parser.parse_args() 68 | 69 | main(args) 70 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import time 15 | 16 | import torch 17 | from torch import nn 18 | from torch.utils.data import DataLoader 19 | 20 | import config 21 | import model 22 | from dataset import TestImageDataset, CUDAPrefetcher 23 | from utils import build_iqa_model, load_state_dict, make_directory, AverageMeter, ProgressMeter 24 | 25 | 26 | def load_dataset(test_gt_images_dir: str, test_lr_images_dir: str, device: torch.device) -> CUDAPrefetcher: 27 | test_datasets = TestImageDataset(test_gt_images_dir, test_lr_images_dir) 28 | test_dataloader = DataLoader(test_datasets, 29 | batch_size=1, 30 | shuffle=False, 31 | num_workers=1, 32 | pin_memory=True, 33 | drop_last=False, 34 | persistent_workers=False) 35 | test_data_prefetcher = CUDAPrefetcher(test_dataloader, device) 36 | 37 | return test_data_prefetcher 38 | 39 | 40 | def build_model(model_arch_name: str, device: torch.device) -> nn.Module: 41 | # Build model 42 | sr_model = model.__dict__[model_arch_name]() 43 | sr_model = sr_model.to(device=device) 44 | # Set the model to evaluation mode 45 | sr_model.eval() 46 | 47 | return sr_model 48 | 49 | 50 | def test( 51 | sr_model: nn.Module, 52 | test_data_prefetcher: CUDAPrefetcher, 53 | psnr_model: nn.Module, 54 | ssim_model: nn.Module, 55 | device: torch.device = torch.device("cpu"), 56 | print_frequency: int = 1, 57 | ) -> [float, float]: 58 | # The information printed by the progress bar 59 | batch_time = AverageMeter("Time", ":6.3f") 60 | psnres = AverageMeter("PSNR", ":4.2f") 61 | ssimes = AverageMeter("SSIM", ":4.4f") 62 | progress = ProgressMeter(len(test_data_prefetcher), [batch_time, psnres, ssimes], prefix=f"Test: ") 63 | 64 | # Set the model as validation model 65 | sr_model.eval() 66 | 67 | # Initialize data batches 68 | batch_index = 0 69 | 70 | # Set the data set iterator pointer to 0 and load the first batch of data 71 | test_data_prefetcher.reset() 72 | batch_data = test_data_prefetcher.next() 73 | 74 | # Record the start time of verifying a batch 75 | end = time.time() 76 | 77 | while batch_data is not None: 78 | # Load batches of data 79 | gt = batch_data["gt"].to(device=device, non_blocking=True) 80 | lr = batch_data["lr"].to(device=device, non_blocking=True) 81 | 82 | # inference 83 | with torch.no_grad(): 84 | sr = sr_model(lr) 85 | 86 | # Calculate the image IQA 87 | psnr = psnr_model(sr, gt) 88 | ssim = ssim_model(sr, gt) 89 | psnres.update(psnr.item(), lr.size(0)) 90 | ssimes.update(ssim.item(), lr.size(0)) 91 | 92 | # Record the total time to verify a batch 93 | batch_time.update(time.time() - end) 94 | end = time.time() 95 | 96 | # Output a verification log information 97 | if batch_index % print_frequency == 0: 98 | progress.display(batch_index + 1) 99 | 100 | # Preload the next batch of data 101 | batch_data = test_data_prefetcher.next() 102 | 103 | # Add 1 to the number of data batches 104 | batch_index += 1 105 | 106 | # Print the performance index of the model at the current epoch 107 | progress.display_summary() 108 | 109 | return psnres.avg, ssimes.avg 110 | 111 | 112 | def main() -> None: 113 | test_data_prefetcher = load_dataset(config.test_gt_images_dir, config.test_lr_images_dir, config.device) 114 | sr_model = build_model(config.model_arch_name, config.device) 115 | psnr_model, ssim_model = build_iqa_model(config.upscale_factor, config.only_test_y_channel, config.device) 116 | 117 | # Load the super-resolution bsrgan_model weights 118 | sr_model = load_state_dict(sr_model, config.model_weights_path) 119 | 120 | # Create a folder of super-resolution experiment results 121 | make_directory(config.test_sr_images_dir) 122 | 123 | psnr, ssim = test(sr_model, 124 | test_data_prefetcher, 125 | psnr_model, 126 | ssim_model, 127 | config.device) 128 | 129 | print(f"PSNR: {psnr:.2f} dB" 130 | f"SSIM: {ssim:.4f} [u]") 131 | 132 | 133 | if __name__ == "__main__": 134 | main() 135 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import math 15 | 16 | import torch 17 | from torch import nn, Tensor 18 | 19 | __all__ = [ 20 | "RCAN", 21 | "rcan_x2", "rcan_x3", "rcan_x4", "rcan_x8", 22 | ] 23 | 24 | 25 | class RCAN(nn.Module): 26 | def __init__( 27 | self, 28 | in_channels: int, 29 | out_channels: int, 30 | channels: int, 31 | reduction: int, 32 | num_rcab: int, 33 | num_rg: int, 34 | upscale_factor: int, 35 | rgb_mean: tuple = None, 36 | ) -> None: 37 | super(RCAN, self).__init__() 38 | if rgb_mean is None: 39 | rgb_mean = [0.4488, 0.4371, 0.4040] 40 | 41 | # The first layer of convolutional layer 42 | self.conv1 = nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1)) 43 | 44 | # Feature extraction backbone 45 | trunk = [] 46 | for _ in range(num_rg): 47 | trunk.append(_ResidualGroup(channels, reduction, num_rcab)) 48 | self.trunk = nn.Sequential(*trunk) 49 | 50 | # After the feature extraction network, reconnect a layer of convolutional blocks 51 | self.conv2 = nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)) 52 | 53 | # Upsampling convolutional layer. 54 | upsampling = [] 55 | if upscale_factor == 2 or upscale_factor == 4 or upscale_factor == 8: 56 | for _ in range(int(math.log(upscale_factor, 2))): 57 | upsampling.append(_UpsampleBlock(channels, 2)) 58 | elif upscale_factor == 3: 59 | upsampling.append(_UpsampleBlock(channels, 3)) 60 | self.upsampling = nn.Sequential(*upsampling) 61 | 62 | # Output layer. 63 | self.conv3 = nn.Conv2d(channels, out_channels, (3, 3), (1, 1), (1, 1)) 64 | 65 | self.register_buffer("mean", Tensor(rgb_mean).view(1, 3, 1, 1)) 66 | 67 | def forward(self, x: Tensor) -> Tensor: 68 | x = x.sub_(self.mean).mul_(1.) 69 | 70 | conv1 = self.conv1(x) 71 | x = self.trunk(conv1) 72 | x = self.conv2(x) 73 | x = torch.add(x, conv1) 74 | x = self.upsampling(x) 75 | x = self.conv3(x) 76 | 77 | x = x.div_(1.).add_(self.mean) 78 | 79 | return x 80 | 81 | 82 | class _ChannelAttentionLayer(nn.Module): 83 | def __init__(self, channel: int, reduction: int): 84 | super(_ChannelAttentionLayer, self).__init__() 85 | self.channel_attention_layer = nn.Sequential( 86 | nn.AdaptiveAvgPool2d(1), 87 | nn.Conv2d(channel, channel // reduction, (1, 1), (1, 1), (0, 0)), 88 | nn.ReLU(True), 89 | nn.Conv2d(channel // reduction, channel, (1, 1), (1, 1), (0, 0)), 90 | nn.Sigmoid(), 91 | ) 92 | 93 | def forward(self, x: Tensor) -> Tensor: 94 | out = self.channel_attention_layer(x) 95 | 96 | out = torch.mul(out, x) 97 | 98 | return out 99 | 100 | 101 | class _ResidualChannelAttentionBlock(nn.Module): 102 | def __init__(self, channel: int, reduction: int): 103 | super(_ResidualChannelAttentionBlock, self).__init__() 104 | self.residual_channel_attention_block = nn.Sequential( 105 | nn.Conv2d(channel, channel, (3, 3), (1, 1), (1, 1)), 106 | nn.ReLU(True), 107 | nn.Conv2d(channel, channel, (3, 3), (1, 1), (1, 1)), 108 | _ChannelAttentionLayer(channel, reduction), 109 | ) 110 | 111 | def forward(self, x: Tensor) -> Tensor: 112 | identity = x 113 | 114 | out = self.residual_channel_attention_block(x) 115 | 116 | out = torch.add(out, identity) 117 | 118 | return out 119 | 120 | 121 | class _ResidualGroup(nn.Module): 122 | def __init__(self, channel: int, reduction: int, num_rcab: int): 123 | super(_ResidualGroup, self).__init__() 124 | residual_group = [] 125 | 126 | for _ in range(num_rcab): 127 | residual_group.append(_ResidualChannelAttentionBlock(channel, reduction)) 128 | residual_group.append(nn.Conv2d(channel, channel, (3, 3), (1, 1), (1, 1))) 129 | 130 | self.residual_group = nn.Sequential(*residual_group) 131 | 132 | def forward(self, x: Tensor) -> Tensor: 133 | identity = x 134 | 135 | out = self.residual_group(x) 136 | 137 | out = torch.add(out, identity) 138 | 139 | return out 140 | 141 | 142 | class _UpsampleBlock(nn.Module): 143 | def __init__(self, channels: int, upscale_factor: int) -> None: 144 | super(_UpsampleBlock, self).__init__() 145 | self.upsample_block = nn.Sequential( 146 | nn.Conv2d(channels, channels * upscale_factor * upscale_factor, (3, 3), (1, 1), (1, 1)), 147 | nn.PixelShuffle(upscale_factor), 148 | ) 149 | 150 | def forward(self, x: Tensor) -> Tensor: 151 | x = self.upsample_block(x) 152 | 153 | return x 154 | 155 | 156 | def rcan_x2(**kwargs) -> RCAN: 157 | model = RCAN(3, 3, 64, 16, 20, 10, 2, **kwargs) 158 | 159 | return model 160 | 161 | 162 | def rcan_x3(**kwargs) -> RCAN: 163 | model = RCAN(3, 3, 64, 16, 20, 10, 3, **kwargs) 164 | 165 | return model 166 | 167 | 168 | def rcan_x4(**kwargs) -> RCAN: 169 | model = RCAN(3, 3, 64, 16, 20, 10, 4, **kwargs) 170 | 171 | return model 172 | 173 | 174 | def rcan_x8(**kwargs) -> RCAN: 175 | model = RCAN(3, 3, 64, 16, 20, 10, 8, **kwargs) 176 | 177 | return model 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RCAN-PyTorch 2 | 3 | ### Overview 4 | 5 | This repository contains an op-for-op PyTorch reimplementation of [Image Super-Resolution Using Very Deep Residual Channel Attention Networks](https://arxiv.org/abs/1807.02758). 6 | 7 | ### Table of contents 8 | 9 | - [RCAN-PyTorch](#rcan-pytorch) 10 | - [Overview](#overview) 11 | - [Table of contents](#table-of-contents) 12 | - [About Image Super-Resolution Using Very Deep Residual Channel Attention Networks](#about-image-super-resolution-using-very-deep-residual-channel-attention-networks) 13 | - [Download weights](#download-weights) 14 | - [Download datasets](#download-datasets) 15 | - [Test](#test) 16 | - [Train](#train) 17 | - [Result](#result) 18 | - [Credit](#credit) 19 | - [Image Super-Resolution Using Very Deep Residual Channel Attention Networks](#image-super-resolution-using-very-deep-residual-channel-attention-networks) 20 | 21 | ## About Image Super-Resolution Using Very Deep Residual Channel Attention Networks 22 | 23 | If you're new to RCAN, here's an abstract straight from the paper: 24 | 25 | Convolutional neural network (CNN) depth is of crucial importance for image super-resolution (SR). However, we observe that deeper networks for image 26 | SR are more difficult to train. The lowresolution inputs and features contain abundant low-frequency information, which is treated equally across 27 | channels, hence hindering the representational ability of CNNs. To solve these problems, we propose the very deep residual channel attention 28 | networks (RCAN). Specifically, we propose a residual in residual (RIR) structure to form very deep network, which consists of several residual groups 29 | with long skip connections. Each residual group contains some residual blocks with short skip connections. Meanwhile, RIR allows abundant 30 | low-frequency information to be bypassed through multiple skip connections, making the main network focus on learning high-frequency information. 31 | Furthermore, we propose a channel attention mechanism to adaptively rescale channel-wise features by considering interdependencies among channels. 32 | Extensive experiments show that our RCAN achieves better accuracy and visual improvements against state-of-the-art methods. 33 | 34 | ## Download weights 35 | 36 | - [Google Driver](https://drive.google.com/drive/folders/17ju2HN7Y6pyPK2CC_AqnAfTOe9_3hCQ8?usp=sharing) 37 | - [Baidu Driver](https://pan.baidu.com/s/1yNs4rqIb004-NKEdKBJtYg?pwd=llot) 38 | 39 | ## Download datasets 40 | 41 | Contains DIV2K, DIV8K, Flickr2K, OST, T91, Set5, Set14, BSDS100 and BSDS200, etc. 42 | 43 | - [Google Driver](https://drive.google.com/drive/folders/1A6lzGeQrFMxPqJehK9s37ce-tPDj20mD?usp=sharing) 44 | - [Baidu Driver](https://pan.baidu.com/s/1o-8Ty_7q6DiS3ykLU09IVg?pwd=llot) 45 | 46 | ## Test 47 | 48 | Modify the contents of the file as follows. 49 | 50 | - line 29: `upscale_factor` change to the magnification you need to enlarge. 51 | - line 31: `mode` change Set to valid mode. 52 | - line 70: `model_path` change weight address after training. 53 | 54 | ## Train 55 | 56 | Modify the contents of the file as follows. 57 | 58 | - line 29: `upscale_factor` change to the magnification you need to enlarge. 59 | - line 31: `mode` change Set to train mode. 60 | 61 | If you want to load weights that you've trained before, modify the contents of the file as follows. 62 | 63 | ### Resume model 64 | 65 | - line 47: `start_epoch` change number of model training iterations in the previous round. 66 | - line 48: `resume` change to SRResNet model address that needs to be loaded. 67 | 68 | ## Result 69 | 70 | Source of original paper results: https://arxiv.org/pdf/1807.02758.pdf 71 | 72 | In the following table, the value in `()` indicates the result of the project, and `-` indicates no test. 73 | 74 | | Dataset | Scale | PSNR | 75 | |:-------:|:-----:|:----------------:| 76 | | Set5 | 2 | 38.27(**38.09**) | 77 | | Set5 | 3 | 34.74(**34.56**) | 78 | | Set5 | 4 | 32.63(**32.41**) | 79 | | Set5 | 8 | 27.31(**26.97**) | 80 | 81 | Low Resolution / Super Resolution / High Resolution 82 | 83 | 84 | ### Credit 85 | 86 | #### Image Super-Resolution Using Very Deep Residual Channel Attention Networks 87 | 88 | _Yulun Zhang, Kunpeng Li, Kai Li, Lichen Wang, Bineng Zhong, Yun Fu_
89 | 90 | **Abstract**
91 | Convolutional neural network (CNN) depth is of crucial importance for image super-resolution (SR). However, we observe that deeper networks for image 92 | SR are more difficult to train. The low-resolution inputs and features contain abundant low-frequency information, which is treated equally across 93 | channels, hence hindering the representational ability of CNNs. To solve these problems, we propose the very deep residual channel attention 94 | networks (RCAN). Specifically, we propose a residual in residual (RIR) structure to form very deep network, which consists of several residual groups 95 | with long skip connections. Each residual group contains some residual blocks with short skip connections. Meanwhile, RIR allows abundant 96 | low-frequency information to be bypassed through multiple skip connections, making the main network focus on learning high-frequency information. 97 | Furthermore, we propose a channel attention mechanism to adaptively rescale channel-wise features by considering interdependencies among channels. 98 | Extensive experiments show that our RCAN achieves better accuracy and visual improvements against state-of-the-art methods. 99 | 100 | [[Code]](https://github.com/yulunzhang/RCAN) [[Paper]](https://arxiv.org/pdf/1807.02758) 101 | 102 | ``` 103 | @article{DBLP:journals/corr/abs-1807-02758, 104 | author = {Yulun Zhang and 105 | Kunpeng Li and 106 | Kai Li and 107 | Lichen Wang and 108 | Bineng Zhong and 109 | Yun Fu}, 110 | title = {Image Super-Resolution Using Very Deep Residual Channel Attention 111 | Networks}, 112 | journal = {CoRR}, 113 | volume = {abs/1807.02758}, 114 | year = {2018}, 115 | url = {http://arxiv.org/abs/1807.02758}, 116 | eprinttype = {arXiv}, 117 | eprint = {1807.02758}, 118 | timestamp = {Tue, 20 Nov 2018 12:24:39 +0100}, 119 | biburl = {https://dblp.org/rec/journals/corr/abs-1807-02758.bib}, 120 | bibsource = {dblp computer science bibliography, https://dblp.org} 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import os 15 | import shutil 16 | from enum import Enum 17 | from typing import Any 18 | 19 | import torch 20 | from torch import nn 21 | from torch.nn import Module 22 | from torch.optim import Optimizer 23 | 24 | from image_quality_assessment import PSNR, SSIM 25 | 26 | __all__ = [ 27 | "build_iqa_model", "load_state_dict", "make_directory", "save_checkpoint", 28 | "Summary", "AverageMeter", "ProgressMeter" 29 | ] 30 | 31 | 32 | def build_iqa_model(upscale_factor: int, only_test_y_channel: bool, device: torch.device) -> tuple[PSNR, SSIM]: 33 | psnr_model = PSNR(upscale_factor, only_test_y_channel) 34 | ssim_model = SSIM(upscale_factor, only_test_y_channel) 35 | psnr_model = psnr_model.to(device=device) 36 | ssim_model = ssim_model.to(device=device) 37 | 38 | return psnr_model, ssim_model 39 | 40 | 41 | def load_state_dict( 42 | model: nn.Module, 43 | model_weights_path: str, 44 | ema_model: nn.Module = None, 45 | optimizer: torch.optim.Optimizer = None, 46 | scheduler: torch.optim.lr_scheduler = None, 47 | load_mode: str = None, 48 | ) -> tuple[Module, Module, Any, Any, Any, Optimizer | None, Any] | tuple[Module, Any, Any, Any, Optimizer | None, Any] | Module: 49 | # Load model weights 50 | checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage) 51 | 52 | if load_mode == "resume": 53 | # Restore the parameters in the training node to this point 54 | start_epoch = checkpoint["epoch"] 55 | best_psnr = checkpoint["best_psnr"] 56 | best_ssim = checkpoint["best_ssim"] 57 | # Load model state dict. Extract the fitted model weights 58 | model_state_dict = model.state_dict() 59 | state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys()} 60 | # Overwrite the model weights to the current model (base model) 61 | model_state_dict.update(state_dict) 62 | model.load_state_dict(model_state_dict) 63 | # Load the optimizer model 64 | optimizer.load_state_dict(checkpoint["optimizer"]) 65 | 66 | if scheduler is not None: 67 | # Load the scheduler model 68 | scheduler.load_state_dict(checkpoint["scheduler"]) 69 | 70 | if ema_model is not None: 71 | # Load ema model state dict. Extract the fitted model weights 72 | ema_model_state_dict = ema_model.state_dict() 73 | ema_state_dict = {k: v for k, v in checkpoint["ema_state_dict"].items() if k in ema_model_state_dict.keys()} 74 | # Overwrite the model weights to the current model (ema model) 75 | ema_model_state_dict.update(ema_state_dict) 76 | ema_model.load_state_dict(ema_model_state_dict) 77 | 78 | return model, ema_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler 79 | else: 80 | # Load model state dict. Extract the fitted model weights 81 | model_state_dict = model.state_dict() 82 | state_dict = {k: v for k, v in checkpoint["state_dict"].items() if 83 | k in model_state_dict.keys() and v.size() == model_state_dict[k].size()} 84 | # Overwrite the model weights to the current model 85 | model_state_dict.update(state_dict) 86 | model.load_state_dict(model_state_dict) 87 | 88 | return model 89 | 90 | 91 | def make_directory(dir_path: str) -> None: 92 | if not os.path.exists(dir_path): 93 | os.makedirs(dir_path) 94 | 95 | 96 | def save_checkpoint( 97 | state_dict: dict, 98 | file_name: str, 99 | samples_dir: str, 100 | results_dir: str, 101 | best_file_name: str, 102 | last_file_name: str, 103 | is_best: bool = False, 104 | is_last: bool = False, 105 | ) -> None: 106 | checkpoint_path = os.path.join(samples_dir, file_name) 107 | torch.save(state_dict, checkpoint_path) 108 | 109 | if is_best: 110 | shutil.copyfile(checkpoint_path, os.path.join(results_dir, best_file_name)) 111 | if is_last: 112 | shutil.copyfile(checkpoint_path, os.path.join(results_dir, last_file_name)) 113 | 114 | 115 | class Summary(Enum): 116 | NONE = 0 117 | AVERAGE = 1 118 | SUM = 2 119 | COUNT = 3 120 | 121 | 122 | class AverageMeter(object): 123 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): 124 | self.name = name 125 | self.fmt = fmt 126 | self.summary_type = summary_type 127 | self.reset() 128 | 129 | def reset(self): 130 | self.val = 0 131 | self.avg = 0 132 | self.sum = 0 133 | self.count = 0 134 | 135 | def update(self, val, n=1): 136 | self.val = val 137 | self.sum += val * n 138 | self.count += n 139 | self.avg = self.sum / self.count 140 | 141 | def __str__(self): 142 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 143 | return fmtstr.format(**self.__dict__) 144 | 145 | def summary(self): 146 | if self.summary_type is Summary.NONE: 147 | fmtstr = "" 148 | elif self.summary_type is Summary.AVERAGE: 149 | fmtstr = "{name} {avg:.2f}" 150 | elif self.summary_type is Summary.SUM: 151 | fmtstr = "{name} {sum:.2f}" 152 | elif self.summary_type is Summary.COUNT: 153 | fmtstr = "{name} {count:.2f}" 154 | else: 155 | raise ValueError(f"Invalid summary type {self.summary_type}") 156 | 157 | return fmtstr.format(**self.__dict__) 158 | 159 | 160 | class ProgressMeter(object): 161 | def __init__(self, num_batches, meters, prefix=""): 162 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 163 | self.meters = meters 164 | self.prefix = prefix 165 | 166 | def display(self, batch): 167 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 168 | entries += [str(meter) for meter in self.meters] 169 | print("\t".join(entries)) 170 | 171 | def display_summary(self): 172 | entries = [" *"] 173 | entries += [meter.summary() for meter in self.meters] 174 | print(" ".join(entries)) 175 | 176 | def _get_batch_fmtstr(self, num_batches): 177 | num_digits = len(str(num_batches // 1)) 178 | fmt = "{:" + str(num_digits) + "d}" 179 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 180 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import os 15 | import queue 16 | import threading 17 | 18 | import cv2 19 | import numpy as np 20 | import torch 21 | from torch import Tensor 22 | from torch.utils.data import Dataset, DataLoader 23 | 24 | import imgproc 25 | 26 | __all__ = [ 27 | "TrainImageDataset", "TestImageDataset", 28 | "PrefetchGenerator", "PrefetchDataLoader", "CPUPrefetcher", "CUDAPrefetcher", 29 | ] 30 | 31 | 32 | class TrainImageDataset(Dataset): 33 | """Define train dataset loading methods. 34 | 35 | Args: 36 | train_gt_image_dir (str): Train ground-truth dataset address. 37 | train_gt_image_size (int): Train ground-truth resolution image size. 38 | upscale_factor (int): Image up scale factor. 39 | 40 | """ 41 | 42 | def __init__( 43 | self, 44 | train_gt_image_dir: str, 45 | train_gt_image_size: int, 46 | upscale_factor: int, 47 | ) -> None: 48 | super(TrainImageDataset, self).__init__() 49 | self.image_file_names = [os.path.join(train_gt_image_dir, image_file_name) for image_file_name in 50 | os.listdir(train_gt_image_dir)] 51 | self.train_gt_image_size = train_gt_image_size 52 | self.upscale_factor = upscale_factor 53 | 54 | def __getitem__(self, batch_index: int) -> [dict[str, Tensor], dict[str, Tensor]]: 55 | # Read a batch of image data 56 | gt_image = cv2.imread(self.image_file_names[batch_index]).astype(np.float32) / 255. 57 | 58 | # Image processing operations 59 | gt_crop_image = imgproc.random_crop(gt_image, self.train_gt_image_size) 60 | gt_crop_image = imgproc.random_rotate(gt_crop_image, [90, 180, 270]) 61 | gt_crop_image = imgproc.random_horizontally_flip(gt_crop_image, 0.5) 62 | gt_crop_image = imgproc.random_vertically_flip(gt_crop_image, 0.5) 63 | 64 | lr_crop_image = imgproc.image_resize(gt_crop_image, 1 / self.upscale_factor) 65 | 66 | # BGR convert RGB 67 | gt_crop_image = cv2.cvtColor(gt_crop_image, cv2.COLOR_BGR2RGB) 68 | lr_crop_image = cv2.cvtColor(lr_crop_image, cv2.COLOR_BGR2RGB) 69 | 70 | # Convert image data into Tensor stream format (PyTorch). 71 | # Note: The range of input and output is between [0, 1] 72 | gt_crop_tensor = imgproc.image_to_tensor(gt_crop_image, False, False) 73 | lr_crop_tensor = imgproc.image_to_tensor(lr_crop_image, False, False) 74 | 75 | return {"gt": gt_crop_tensor, "lr": lr_crop_tensor} 76 | 77 | def __len__(self) -> int: 78 | return len(self.image_file_names) 79 | 80 | 81 | class TestImageDataset(Dataset): 82 | """Define Test dataset loading methods. 83 | 84 | Args: 85 | test_gt_images_dir (str): ground truth image in test image 86 | test_lr_images_dir (str): low-resolution image in test image 87 | """ 88 | 89 | def __init__(self, test_gt_images_dir: str, test_lr_images_dir: str) -> None: 90 | super(TestImageDataset, self).__init__() 91 | # Get all image file names in folder 92 | self.gt_image_file_names = [os.path.join(test_gt_images_dir, x) for x in os.listdir(test_gt_images_dir)] 93 | self.lr_image_file_names = [os.path.join(test_lr_images_dir, x) for x in os.listdir(test_lr_images_dir)] 94 | 95 | def __getitem__(self, batch_index: int) -> [torch.Tensor, torch.Tensor]: 96 | # Read a batch of image data 97 | gt_image = cv2.imread(self.gt_image_file_names[batch_index]).astype(np.float32) / 255. 98 | lr_image = cv2.imread(self.lr_image_file_names[batch_index]).astype(np.float32) / 255. 99 | 100 | # BGR convert RGB 101 | gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB) 102 | lr_image = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB) 103 | 104 | # Convert image data into Tensor stream format (PyTorch). 105 | # Note: The range of input and output is between [0, 1] 106 | gt_tensor = imgproc.image_to_tensor(gt_image, False, False) 107 | lr_tensor = imgproc.image_to_tensor(lr_image, False, False) 108 | 109 | return {"gt": gt_tensor, "lr": lr_tensor} 110 | 111 | def __len__(self) -> int: 112 | return len(self.gt_image_file_names) 113 | 114 | 115 | class PrefetchGenerator(threading.Thread): 116 | """A fast data prefetch generator. 117 | 118 | Args: 119 | generator: Data generator. 120 | num_data_prefetch_queue (int): How many early data load queues. 121 | """ 122 | 123 | def __init__(self, generator, num_data_prefetch_queue: int) -> None: 124 | threading.Thread.__init__(self) 125 | self.queue = queue.Queue(num_data_prefetch_queue) 126 | self.generator = generator 127 | self.daemon = True 128 | self.start() 129 | 130 | def run(self) -> None: 131 | for item in self.generator: 132 | self.queue.put(item) 133 | self.queue.put(None) 134 | 135 | def __next__(self): 136 | next_item = self.queue.get() 137 | if next_item is None: 138 | raise StopIteration 139 | return next_item 140 | 141 | def __iter__(self): 142 | return self 143 | 144 | 145 | class PrefetchDataLoader(DataLoader): 146 | """A fast data prefetch dataloader. 147 | 148 | Args: 149 | num_data_prefetch_queue (int): How many early data load queues. 150 | kwargs (dict): Other extended parameters. 151 | """ 152 | 153 | def __init__(self, num_data_prefetch_queue: int, **kwargs) -> None: 154 | self.num_data_prefetch_queue = num_data_prefetch_queue 155 | super(PrefetchDataLoader, self).__init__(**kwargs) 156 | 157 | def __iter__(self): 158 | return PrefetchGenerator(super().__iter__(), self.num_data_prefetch_queue) 159 | 160 | 161 | class CPUPrefetcher: 162 | """Use the CPU side to accelerate data reading. 163 | 164 | Args: 165 | dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. 166 | """ 167 | 168 | def __init__(self, dataloader: DataLoader) -> None: 169 | self.original_dataloader = dataloader 170 | self.data = iter(dataloader) 171 | 172 | def next(self): 173 | try: 174 | return next(self.data) 175 | except StopIteration: 176 | return None 177 | 178 | def reset(self): 179 | self.data = iter(self.original_dataloader) 180 | 181 | def __len__(self) -> int: 182 | return len(self.original_dataloader) 183 | 184 | 185 | class CUDAPrefetcher: 186 | """Use the CUDA side to accelerate data reading. 187 | 188 | Args: 189 | dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. 190 | device (torch.device): Specify running device. 191 | """ 192 | 193 | def __init__(self, dataloader: DataLoader, device: torch.device): 194 | self.batch_data = None 195 | self.original_dataloader = dataloader 196 | self.device = device 197 | 198 | self.data = iter(dataloader) 199 | self.stream = torch.cuda.Stream() 200 | self.preload() 201 | 202 | def preload(self): 203 | try: 204 | self.batch_data = next(self.data) 205 | except StopIteration: 206 | self.batch_data = None 207 | return None 208 | 209 | with torch.cuda.stream(self.stream): 210 | for k, v in self.batch_data.items(): 211 | if torch.is_tensor(v): 212 | self.batch_data[k] = self.batch_data[k].to(self.device, non_blocking=True) 213 | 214 | def next(self): 215 | torch.cuda.current_stream().wait_stream(self.stream) 216 | batch_data = self.batch_data 217 | self.preload() 218 | return batch_data 219 | 220 | def reset(self): 221 | self.data = iter(self.original_dataloader) 222 | self.preload() 223 | 224 | def __len__(self) -> int: 225 | return len(self.original_dataloader) 226 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================ 14 | import os 15 | import time 16 | 17 | import torch 18 | from torch import nn, optim 19 | from torch.cuda import amp 20 | from torch.optim import lr_scheduler 21 | from torch.optim.swa_utils import AveragedModel 22 | from torch.utils.data import DataLoader 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | import config 26 | import model 27 | from dataset import CUDAPrefetcher, TrainImageDataset, TestImageDataset 28 | from test import test 29 | from utils import build_iqa_model, load_state_dict, make_directory, save_checkpoint, AverageMeter, ProgressMeter 30 | 31 | 32 | def main(): 33 | # Initialize the gradient scaler 34 | scaler = amp.GradScaler() 35 | 36 | # Initialize the number of training epochs 37 | start_epoch = 0 38 | 39 | # Initialize training to generate network evaluation indicators 40 | best_psnr = 0.0 41 | best_ssim = 0.0 42 | 43 | train_data_prefetcher, test_data_prefetcher = load_dataset(config.train_gt_images_dir, 44 | config.train_gt_image_size, 45 | config.test_gt_images_dir, 46 | config.test_lr_images_dir, 47 | config.upscale_factor, 48 | config.batch_size, 49 | config.num_workers, 50 | config.device) 51 | print("Load all datasets successfully.") 52 | 53 | sr_model, ema_sr_model = build_model(config.model_arch_name, config.device) 54 | print(f"Build `{config.model_arch_name}` model successfully.") 55 | 56 | criterion = define_loss(config.device) 57 | print("Define all loss functions successfully.") 58 | 59 | optimizer = define_optimizer(sr_model) 60 | print("Define all optimizer functions successfully.") 61 | 62 | scheduler = define_scheduler(optimizer) 63 | print("Define all optimizer scheduler functions successfully.") 64 | 65 | # Create an IQA evaluation model 66 | psnr_model, ssim_model = build_iqa_model(config.upscale_factor, config.only_test_y_channel, config.device) 67 | 68 | print("Check whether to load pretrained model weights...") 69 | if config.pretrained_model_weights_path: 70 | sr_model = load_state_dict(sr_model, config.pretrained_model_weights_path) 71 | print(f"Loaded `{config.pretrained_model_weights_path}` pretrained model weights successfully.") 72 | else: 73 | print("Pretrained model weights not found.") 74 | 75 | print("Check whether the resume model is restored...") 76 | if config.resume_model_weights_path: 77 | sr_model, ema_sr_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler = load_state_dict( 78 | sr_model, 79 | config.resume_model_weights_path, 80 | ema_sr_model, 81 | optimizer, 82 | scheduler, 83 | "resume") 84 | print("Loaded resume model weights.") 85 | else: 86 | print("Resume training model not found. Start training from scratch.") 87 | 88 | # Create a experiment results 89 | samples_dir = os.path.join("samples", config.exp_name) 90 | results_dir = os.path.join("results", config.exp_name) 91 | make_directory(samples_dir) 92 | make_directory(results_dir) 93 | 94 | # Create training process log file 95 | writer = SummaryWriter(os.path.join("samples", "logs", config.exp_name)) 96 | 97 | for epoch in range(start_epoch, config.epochs): 98 | train(sr_model, 99 | ema_sr_model, 100 | train_data_prefetcher, 101 | criterion, 102 | optimizer, 103 | epoch, 104 | scaler, 105 | writer, 106 | config.device, 107 | config.train_print_frequency) 108 | psnr, ssim = test(sr_model, 109 | test_data_prefetcher, 110 | psnr_model, 111 | ssim_model, 112 | config.device, 113 | config.test_print_frequency) 114 | 115 | # Write the evaluation results to the tensorboard 116 | writer.add_scalar(f"Test/PSNR", psnr, epoch + 1) 117 | writer.add_scalar(f"Test/SSIM", ssim, epoch + 1) 118 | 119 | print("\n") 120 | 121 | # Update LR 122 | scheduler.step() 123 | 124 | # Automatically save the model with the highest index 125 | is_best = psnr > best_psnr and ssim > best_ssim 126 | is_last = (epoch + 1) == config.epochs 127 | best_psnr = max(psnr, best_psnr) 128 | best_ssim = max(ssim, best_ssim) 129 | save_checkpoint({"epoch": epoch + 1, 130 | "best_psnr": best_psnr, 131 | "best_ssim": best_ssim, 132 | "state_dict": sr_model.state_dict(), 133 | "ema_state_dict": ema_sr_model.state_dict(), 134 | "optimizer": optimizer.state_dict(), 135 | "scheduler": scheduler.state_dict()}, 136 | f"epoch_{epoch + 1}.pth.tar", 137 | samples_dir, 138 | results_dir, 139 | "best.pth.tar", 140 | "last.pth.tar", 141 | is_best, 142 | is_last) 143 | 144 | 145 | def load_dataset( 146 | train_gt_images_dir: str, 147 | train_gt_image_size: int, 148 | test_gt_images_dir: str, 149 | test_lr_images_dir: str, 150 | upscale_factor: int, 151 | batch_size: int, 152 | num_workers: int, 153 | device: torch.device, 154 | ) -> [CUDAPrefetcher, CUDAPrefetcher]: 155 | # Load train, test and valid datasets 156 | train_datasets = TrainImageDataset(train_gt_images_dir, train_gt_image_size, upscale_factor) 157 | test_datasets = TestImageDataset(test_gt_images_dir, test_lr_images_dir) 158 | 159 | # Generator all dataloader 160 | train_dataloader = DataLoader(train_datasets, 161 | batch_size=batch_size, 162 | shuffle=True, 163 | num_workers=num_workers, 164 | pin_memory=True, 165 | drop_last=True, 166 | persistent_workers=True) 167 | test_dataloader = DataLoader(test_datasets, 168 | batch_size=1, 169 | shuffle=False, 170 | num_workers=1, 171 | pin_memory=True, 172 | drop_last=False, 173 | persistent_workers=False) 174 | 175 | # Place all data on the preprocessing data loader 176 | train_prefetcher = CUDAPrefetcher(train_dataloader, device) 177 | test_prefetcher = CUDAPrefetcher(test_dataloader, device) 178 | 179 | return train_prefetcher, test_prefetcher 180 | 181 | 182 | def build_model(model_arch_name: str, device: torch.device) -> [nn.Module, nn.Module]: 183 | # Build model 184 | sr_model = model.__dict__[model_arch_name]() 185 | # Generate exponential average model, stabilize model training 186 | ema_avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: (1 - config.model_ema_decay) * averaged_model_parameter + config.model_ema_decay * model_parameter 187 | ema_sr_model = AveragedModel(sr_model, avg_fn=ema_avg_fn) 188 | 189 | sr_model = sr_model.to(device=device) 190 | ema_sr_model = ema_sr_model.to(device=device) 191 | 192 | return sr_model, ema_sr_model 193 | 194 | 195 | def define_loss(device) -> nn.L1Loss: 196 | criterion = nn.L1Loss().to(device=device) 197 | 198 | return criterion 199 | 200 | 201 | def define_optimizer(sr_model: nn.Module) -> optim.Adam: 202 | optimizer = optim.Adam(sr_model.parameters(), 203 | config.model_lr, 204 | config.model_betas, 205 | config.model_eps) 206 | 207 | return optimizer 208 | 209 | 210 | def define_scheduler(optimizer) -> lr_scheduler.StepLR: 211 | scheduler = lr_scheduler.StepLR(optimizer, 212 | config.lr_scheduler_step_size, 213 | config.lr_scheduler_gamma) 214 | 215 | return scheduler 216 | 217 | 218 | def train( 219 | sr_model: nn.Module, 220 | ema_sr_model: nn.Module, 221 | train_data_prefetcher: CUDAPrefetcher, 222 | criterion: nn.L1Loss, 223 | optimizer: optim.Adam, 224 | epoch: int, 225 | scaler: amp.GradScaler, 226 | writer: SummaryWriter, 227 | device: torch.device = torch.device("cpu"), 228 | print_frequency: int = 1, 229 | ) -> None: 230 | # Calculate how many iterations there are under epoch 231 | batches = len(train_data_prefetcher) 232 | # Progress bar print information 233 | batch_time = AverageMeter("Time", ":6.3f") 234 | data_time = AverageMeter("Data", ":6.3f") 235 | losses = AverageMeter("Loss", ":6.6f") 236 | progress = ProgressMeter(batches, [batch_time, data_time, losses], prefix=f"Epoch: [{epoch}]") 237 | 238 | # Put the generator in training mode 239 | sr_model.train() 240 | 241 | # Define loss function weights 242 | loss_weight = torch.Tensor(config.loss_weight).to(device=device) 243 | 244 | # Initialize data batches 245 | batch_index = 0 246 | # Set the dataset iterator pointer to 0 247 | train_data_prefetcher.reset() 248 | # Record the start time of training a batch 249 | end = time.time() 250 | # load the first batch of data 251 | batch_data = train_data_prefetcher.next() 252 | 253 | while batch_data is not None: 254 | gt = batch_data["gt"].to(config.device, non_blocking=True) 255 | lr = batch_data["lr"].to(config.device, non_blocking=True) 256 | 257 | # Record the data loading time for training a batch 258 | data_time.update(time.time() - end) 259 | 260 | # Initialize the generator gradient 261 | sr_model.zero_grad(set_to_none=True) 262 | 263 | # Mixed precision training 264 | with amp.autocast(): 265 | sr = sr_model(lr) 266 | loss = criterion(sr, gt) 267 | loss = torch.sum(torch.mul(loss_weight, loss)) 268 | 269 | # Gradient zoom 270 | scaler.scale(loss).backward() 271 | scaler.unscale_(optimizer) 272 | # Update generator weight 273 | scaler.step(optimizer) 274 | scaler.update() 275 | 276 | # update exponentially averaged model weights 277 | ema_sr_model.update_parameters(sr_model) 278 | 279 | # record the loss value 280 | losses.update(loss.item(), lr.size(0)) 281 | 282 | # Record the total time of training a batch 283 | batch_time.update(time.time() - end) 284 | end = time.time() 285 | 286 | # Output training log information once 287 | if batch_index % print_frequency == 0: 288 | # Write training log information to tensorboard 289 | writer.add_scalar("Train/Loss", loss.item(), batch_index + epoch * batches + 1) 290 | progress.display(batch_index) 291 | 292 | # Preload the next batch of data 293 | batch_data = train_data_prefetcher.next() 294 | 295 | # Add 1 to the number of data batches 296 | batch_index += 1 297 | 298 | 299 | if __name__ == "__main__": 300 | main() 301 | -------------------------------------------------------------------------------- /imgproc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import math 15 | import random 16 | from typing import Any 17 | 18 | import cv2 19 | import numpy as np 20 | import torch 21 | from numpy import ndarray 22 | from torch import Tensor 23 | 24 | __all__ = [ 25 | "image_to_tensor", "tensor_to_image", 26 | "image_resize", "preprocess_one_image", 27 | "expand_y", "rgb_to_ycbcr", "bgr_to_ycbcr", "ycbcr_to_bgr", "ycbcr_to_rgb", 28 | "rgb_to_ycbcr_torch", "bgr_to_ycbcr_torch", 29 | "center_crop", "random_crop", "random_rotate", "random_vertically_flip", "random_horizontally_flip", 30 | ] 31 | 32 | 33 | def _cubic(x: Any) -> Any: 34 | """Implementation of `cubic` function in Matlab under Python language. 35 | 36 | Args: 37 | x: Element vector. 38 | 39 | Returns: 40 | Bicubic interpolation 41 | 42 | """ 43 | absx = torch.abs(x) 44 | absx2 = absx ** 2 45 | absx3 = absx ** 3 46 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( 47 | -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ( 48 | ((absx > 1) * (absx <= 2)).type_as(absx)) 49 | 50 | 51 | def _calculate_weights_indices(in_length: int, 52 | out_length: int, 53 | scale: float, 54 | kernel_width: int, 55 | antialiasing: bool) -> [np.ndarray, np.ndarray, int, int]: 56 | """Implementation of `calculate_weights_indices` function in Matlab under Python language. 57 | 58 | Args: 59 | in_length (int): Input length. 60 | out_length (int): Output length. 61 | scale (float): Scale factor. 62 | kernel_width (int): Kernel width. 63 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations. 64 | Caution: Bicubic down-sampling in PIL uses antialiasing by default. 65 | 66 | Returns: 67 | weights, indices, sym_len_s, sym_len_e 68 | 69 | """ 70 | if (scale < 1) and antialiasing: 71 | # Use a modified kernel (larger kernel width) to simultaneously 72 | # interpolate and antialiasing 73 | kernel_width = kernel_width / scale 74 | 75 | # Output-space coordinates 76 | x = torch.linspace(1, out_length, out_length) 77 | 78 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 79 | # in output space maps to 0.5 in input space, and 0.5 + scale in output 80 | # space maps to 1.5 in input space. 81 | u = x / scale + 0.5 * (1 - 1 / scale) 82 | 83 | # What is the left-most pixel that can be involved in the computation? 84 | left = torch.floor(u - kernel_width / 2) 85 | 86 | # What is the maximum number of pixels that can be involved in the 87 | # computation? Note: it's OK to use an extra pixel here; if the 88 | # corresponding weights are all zero, it will be eliminated at the end 89 | # of this function. 90 | p = math.ceil(kernel_width) + 2 91 | 92 | # The indices of the input pixels involved in computing the k-th output 93 | # pixel are in row k of the indices matrix. 94 | indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( 95 | out_length, p) 96 | 97 | # The weights used to compute the k-th output pixel are in row k of the 98 | # weights matrix. 99 | distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices 100 | 101 | # apply cubic kernel 102 | if (scale < 1) and antialiasing: 103 | weights = scale * _cubic(distance_to_center * scale) 104 | else: 105 | weights = _cubic(distance_to_center) 106 | 107 | # Normalize the weights matrix so that each row sums to 1. 108 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 109 | weights = weights / weights_sum.expand(out_length, p) 110 | 111 | # If a column in weights is all zero, get rid of it. only consider the 112 | # first and last column. 113 | weights_zero_tmp = torch.sum((weights == 0), 0) 114 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 115 | indices = indices.narrow(1, 1, p - 2) 116 | weights = weights.narrow(1, 1, p - 2) 117 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 118 | indices = indices.narrow(1, 0, p - 2) 119 | weights = weights.narrow(1, 0, p - 2) 120 | weights = weights.contiguous() 121 | indices = indices.contiguous() 122 | sym_len_s = -indices.min() + 1 123 | sym_len_e = indices.max() - in_length 124 | indices = indices + sym_len_s - 1 125 | return weights, indices, int(sym_len_s), int(sym_len_e) 126 | 127 | 128 | def image_to_tensor(image: ndarray, range_norm: bool, half: bool) -> Tensor: 129 | """Convert the image data type to the Tensor (NCWH) data type supported by PyTorch 130 | 131 | Args: 132 | image (np.ndarray): The image data read by ``OpenCV.imread``, the data range is [0,255] or [0, 1] 133 | range_norm (bool): Scale [0, 1] data to between [-1, 1] 134 | half (bool): Whether to convert torch.float32 similarly to torch.half type 135 | 136 | Returns: 137 | tensor (Tensor): Data types supported by PyTorch 138 | 139 | Examples: 140 | >>> example_image = cv2.imread("lr_image.bmp") 141 | >>> example_tensor = image_to_tensor(example_image, range_norm=True, half=False) 142 | 143 | """ 144 | # Convert image data type to Tensor data type 145 | tensor = torch.from_numpy(np.ascontiguousarray(image)).permute(2, 0, 1).float() 146 | 147 | # Scale the image data from [0, 1] to [-1, 1] 148 | if range_norm: 149 | tensor = tensor.mul(2.0).sub(1.0) 150 | 151 | # Convert torch.float32 image data type to torch.half image data type 152 | if half: 153 | tensor = tensor.half() 154 | 155 | return tensor 156 | 157 | 158 | def tensor_to_image(tensor: Tensor, range_norm: bool, half: bool) -> Any: 159 | """Convert the Tensor(NCWH) data type supported by PyTorch to the np.ndarray(WHC) image data type 160 | 161 | Args: 162 | tensor (Tensor): Data types supported by PyTorch (NCHW), the data range is [0, 1] 163 | range_norm (bool): Scale [-1, 1] data to between [0, 1] 164 | half (bool): Whether to convert torch.float32 similarly to torch.half type. 165 | 166 | Returns: 167 | image (np.ndarray): Data types supported by PIL or OpenCV 168 | 169 | Examples: 170 | >>> example_image = cv2.imread("lr_image.bmp") 171 | >>> example_tensor = image_to_tensor(example_image, range_norm=False, half=False) 172 | 173 | """ 174 | if range_norm: 175 | tensor = tensor.add(1.0).div(2.0) 176 | if half: 177 | tensor = tensor.half() 178 | 179 | image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8") 180 | 181 | return image 182 | 183 | 184 | def preprocess_one_image(image_path: str, device: torch.device) -> Tensor: 185 | image = cv2.imread(image_path).astype(np.float32) / 255.0 186 | 187 | # BGR to RGB 188 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 189 | 190 | # Convert image data to pytorch format data 191 | tensor = image_to_tensor(image, False, False).unsqueeze_(0) 192 | 193 | # Transfer tensor channel image format data to CUDA device 194 | tensor = tensor.to(device=device, memory_format=torch.channels_last, non_blocking=True) 195 | 196 | return tensor 197 | 198 | 199 | def image_resize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any: 200 | """Implementation of `imresize` function in Matlab under Python language. 201 | 202 | Args: 203 | image: The input image. 204 | scale_factor (float): Scale factor. The same scale applies for both height and width. 205 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations. 206 | Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``. 207 | 208 | Returns: 209 | out_2 (np.ndarray): Output image with shape (c, h, w), [0, 1] range, w/o round 210 | 211 | """ 212 | squeeze_flag = False 213 | if type(image).__module__ == np.__name__: # numpy type 214 | numpy_type = True 215 | if image.ndim == 2: 216 | image = image[:, :, None] 217 | squeeze_flag = True 218 | image = torch.from_numpy(image.transpose(2, 0, 1)).float() 219 | else: 220 | numpy_type = False 221 | if image.ndim == 2: 222 | image = image.unsqueeze(0) 223 | squeeze_flag = True 224 | 225 | in_c, in_h, in_w = image.size() 226 | out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor) 227 | kernel_width = 4 228 | 229 | # get weights and indices 230 | weights_h, indices_h, sym_len_hs, sym_len_he = _calculate_weights_indices(in_h, out_h, scale_factor, kernel_width, 231 | antialiasing) 232 | weights_w, indices_w, sym_len_ws, sym_len_we = _calculate_weights_indices(in_w, out_w, scale_factor, kernel_width, 233 | antialiasing) 234 | # process H dimension 235 | # symmetric copying 236 | img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) 237 | img_aug.narrow(1, sym_len_hs, in_h).copy_(image) 238 | 239 | sym_patch = image[:, :sym_len_hs, :] 240 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 241 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 242 | img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) 243 | 244 | sym_patch = image[:, -sym_len_he:, :] 245 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 246 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 247 | img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) 248 | 249 | out_1 = torch.FloatTensor(in_c, out_h, in_w) 250 | kernel_width = weights_h.size(1) 251 | for i in range(out_h): 252 | idx = int(indices_h[i][0]) 253 | for j in range(in_c): 254 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) 255 | 256 | # process W dimension 257 | # symmetric copying 258 | out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) 259 | out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) 260 | 261 | sym_patch = out_1[:, :, :sym_len_ws] 262 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 263 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 264 | out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) 265 | 266 | sym_patch = out_1[:, :, -sym_len_we:] 267 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 268 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 269 | out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) 270 | 271 | out_2 = torch.FloatTensor(in_c, out_h, out_w) 272 | kernel_width = weights_w.size(1) 273 | for i in range(out_w): 274 | idx = int(indices_w[i][0]) 275 | for j in range(in_c): 276 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) 277 | 278 | if squeeze_flag: 279 | out_2 = out_2.squeeze(0) 280 | if numpy_type: 281 | out_2 = out_2.numpy() 282 | if not squeeze_flag: 283 | out_2 = out_2.transpose(1, 2, 0) 284 | 285 | return out_2 286 | 287 | 288 | def expand_y(image: np.ndarray) -> np.ndarray: 289 | """Convert BGR channel to YCbCr format, 290 | and expand Y channel data in YCbCr, from HW to HWC 291 | 292 | Args: 293 | image (np.ndarray): Y channel image data 294 | 295 | Returns: 296 | y_image (np.ndarray): Y-channel image data in HWC form 297 | 298 | """ 299 | # Normalize image data to [0, 1] 300 | image = image.astype(np.float32) / 255. 301 | 302 | # Convert BGR to YCbCr, and extract only Y channel 303 | y_image = bgr_to_ycbcr(image, only_use_y_channel=True) 304 | 305 | # Expand Y channel 306 | y_image = y_image[..., None] 307 | 308 | # Normalize the image data to [0, 255] 309 | y_image = y_image.astype(np.float64) * 255.0 310 | 311 | return y_image 312 | 313 | 314 | def rgb_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray: 315 | """Implementation of rgb2ycbcr function in Matlab under Python language 316 | 317 | Args: 318 | image (np.ndarray): Image input in RGB format. 319 | only_use_y_channel (bool): Extract Y channel separately 320 | 321 | Returns: 322 | image (np.ndarray): YCbCr image array data 323 | 324 | """ 325 | if only_use_y_channel: 326 | image = np.dot(image, [65.481, 128.553, 24.966]) + 16.0 327 | else: 328 | image = np.matmul(image, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [ 329 | 16, 128, 128] 330 | 331 | image /= 255. 332 | image = image.astype(np.float32) 333 | 334 | return image 335 | 336 | 337 | def bgr_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray: 338 | """Implementation of bgr2ycbcr function in Matlab under Python language. 339 | 340 | Args: 341 | image (np.ndarray): Image input in BGR format 342 | only_use_y_channel (bool): Extract Y channel separately 343 | 344 | Returns: 345 | image (np.ndarray): YCbCr image array data 346 | 347 | """ 348 | if only_use_y_channel: 349 | image = np.dot(image, [24.966, 128.553, 65.481]) + 16.0 350 | else: 351 | image = np.matmul(image, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [ 352 | 16, 128, 128] 353 | 354 | image /= 255. 355 | image = image.astype(np.float32) 356 | 357 | return image 358 | 359 | 360 | def ycbcr_to_rgb(image: np.ndarray) -> np.ndarray: 361 | """Implementation of ycbcr2rgb function in Matlab under Python language. 362 | 363 | Args: 364 | image (np.ndarray): Image input in YCbCr format. 365 | 366 | Returns: 367 | image (np.ndarray): RGB image array data 368 | 369 | """ 370 | image_dtype = image.dtype 371 | image *= 255. 372 | 373 | image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621], 374 | [0, -0.00153632, 0.00791071], 375 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 376 | 377 | image /= 255. 378 | image = image.astype(image_dtype) 379 | 380 | return image 381 | 382 | 383 | def ycbcr_to_bgr(image: np.ndarray) -> np.ndarray: 384 | """Implementation of ycbcr2bgr function in Matlab under Python language. 385 | 386 | Args: 387 | image (np.ndarray): Image input in YCbCr format. 388 | 389 | Returns: 390 | image (np.ndarray): BGR image array data 391 | 392 | """ 393 | image_dtype = image.dtype 394 | image *= 255. 395 | 396 | image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621], 397 | [0.00791071, -0.00153632, 0], 398 | [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] 399 | 400 | image /= 255. 401 | image = image.astype(image_dtype) 402 | 403 | return image 404 | 405 | 406 | def rgb_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor: 407 | """Implementation of rgb2ycbcr function in Matlab under PyTorch 408 | 409 | References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion` 410 | 411 | Args: 412 | tensor (Tensor): Image data in PyTorch format 413 | only_use_y_channel (bool): Extract only Y channel 414 | 415 | Returns: 416 | tensor (Tensor): YCbCr image data in PyTorch format 417 | 418 | """ 419 | if only_use_y_channel: 420 | weight = Tensor([[65.481], [128.553], [24.966]]).to(tensor) 421 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 422 | else: 423 | weight = Tensor([[65.481, -37.797, 112.0], 424 | [128.553, -74.203, -93.786], 425 | [24.966, 112.0, -18.214]]).to(tensor) 426 | bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor) 427 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias 428 | 429 | tensor /= 255. 430 | 431 | return tensor 432 | 433 | 434 | def bgr_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor: 435 | """Implementation of bgr2ycbcr function in Matlab under PyTorch 436 | 437 | References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion` 438 | 439 | Args: 440 | tensor (Tensor): Image data in PyTorch format 441 | only_use_y_channel (bool): Extract only Y channel 442 | 443 | Returns: 444 | tensor (Tensor): YCbCr image data in PyTorch format 445 | 446 | """ 447 | if only_use_y_channel: 448 | weight = Tensor([[24.966], [128.553], [65.481]]).to(tensor) 449 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 450 | else: 451 | weight = Tensor([[24.966, 112.0, -18.214], 452 | [128.553, -74.203, -93.786], 453 | [65.481, -37.797, 112.0]]).to(tensor) 454 | bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor) 455 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias 456 | 457 | tensor /= 255. 458 | 459 | return tensor 460 | 461 | 462 | def center_crop(image: np.ndarray, image_size: int) -> np.ndarray: 463 | """Crop small image patches from one image center area. 464 | 465 | Args: 466 | image (np.ndarray): The input image for `OpenCV.imread`. 467 | image_size (int): The size of the captured image area. 468 | 469 | Returns: 470 | patch_image (np.ndarray): Small patch image 471 | 472 | """ 473 | image_height, image_width = image.shape[:2] 474 | 475 | # Just need to find the top and left coordinates of the image 476 | top = (image_height - image_size) // 2 477 | left = (image_width - image_size) // 2 478 | 479 | # Crop image patch 480 | patch_image = image[top:top + image_size, left:left + image_size, ...] 481 | 482 | return patch_image 483 | 484 | 485 | def random_crop(image: np.ndarray, image_size: int) -> np.ndarray: 486 | """Crop small image patches from one image. 487 | 488 | Args: 489 | image (np.ndarray): The input image for `OpenCV.imread`. 490 | image_size (int): The size of the captured image area. 491 | 492 | Returns: 493 | patch_image (np.ndarray): Small patch image 494 | 495 | """ 496 | image_height, image_width = image.shape[:2] 497 | 498 | # Just need to find the top and left coordinates of the image 499 | top = random.randint(0, image_height - image_size) 500 | left = random.randint(0, image_width - image_size) 501 | 502 | # Crop image patch 503 | patch_image = image[top:top + image_size, left:left + image_size, ...] 504 | 505 | return patch_image 506 | 507 | 508 | def random_rotate(image, 509 | angles: list, 510 | center: tuple[int, int] = None, 511 | scale_factor: float = 1.0) -> np.ndarray: 512 | """Rotate an image by a random angle 513 | 514 | Args: 515 | image (np.ndarray): Image read with OpenCV 516 | angles (list): Rotation angle range 517 | center (optional, tuple[int, int]): High resolution image selection center point. Default: ``None`` 518 | scale_factor (optional, float): scaling factor. Default: 1.0 519 | 520 | Returns: 521 | rotated_image (np.ndarray): image after rotation 522 | 523 | """ 524 | image_height, image_width = image.shape[:2] 525 | 526 | if center is None: 527 | center = (image_width // 2, image_height // 2) 528 | 529 | # Random select specific angle 530 | angle = random.choice(angles) 531 | matrix = cv2.getRotationMatrix2D(center, angle, scale_factor) 532 | rotated_image = cv2.warpAffine(image, matrix, (image_width, image_height)) 533 | 534 | return rotated_image 535 | 536 | 537 | def random_horizontally_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray: 538 | """Flip the image upside down randomly 539 | 540 | Args: 541 | image (np.ndarray): Image read with OpenCV 542 | p (optional, float): Horizontally flip probability. Default: 0.5 543 | 544 | Returns: 545 | horizontally_flip_image (np.ndarray): image after horizontally flip 546 | 547 | """ 548 | if random.random() < p: 549 | horizontally_flip_image = cv2.flip(image, 1) 550 | else: 551 | horizontally_flip_image = image 552 | 553 | return horizontally_flip_image 554 | 555 | 556 | def random_vertically_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray: 557 | """Flip an image horizontally randomly 558 | 559 | Args: 560 | image (np.ndarray): Image read with OpenCV 561 | p (optional, float): Vertically flip probability. Default: 0.5 562 | 563 | Returns: 564 | vertically_flip_image (np.ndarray): image after vertically flip 565 | 566 | """ 567 | if random.random() < p: 568 | vertically_flip_image = cv2.flip(image, 0) 569 | else: 570 | vertically_flip_image = image 571 | 572 | return vertically_flip_image 573 | -------------------------------------------------------------------------------- /image_quality_assessment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import collections.abc 15 | import math 16 | import typing 17 | import warnings 18 | from itertools import repeat 19 | from typing import Any 20 | 21 | import cv2 22 | import numpy as np 23 | import torch 24 | from numpy import ndarray 25 | from scipy.io import loadmat 26 | from scipy.ndimage.filters import convolve 27 | from scipy.special import gamma 28 | from torch import nn, Tensor 29 | from torch.nn import functional as F 30 | 31 | from imgproc import image_resize, expand_y, bgr_to_ycbcr, rgb_to_ycbcr_torch 32 | 33 | __all__ = [ 34 | "psnr", "ssim", "niqe", 35 | "PSNR", "SSIM", "NIQE", 36 | ] 37 | 38 | _I = typing.Optional[int] 39 | _D = typing.Optional[torch.dtype] 40 | 41 | 42 | # The following is the implementation of IQA method in Python, using CPU as processing device 43 | def _check_image(raw_image: np.ndarray, dst_image: np.ndarray): 44 | """Check whether the size and type of the two images are the same 45 | 46 | Args: 47 | raw_image (np.ndarray): image data to be compared, BGR format, data range [0, 255] 48 | dst_image (np.ndarray): reference image data, BGR format, data range [0, 255] 49 | 50 | """ 51 | # check image scale 52 | assert raw_image.shape == dst_image.shape, \ 53 | f"Supplied images have different sizes {str(raw_image.shape)} and {str(dst_image.shape)}" 54 | 55 | # check image type 56 | if raw_image.dtype != dst_image.dtype: 57 | warnings.warn(f"Supplied images have different dtypes{str(raw_image.shape)} and {str(dst_image.shape)}") 58 | 59 | 60 | def psnr(raw_image: np.ndarray, dst_image: np.ndarray, crop_border: int, only_test_y_channel: bool) -> float: 61 | """Python implements PSNR (Peak Signal-to-Noise Ratio, peak signal-to-noise ratio) function 62 | 63 | Args: 64 | raw_image (np.ndarray): image data to be compared, BGR format, data range [0, 255] 65 | dst_image (np.ndarray): reference image data, BGR format, data range [0, 255] 66 | crop_border (int): crop border a few pixels 67 | only_test_y_channel (bool): Whether to test only the Y channel of the image. 68 | 69 | Returns: 70 | psnr_metrics (np.float64): PSNR metrics 71 | 72 | """ 73 | # Check if two images are similar in scale and type 74 | _check_image(raw_image, dst_image) 75 | 76 | # crop border pixels 77 | if crop_border > 0: 78 | raw_image = raw_image[crop_border:-crop_border, crop_border:-crop_border, ...] 79 | dst_image = dst_image[crop_border:-crop_border, crop_border:-crop_border, ...] 80 | 81 | # If you only test the Y channel, you need to extract the Y channel data of the YCbCr channel data separately 82 | if only_test_y_channel: 83 | raw_image = expand_y(raw_image) 84 | dst_image = expand_y(dst_image) 85 | 86 | # Convert data type to numpy.float64 bit 87 | raw_image = raw_image.astype(np.float64) 88 | dst_image = dst_image.astype(np.float64) 89 | 90 | psnr_metrics = 10 * np.log10((255.0 ** 2) / np.mean((raw_image - dst_image) ** 2) + 1e-8) 91 | 92 | return psnr_metrics 93 | 94 | 95 | def _ssim(raw_image: np.ndarray, dst_image: np.ndarray) -> float: 96 | """Python implements the SSIM (Structural Similarity) function, which only calculates single-channel data 97 | 98 | Args: 99 | raw_image (np.ndarray): The image data to be compared, in BGR format, the data range is [0, 255] 100 | dst_image (np.ndarray): reference image data, BGR format, data range is [0, 255] 101 | 102 | Returns: 103 | ssim_metrics (float): SSIM metrics for single channel 104 | 105 | """ 106 | c1 = (0.01 * 255.0) ** 2 107 | c2 = (0.03 * 255.0) ** 2 108 | 109 | kernel = cv2.getGaussianKernel(11, 1.5) 110 | kernel_window = np.outer(kernel, kernel.transpose()) 111 | 112 | raw_mean = cv2.filter2D(raw_image, -1, kernel_window)[5:-5, 5:-5] 113 | dst_mean = cv2.filter2D(dst_image, -1, kernel_window)[5:-5, 5:-5] 114 | raw_mean_square = raw_mean ** 2 115 | dst_mean_square = dst_mean ** 2 116 | raw_dst_mean = raw_mean * dst_mean 117 | raw_variance = cv2.filter2D(raw_image ** 2, -1, kernel_window)[5:-5, 5:-5] - raw_mean_square 118 | dst_variance = cv2.filter2D(dst_image ** 2, -1, kernel_window)[5:-5, 5:-5] - dst_mean_square 119 | raw_dst_covariance = cv2.filter2D(raw_image * dst_image, -1, kernel_window)[5:-5, 5:-5] - raw_dst_mean 120 | 121 | ssim_molecular = (2 * raw_dst_mean + c1) * (2 * raw_dst_covariance + c2) 122 | ssim_denominator = (raw_mean_square + dst_mean_square + c1) * (raw_variance + dst_variance + c2) 123 | 124 | ssim_metrics = ssim_molecular / ssim_denominator 125 | ssim_metrics = float(np.mean(ssim_metrics)) 126 | 127 | return ssim_metrics 128 | 129 | 130 | def ssim(raw_image: np.ndarray, dst_image: np.ndarray, crop_border: int, only_test_y_channel: bool) -> float: 131 | """Python implements the SSIM (Structural Similarity) function, which calculates single/multi-channel data 132 | 133 | Args: 134 | raw_image (np.ndarray): The image data to be compared, in BGR format, the data range is [0, 255] 135 | dst_image (np.ndarray): reference image data, BGR format, data range is [0, 255] 136 | crop_border (int): crop border a few pixels 137 | only_test_y_channel (bool): Whether to test only the Y channel of the image 138 | 139 | Returns: 140 | ssim_metrics (float): SSIM metrics for single channel 141 | 142 | """ 143 | # Check if two images are similar in scale and type 144 | _check_image(raw_image, dst_image) 145 | 146 | # crop border pixels 147 | if crop_border > 0: 148 | raw_image = raw_image[crop_border:-crop_border, crop_border:-crop_border, ...] 149 | dst_image = dst_image[crop_border:-crop_border, crop_border:-crop_border, ...] 150 | 151 | # If you only test the Y channel, you need to extract the Y channel data of the YCbCr channel data separately 152 | if only_test_y_channel: 153 | raw_image = expand_y(raw_image) 154 | dst_image = expand_y(dst_image) 155 | 156 | # Convert data type to numpy.float64 bit 157 | raw_image = raw_image.astype(np.float64) 158 | dst_image = dst_image.astype(np.float64) 159 | 160 | channels_ssim_metrics = [] 161 | for channel in range(raw_image.shape[2]): 162 | ssim_metrics = _ssim(raw_image[..., channel], dst_image[..., channel]) 163 | channels_ssim_metrics.append(ssim_metrics) 164 | ssim_metrics = np.mean(np.asarray(channels_ssim_metrics)) 165 | 166 | return float(ssim_metrics) 167 | 168 | 169 | def _estimate_aggd_parameters(vector: np.ndarray) -> [np.ndarray, float, float]: 170 | """Python implements the NIQE (Natural Image Quality Evaluator) function, 171 | This function is used to estimate an asymmetric generalized Gaussian distribution 172 | 173 | Reference papers: 174 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video` 175 | 176 | Args: 177 | vector (np.ndarray): data vector 178 | 179 | Returns: 180 | aggd_parameters (np.ndarray): asymmetric generalized Gaussian distribution 181 | left_beta (float): symmetric left data vector variance mean product 182 | right_beta (float): symmetric right side data vector variance mean product 183 | 184 | """ 185 | # The following is obtained according to the formula and the method provided in the paper on WIki encyclopedia 186 | vector = vector.flatten() 187 | gam = np.arange(0.2, 10.001, 0.001) # len = 9801 188 | gam_reciprocal = np.reciprocal(gam) 189 | r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) 190 | 191 | left_std = np.sqrt(np.mean(vector[vector < 0] ** 2)) 192 | right_std = np.sqrt(np.mean(vector[vector > 0] ** 2)) 193 | gamma_hat = left_std / right_std 194 | rhat = (np.mean(np.abs(vector))) ** 2 / np.mean(vector ** 2) 195 | rhat_norm = (rhat * (gamma_hat ** 3 + 1) * (gamma_hat + 1)) / ((gamma_hat ** 2 + 1) ** 2) 196 | array_position = np.argmin((r_gam - rhat_norm) ** 2) 197 | 198 | aggd_parameters = gam[array_position] 199 | left_beta = left_std * np.sqrt(gamma(1 / aggd_parameters) / gamma(3 / aggd_parameters)) 200 | right_beta = right_std * np.sqrt(gamma(1 / aggd_parameters) / gamma(3 / aggd_parameters)) 201 | 202 | return aggd_parameters, left_beta, right_beta 203 | 204 | 205 | def _get_mscn_feature(image: np.ndarray) -> list[float | Any]: 206 | """Python implements the NIQE (Natural Image Quality Evaluator) function, 207 | This function is used to calculate the MSCN feature map 208 | 209 | Reference papers: 210 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video` 211 | 212 | Args: 213 | image (np.ndarray): Grayscale image of MSCN feature to be calculated, BGR format, data range is [0, 255] 214 | 215 | Returns: 216 | mscn_feature (np.ndarray): MSCN feature map of the image 217 | 218 | """ 219 | mscn_feature = [] 220 | # Calculate the asymmetric generalized Gaussian distribution 221 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters(image) 222 | mscn_feature.extend([aggd_parameters, (left_beta + right_beta) / 2]) 223 | 224 | shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] 225 | for i in range(len(shifts)): 226 | shifted_block = np.roll(image, shifts[i], axis=(0, 1)) 227 | # Calculate the asymmetric generalized Gaussian distribution 228 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters(image * shifted_block) 229 | mean = (right_beta - left_beta) * (gamma(2 / aggd_parameters) / gamma(1 / aggd_parameters)) 230 | mscn_feature.extend([aggd_parameters, mean, left_beta, right_beta]) 231 | 232 | return mscn_feature 233 | 234 | 235 | def _fit_mscn_ipac(image: np.ndarray, 236 | mu_pris_param: np.ndarray, 237 | cov_pris_param: np.ndarray, 238 | gaussian_window: np.ndarray, 239 | block_size_height: int, 240 | block_size_width: int) -> float: 241 | """Python implements the NIQE (Natural Image Quality Evaluator) function, 242 | This function is used to fit the inner product of adjacent coefficients of MSCN 243 | 244 | Reference papers: 245 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video` 246 | 247 | Args: 248 | image (np.ndarray): The image data of the NIQE to be tested, in BGR format, the data range is [0, 255] 249 | mu_pris_param (np.ndarray): Mean of predefined multivariate Gaussians, model computed on original dataset. 250 | cov_pris_param (np.ndarray): Covariance of predefined multivariate Gaussian model computed on original dataset. 251 | gaussian_window (np.ndarray): 7x7 Gaussian window for smoothing the image 252 | block_size_height (int): the height of the block into which the image is divided 253 | block_size_width (int): The width of the block into which the image is divided 254 | 255 | Returns: 256 | niqe_metric (np.ndarray): NIQE score 257 | 258 | """ 259 | image_height, image_width = image.shape 260 | num_block_height = math.floor(image_height / block_size_height) 261 | num_block_width = math.floor(image_width / block_size_width) 262 | image = image[0:num_block_height * block_size_height, 0:num_block_width * block_size_width] 263 | 264 | features_parameters = [] 265 | for scale in (1, 2): 266 | mu = convolve(image, gaussian_window, mode="nearest") 267 | sigma = np.sqrt(np.abs(convolve(np.square(image), gaussian_window, mode="nearest") - np.square(mu))) 268 | image_norm = (image - mu) / (sigma + 1) 269 | 270 | feature = [] 271 | for idx_w in range(num_block_width): 272 | for idx_h in range(num_block_height): 273 | vector = image_norm[ 274 | idx_h * block_size_height // scale:(idx_h + 1) * block_size_height // scale, 275 | idx_w * block_size_width // scale:(idx_w + 1) * block_size_width // scale] 276 | feature.append(_get_mscn_feature(vector)) 277 | 278 | features_parameters.append(np.array(feature)) 279 | 280 | if scale == 1: 281 | image = image_resize(image / 255., scale_factor=0.5, antialiasing=True) 282 | image = image * 255. 283 | 284 | features_parameters = np.concatenate(features_parameters, axis=1) 285 | 286 | # Fitting a multivariate Gaussian kernel model to distorted patch features 287 | mu_distparam = np.nanmean(features_parameters, axis=0) 288 | distparam_no_nan = features_parameters[~np.isnan(features_parameters).any(axis=1)] 289 | cov_distparam = np.cov(distparam_no_nan, rowvar=False) 290 | 291 | invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) 292 | niqe_metric = np.matmul(np.matmul((mu_pris_param - mu_distparam), invcov_param), 293 | np.transpose((mu_pris_param - mu_distparam))) 294 | 295 | niqe_metric = np.sqrt(niqe_metric) 296 | niqe_metric = float(np.squeeze(niqe_metric)) 297 | 298 | return niqe_metric 299 | 300 | 301 | def niqe(image: np.ndarray, 302 | crop_border: int, 303 | niqe_model_path: str, 304 | block_size_height: int = 96, 305 | block_size_width: int = 96) -> float: 306 | """Python implements the NIQE (Natural Image Quality Evaluator) function, 307 | This function computes single/multi-channel data 308 | 309 | Args: 310 | image (np.ndarray): The image data to be compared, in BGR format, the data range is [0, 255] 311 | crop_border (int): crop border a few pixels 312 | niqe_model_path: NIQE estimator model address 313 | block_size_height (int): The height of the block the image is divided into. Default: 96 314 | block_size_width (int): The width of the block the image is divided into. Default: 96 315 | 316 | Returns: 317 | niqe_metrics (float): NIQE indicator under single channel 318 | 319 | """ 320 | # crop border pixels 321 | if crop_border > 0: 322 | image = image[crop_border:-crop_border, crop_border:-crop_border, ...] 323 | 324 | # Defining the NIQE Feature Extraction Model 325 | niqe_model = np.load(niqe_model_path) 326 | 327 | mu_pris_param = niqe_model["mu_pris_param"] 328 | cov_pris_param = niqe_model["cov_pris_param"] 329 | gaussian_window = niqe_model["gaussian_window"] 330 | 331 | # NIQE only tests on Y channel images and needs to convert the images 332 | y_image = bgr_to_ycbcr(image, only_use_y_channel=True) 333 | 334 | # Convert data type to numpy.float64 bit 335 | y_image = y_image.astype(np.float64) 336 | 337 | niqe_metric = _fit_mscn_ipac(y_image, 338 | mu_pris_param, 339 | cov_pris_param, 340 | gaussian_window, 341 | block_size_height, 342 | block_size_width) 343 | 344 | return niqe_metric 345 | 346 | 347 | # The following is the IQA method implemented by PyTorch, using CUDA as the processing device 348 | def _check_tensor_shape(raw_tensor: torch.Tensor, dst_tensor: torch.Tensor): 349 | """Check if the dimensions of the two tensors are the same 350 | 351 | Args: 352 | raw_tensor (np.ndarray or torch.Tensor): image tensor flow to be compared, RGB format, data range [0, 1] 353 | dst_tensor (np.ndarray or torch.Tensor): reference image tensorflow, RGB format, data range [0, 1] 354 | 355 | """ 356 | # Check if tensor scales are consistent 357 | assert raw_tensor.shape == dst_tensor.shape, \ 358 | f"Supplied images have different sizes {str(raw_tensor.shape)} and {str(dst_tensor.shape)}" 359 | 360 | 361 | def _psnr_torch(raw_tensor: torch.Tensor, dst_tensor: torch.Tensor, crop_border: int, 362 | only_test_y_channel: bool) -> float: 363 | """PyTorch implements PSNR (Peak Signal-to-Noise Ratio, peak signal-to-noise ratio) function 364 | 365 | Args: 366 | raw_tensor (torch.Tensor): image tensor flow to be compared, RGB format, data range [0, 1] 367 | dst_tensor (torch.Tensor): reference image tensorflow, RGB format, data range [0, 1] 368 | crop_border (int): crop border a few pixels 369 | only_test_y_channel (bool): Whether to test only the Y channel of the image 370 | 371 | Returns: 372 | psnr_metrics (torch.Tensor): PSNR metrics 373 | 374 | """ 375 | # Check if two tensor scales are similar 376 | _check_tensor_shape(raw_tensor, dst_tensor) 377 | 378 | # crop border pixels 379 | if crop_border > 0: 380 | raw_tensor = raw_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border] 381 | dst_tensor = dst_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border] 382 | 383 | # Convert RGB tensor data to YCbCr tensor, and extract only Y channel data 384 | if only_test_y_channel: 385 | raw_tensor = rgb_to_ycbcr_torch(raw_tensor, only_use_y_channel=True) 386 | dst_tensor = rgb_to_ycbcr_torch(dst_tensor, only_use_y_channel=True) 387 | 388 | # Convert data type to torch.float64 bit 389 | raw_tensor = raw_tensor.to(torch.float64) 390 | dst_tensor = dst_tensor.to(torch.float64) 391 | 392 | mse_value = torch.mean((raw_tensor * 255.0 - dst_tensor * 255.0) ** 2 + 1e-8, dim=[1, 2, 3]) 393 | psnr_metrics = 10 * torch.log10_(255.0 ** 2 / mse_value) 394 | 395 | return psnr_metrics 396 | 397 | 398 | class PSNR(nn.Module): 399 | """PyTorch implements PSNR (Peak Signal-to-Noise Ratio, peak signal-to-noise ratio) function 400 | 401 | Attributes: 402 | crop_border (int): crop border a few pixels 403 | only_test_y_channel (bool): Whether to test only the Y channel of the image 404 | 405 | Returns: 406 | psnr_metrics (torch.Tensor): PSNR metrics 407 | 408 | """ 409 | 410 | def __init__(self, crop_border: int, only_test_y_channel: bool) -> None: 411 | super().__init__() 412 | self.crop_border = crop_border 413 | self.only_test_y_channel = only_test_y_channel 414 | 415 | def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor) -> float: 416 | psnr_metrics = _psnr_torch(raw_tensor, dst_tensor, self.crop_border, self.only_test_y_channel) 417 | 418 | return psnr_metrics 419 | 420 | 421 | def _ssim_torch(raw_tensor: torch.Tensor, 422 | dst_tensor: torch.Tensor, 423 | window_size: int, 424 | gaussian_kernel_window: np.ndarray) -> float: 425 | """PyTorch implements the SSIM (Structural Similarity) function, which only calculates single-channel data 426 | 427 | Args: 428 | raw_tensor (torch.Tensor): image tensor flow to be compared, RGB format, data range [0, 255] 429 | dst_tensor (torch.Tensor): reference image tensorflow, RGB format, data range [0, 255] 430 | window_size (int): Gaussian filter size 431 | gaussian_kernel_window (np.ndarray): Gaussian filter 432 | 433 | Returns: 434 | ssim_metrics (torch.Tensor): SSIM metrics 435 | 436 | """ 437 | c1 = (0.01 * 255.0) ** 2 438 | c2 = (0.03 * 255.0) ** 2 439 | 440 | gaussian_kernel_window = torch.from_numpy(gaussian_kernel_window).view(1, 1, window_size, window_size) 441 | gaussian_kernel_window = gaussian_kernel_window.expand(raw_tensor.size(1), 1, window_size, window_size) 442 | gaussian_kernel_window = gaussian_kernel_window.to(device=raw_tensor.device, dtype=raw_tensor.dtype) 443 | 444 | raw_mean = F.conv2d(raw_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0), groups=raw_tensor.shape[1]) 445 | dst_mean = F.conv2d(dst_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0), groups=dst_tensor.shape[1]) 446 | raw_mean_square = raw_mean ** 2 447 | dst_mean_square = dst_mean ** 2 448 | raw_dst_mean = raw_mean * dst_mean 449 | raw_variance = F.conv2d(raw_tensor * raw_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0), 450 | groups=raw_tensor.shape[1]) - raw_mean_square 451 | dst_variance = F.conv2d(dst_tensor * dst_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0), 452 | groups=raw_tensor.shape[1]) - dst_mean_square 453 | raw_dst_covariance = F.conv2d(raw_tensor * dst_tensor, gaussian_kernel_window, stride=1, padding=(0, 0), 454 | groups=raw_tensor.shape[1]) - raw_dst_mean 455 | 456 | ssim_molecular = (2 * raw_dst_mean + c1) * (2 * raw_dst_covariance + c2) 457 | ssim_denominator = (raw_mean_square + dst_mean_square + c1) * (raw_variance + dst_variance + c2) 458 | 459 | ssim_metrics = ssim_molecular / ssim_denominator 460 | ssim_metrics = torch.mean(ssim_metrics, [1, 2, 3]).float() 461 | 462 | return ssim_metrics 463 | 464 | 465 | def _ssim_single_torch(raw_tensor: torch.Tensor, 466 | dst_tensor: torch.Tensor, 467 | crop_border: int, 468 | only_test_y_channel: bool, 469 | window_size: int, 470 | gaussian_kernel_window: ndarray) -> float: 471 | """PyTorch implements the SSIM (Structural Similarity) function, which only calculates single-channel data 472 | 473 | Args: 474 | raw_tensor (Tensor): image tensor flow to be compared, RGB format, data range [0, 1] 475 | dst_tensor (Tensor): reference image tensorflow, RGB format, data range [0, 1] 476 | crop_border (int): crop border a few pixels 477 | only_test_y_channel (bool): Whether to test only the Y channel of the image 478 | window_size (int): Gaussian filter size 479 | gaussian_kernel_window (ndarray): Gaussian filter 480 | 481 | Returns: 482 | ssim_metrics (torch.Tensor): SSIM metrics 483 | 484 | """ 485 | # Check if two tensor scales are similar 486 | _check_tensor_shape(raw_tensor, dst_tensor) 487 | 488 | # crop border pixels 489 | if crop_border > 0: 490 | raw_tensor = raw_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border] 491 | dst_tensor = dst_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border] 492 | 493 | # Convert RGB tensor data to YCbCr tensor, and extract only Y channel data 494 | if only_test_y_channel: 495 | raw_tensor = rgb_to_ycbcr_torch(raw_tensor, only_use_y_channel=True) 496 | dst_tensor = rgb_to_ycbcr_torch(dst_tensor, only_use_y_channel=True) 497 | 498 | # Convert data type to torch.float64 bit 499 | raw_tensor = raw_tensor.to(torch.float64) 500 | dst_tensor = dst_tensor.to(torch.float64) 501 | 502 | ssim_metrics = _ssim_torch(raw_tensor * 255.0, dst_tensor * 255.0, window_size, gaussian_kernel_window) 503 | 504 | return ssim_metrics 505 | 506 | 507 | class SSIM(nn.Module): 508 | """PyTorch implements the SSIM (Structural Similarity) function, which only calculates single-channel data 509 | 510 | Args: 511 | crop_border (int): crop border a few pixels 512 | only_only_test_y_channel (bool): Whether to test only the Y channel of the image 513 | window_size (int): Gaussian filter size 514 | gaussian_sigma (float): sigma parameter in Gaussian filter 515 | 516 | Returns: 517 | ssim_metrics (torch.Tensor): SSIM metrics 518 | 519 | """ 520 | 521 | def __init__(self, crop_border: int, 522 | only_only_test_y_channel: bool, 523 | window_size: int = 11, 524 | gaussian_sigma: float = 1.5) -> None: 525 | super().__init__() 526 | self.crop_border = crop_border 527 | self.only_test_y_channel = only_only_test_y_channel 528 | self.window_size = window_size 529 | 530 | gaussian_kernel = cv2.getGaussianKernel(window_size, gaussian_sigma) 531 | self.gaussian_kernel_window = np.outer(gaussian_kernel, gaussian_kernel.transpose()) 532 | 533 | def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor) -> float: 534 | ssim_metrics = _ssim_single_torch(raw_tensor, 535 | dst_tensor, 536 | self.crop_border, 537 | self.only_test_y_channel, 538 | self.window_size, 539 | self.gaussian_kernel_window) 540 | 541 | return ssim_metrics 542 | 543 | 544 | def _fspecial_gaussian_torch(window_size: int, sigma: float, channels: int): 545 | """PyTorch implements the fspecial_gaussian() function in MATLAB 546 | 547 | Args: 548 | window_size (int): Gaussian filter size 549 | sigma (float): sigma parameter in Gaussian filter 550 | channels (int): number of input image channels 551 | 552 | Returns: 553 | gaussian_kernel_window (torch.Tensor): Gaussian filter in Tensor format 554 | 555 | """ 556 | if type(window_size) is int: 557 | shape = (window_size, window_size) 558 | else: 559 | shape = window_size 560 | m, n = [(ss - 1.) / 2. for ss in shape] 561 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 562 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) 563 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 564 | sumh = h.sum() 565 | 566 | if sumh != 0: 567 | h /= sumh 568 | 569 | gaussian_kernel_window = torch.from_numpy(h).float().repeat(channels, 1, 1, 1) 570 | 571 | return gaussian_kernel_window 572 | 573 | 574 | def _to_tuple(n): 575 | def parse(x): 576 | if isinstance(x, collections.abc.Iterable): 577 | return x 578 | return tuple(repeat(x, n)) 579 | 580 | return parse 581 | 582 | 583 | def _excact_padding_2d(tensor: torch.Tensor, 584 | kernel: torch.Tensor | tuple, 585 | stride: int = 1, 586 | dilation: int = 1, 587 | mode: str = "same") -> torch.Tensor: 588 | assert len(tensor.shape) == 4, f"Only support 4D tensor input, but got {tensor.shape}" 589 | kernel = _to_tuple(2)(kernel) 590 | stride = _to_tuple(2)(stride) 591 | dilation = _to_tuple(2)(dilation) 592 | b, c, h, w = tensor.shape 593 | h2 = math.ceil(h / stride[0]) 594 | w2 = math.ceil(w / stride[1]) 595 | pad_row = (h2 - 1) * stride[0] + (kernel[0] - 1) * dilation[0] + 1 - h 596 | pad_col = (w2 - 1) * stride[1] + (kernel[1] - 1) * dilation[1] + 1 - w 597 | pad_l, pad_r, pad_t, pad_b = (pad_col // 2, pad_col - pad_col // 2, pad_row // 2, pad_row - pad_row // 2) 598 | 599 | mode = mode if mode != "same" else "constant" 600 | if mode != "symmetric": 601 | tensor = F.pad(tensor, (pad_l, pad_r, pad_t, pad_b), mode=mode) 602 | elif mode == "symmetric": 603 | sym_h = torch.flip(tensor, [2]) 604 | sym_w = torch.flip(tensor, [3]) 605 | sym_hw = torch.flip(tensor, [2, 3]) 606 | 607 | row1 = torch.cat((sym_hw, sym_h, sym_hw), dim=3) 608 | row2 = torch.cat((sym_w, tensor, sym_w), dim=3) 609 | row3 = torch.cat((sym_hw, sym_h, sym_hw), dim=3) 610 | 611 | whole_map = torch.cat((row1, row2, row3), dim=2) 612 | 613 | tensor = whole_map[:, :, h - pad_t:2 * h + pad_b, w - pad_l:2 * w + pad_r, ] 614 | 615 | return tensor 616 | 617 | 618 | class ExactPadding2d(nn.Module): 619 | r"""This function calculate exact padding values for 4D tensor inputs, 620 | and support the same padding mode as tensorflow. 621 | 622 | Args: 623 | kernel (int or tuple): kernel size. 624 | stride (int or tuple): stride size. 625 | dilation (int or tuple): dilation size, default with 1. 626 | mode (srt): padding mode can be ('same', 'symmetric', 'replicate', 'circular') 627 | 628 | """ 629 | 630 | def __init__(self, kernel, stride=1, dilation=1, mode="same") -> None: 631 | super().__init__() 632 | self.kernel = _to_tuple(2)(kernel) 633 | self.stride = _to_tuple(2)(stride) 634 | self.dilation = _to_tuple(2)(dilation) 635 | self.mode = mode 636 | 637 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 638 | return _excact_padding_2d(tensor, self.kernel, self.stride, self.dilation, self.mode) 639 | 640 | 641 | def _image_filter(tensor: torch.Tensor, 642 | weight: torch.Tensor, 643 | bias=None, 644 | stride: int = 1, 645 | padding: str = "same", 646 | dilation: int = 1, 647 | groups: int = 1): 648 | """PyTorch implements the imfilter() function in MATLAB 649 | 650 | Args: 651 | tensor (torch.Tensor): Tensor image data 652 | weight (torch.Tensor): filter weight 653 | padding (str): how to pad pixels. Default: ``same`` 654 | dilation (int): convolution dilation scale 655 | groups (int): number of grouped convolutions 656 | 657 | """ 658 | kernel_size = weight.shape[2:] 659 | exact_padding_2d = ExactPadding2d(kernel_size, stride, dilation, mode=padding) 660 | 661 | return F.conv2d(exact_padding_2d(tensor), weight, bias, stride, dilation=dilation, groups=groups) 662 | 663 | 664 | def _reshape_input_torch(tensor: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]: 665 | if tensor.dim() == 4: 666 | b, c, h, w = tensor.size() 667 | elif tensor.dim() == 3: 668 | c, h, w = tensor.size() 669 | b = None 670 | elif tensor.dim() == 2: 671 | h, w = tensor.size() 672 | b = c = None 673 | else: 674 | raise ValueError('{}-dim Tensor is not supported!'.format(tensor.dim())) 675 | 676 | tensor = tensor.view(-1, 1, h, w) 677 | return tensor, b, c, h, w 678 | 679 | 680 | def _reshape_output_torch(tensor: torch.Tensor, b: _I, c: _I) -> torch.Tensor: 681 | rh = tensor.size(-2) 682 | rw = tensor.size(-1) 683 | # Back to the original dimension 684 | if b is not None: 685 | tensor = tensor.view(b, c, rh, rw) # 4-dim 686 | else: 687 | if c is not None: 688 | tensor = tensor.view(c, rh, rw) # 3-dim 689 | else: 690 | tensor = tensor.view(rh, rw) # 2-dim 691 | 692 | return tensor 693 | 694 | 695 | def _cast_input_torch(tensor: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]: 696 | if tensor.dtype != torch.float32 or tensor.dtype != torch.float64: 697 | dtype = tensor.dtype 698 | tensor = tensor.float() 699 | else: 700 | dtype = None 701 | 702 | return tensor, dtype 703 | 704 | 705 | def _cast_output_torch(tensor: torch.Tensor, dtype: _D) -> torch.Tensor: 706 | if dtype is not None: 707 | if not dtype.is_floating_point: 708 | tensor = tensor.round() 709 | # To prevent over/underflow when converting types 710 | if dtype is torch.uint8: 711 | tensor = tensor.clamp(0, 255) 712 | 713 | tensor = tensor.to(dtype=dtype) 714 | 715 | return tensor 716 | 717 | 718 | def _cubic_contribution_torch(tensor: torch.Tensor, a: float = -0.5) -> torch.Tensor: 719 | ax = tensor.abs() 720 | ax2 = ax * ax 721 | ax3 = ax * ax2 722 | 723 | range_01 = ax.le(1) 724 | range_12 = torch.logical_and(ax.gt(1), ax.le(2)) 725 | 726 | cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1 727 | cont_01 = cont_01 * range_01.to(dtype=tensor.dtype) 728 | 729 | cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a) 730 | cont_12 = cont_12 * range_12.to(dtype=tensor.dtype) 731 | 732 | cont = cont_01 + cont_12 733 | return cont 734 | 735 | 736 | def _gaussian_contribution_torch(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor: 737 | range_3sigma = (x.abs() <= 3 * sigma + 1) 738 | # Normalization will be done after 739 | cont = torch.exp(-x.pow(2) / (2 * sigma ** 2)) 740 | cont = cont * range_3sigma.to(dtype=x.dtype) 741 | return cont 742 | 743 | 744 | def _reflect_padding_torch(tensor: torch.Tensor, dim: int, pad_pre: int, pad_post: int) -> torch.Tensor: 745 | """ 746 | Apply reflect padding to the given Tensor. 747 | Note that it is slightly different from the PyTorch functional.pad, 748 | where boundary elements are used only once. 749 | Instead, we follow the MATLAB implementation 750 | which uses boundary elements twice. 751 | For example, 752 | [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation, 753 | while our implementation yields [a, a, b, c, d, d]. 754 | """ 755 | b, c, h, w = tensor.size() 756 | if dim == 2 or dim == -2: 757 | padding_buffer = tensor.new_zeros(b, c, h + pad_pre + pad_post, w) 758 | padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(tensor) 759 | for p in range(pad_pre): 760 | padding_buffer[..., pad_pre - p - 1, :].copy_(tensor[..., p, :]) 761 | for p in range(pad_post): 762 | padding_buffer[..., h + pad_pre + p, :].copy_(tensor[..., -(p + 1), :]) 763 | else: 764 | padding_buffer = tensor.new_zeros(b, c, h, w + pad_pre + pad_post) 765 | padding_buffer[..., pad_pre:(w + pad_pre)].copy_(tensor) 766 | for p in range(pad_pre): 767 | padding_buffer[..., pad_pre - p - 1].copy_(tensor[..., p]) 768 | for p in range(pad_post): 769 | padding_buffer[..., w + pad_pre + p].copy_(tensor[..., -(p + 1)]) 770 | 771 | return padding_buffer 772 | 773 | 774 | def _padding_torch(tensor: torch.Tensor, 775 | dim: int, 776 | pad_pre: int, 777 | pad_post: int, 778 | padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor: 779 | if padding_type is None: 780 | return tensor 781 | elif padding_type == 'reflect': 782 | x_pad = _reflect_padding_torch(tensor, dim, pad_pre, pad_post) 783 | else: 784 | raise ValueError('{} padding is not supported!'.format(padding_type)) 785 | 786 | return x_pad 787 | 788 | 789 | def _get_padding_torch(tensor: torch.Tensor, kernel_size: int, x_size: int) -> typing.Tuple[int, int, torch.Tensor]: 790 | tensor = tensor.long() 791 | r_min = tensor.min() 792 | r_max = tensor.max() + kernel_size - 1 793 | 794 | if r_min <= 0: 795 | pad_pre = -r_min 796 | pad_pre = pad_pre.item() 797 | tensor += pad_pre 798 | else: 799 | pad_pre = 0 800 | 801 | if r_max >= x_size: 802 | pad_post = r_max - x_size + 1 803 | pad_post = pad_post.item() 804 | else: 805 | pad_post = 0 806 | 807 | return pad_pre, pad_post, tensor 808 | 809 | 810 | def _get_weight_torch(tensor: torch.Tensor, 811 | kernel_size: int, 812 | kernel: str = "cubic", 813 | sigma: float = 2.0, 814 | antialiasing_factor: float = 1) -> torch.Tensor: 815 | buffer_pos = tensor.new_zeros(kernel_size, len(tensor)) 816 | for idx, buffer_sub in enumerate(buffer_pos): 817 | buffer_sub.copy_(tensor - idx) 818 | 819 | # Expand (downsampling) / Shrink (upsampling) the receptive field. 820 | buffer_pos *= antialiasing_factor 821 | if kernel == 'cubic': 822 | weight = _cubic_contribution_torch(buffer_pos) 823 | elif kernel == 'gaussian': 824 | weight = _gaussian_contribution_torch(buffer_pos, sigma=sigma) 825 | else: 826 | raise ValueError('{} kernel is not supported!'.format(kernel)) 827 | 828 | weight /= weight.sum(dim=0, keepdim=True) 829 | return weight 830 | 831 | 832 | def _reshape_tensor_torch(tensor: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor: 833 | # Resize height 834 | if dim == 2 or dim == -2: 835 | k = (kernel_size, 1) 836 | h_out = tensor.size(-2) - kernel_size + 1 837 | w_out = tensor.size(-1) 838 | # Resize width 839 | else: 840 | k = (1, kernel_size) 841 | h_out = tensor.size(-2) 842 | w_out = tensor.size(-1) - kernel_size + 1 843 | 844 | unfold = F.unfold(tensor, k) 845 | unfold = unfold.view(unfold.size(0), -1, h_out, w_out) 846 | return unfold 847 | 848 | 849 | def _resize_1d_torch(tensor: torch.Tensor, 850 | dim: int, 851 | size: int, 852 | scale: float, 853 | kernel: str = 'cubic', 854 | sigma: float = 2.0, 855 | padding_type: str = 'reflect', 856 | antialiasing: bool = True) -> torch.Tensor: 857 | """ 858 | Args: 859 | tensor (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W). 860 | dim (int): 861 | scale (float): 862 | size (int): 863 | Return: 864 | """ 865 | # Identity case 866 | if scale == 1: 867 | return tensor 868 | 869 | # Default bicubic kernel with antialiasing (only when downsampling) 870 | if kernel == 'cubic': 871 | kernel_size = 4 872 | else: 873 | kernel_size = math.floor(6 * sigma) 874 | 875 | if antialiasing and (scale < 1): 876 | antialiasing_factor = scale 877 | kernel_size = math.ceil(kernel_size / antialiasing_factor) 878 | else: 879 | antialiasing_factor = 1 880 | 881 | # We allow margin to both sizes 882 | kernel_size += 2 883 | 884 | # Weights only depend on the shape of input and output, 885 | # so we do not calculate gradients here. 886 | with torch.no_grad(): 887 | pos = torch.linspace( 888 | 0, 889 | size - 1, 890 | steps=size, 891 | dtype=tensor.dtype, 892 | device=tensor.device, 893 | ) 894 | pos = (pos + 0.5) / scale - 0.5 895 | base = pos.floor() - (kernel_size // 2) + 1 896 | dist = pos - base 897 | weight = _get_weight_torch( 898 | dist, 899 | kernel_size, 900 | kernel=kernel, 901 | sigma=sigma, 902 | antialiasing_factor=antialiasing_factor, 903 | ) 904 | pad_pre, pad_post, base = _get_padding_torch(base, kernel_size, tensor.size(dim)) 905 | 906 | # To back-propagate through x 907 | x_pad = _padding_torch(tensor, dim, pad_pre, pad_post, padding_type=padding_type) 908 | unfold = _reshape_tensor_torch(x_pad, dim, kernel_size) 909 | # Subsampling first 910 | if dim == 2 or dim == -2: 911 | sample = unfold[..., base, :] 912 | weight = weight.view(1, kernel_size, sample.size(2), 1) 913 | else: 914 | sample = unfold[..., base] 915 | weight = weight.view(1, kernel_size, 1, sample.size(3)) 916 | 917 | # Apply the kernel 918 | tensor = sample * weight 919 | tensor = tensor.sum(dim=1, keepdim=True) 920 | return tensor 921 | 922 | 923 | def _downsampling_2d_torch(tensor: torch.Tensor, k: torch.Tensor, scale: int, 924 | padding_type: str = 'reflect') -> torch.Tensor: 925 | c = tensor.size(1) 926 | k_h = k.size(-2) 927 | k_w = k.size(-1) 928 | 929 | k = k.to(dtype=tensor.dtype, device=tensor.device) 930 | k = k.view(1, 1, k_h, k_w) 931 | k = k.repeat(c, c, 1, 1) 932 | e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False) 933 | e = e.view(c, c, 1, 1) 934 | k = k * e 935 | 936 | pad_h = (k_h - scale) // 2 937 | pad_w = (k_w - scale) // 2 938 | tensor = _padding_torch(tensor, -2, pad_h, pad_h, padding_type=padding_type) 939 | tensor = _padding_torch(tensor, -1, pad_w, pad_w, padding_type=padding_type) 940 | y = F.conv2d(tensor, k, padding=0, stride=scale) 941 | return y 942 | 943 | 944 | def _cov_torch(tensor, rowvar=True, bias=False): 945 | r"""Estimate a covariance matrix (np.cov) 946 | Ref: https://gist.github.com/ModarTensai/5ab449acba9df1a26c12060240773110 947 | """ 948 | tensor = tensor if rowvar else tensor.transpose(-1, -2) 949 | tensor = tensor - tensor.mean(dim=-1, keepdim=True) 950 | factor = 1 / (tensor.shape[-1] - int(not bool(bias))) 951 | return factor * tensor @ tensor.transpose(-1, -2) 952 | 953 | 954 | def _nancov_torch(x): 955 | r"""Calculate nancov for batched tensor, rows that contains nan value 956 | will be removed. 957 | Args: 958 | x (tensor): (B, row_num, feat_dim) 959 | Return: 960 | cov (tensor): (B, feat_dim, feat_dim) 961 | """ 962 | assert len(x.shape) == 3, f'Shape of input should be (batch_size, row_num, feat_dim), but got {x.shape}' 963 | b, rownum, feat_dim = x.shape 964 | nan_mask = torch.isnan(x).any(dim=2, keepdim=True) 965 | x_no_nan = x.masked_select(~nan_mask).reshape(b, -1, feat_dim) 966 | cov_x = _cov_torch(x_no_nan, rowvar=False) 967 | return cov_x 968 | 969 | 970 | def _nanmean_torch(v, *args, inplace=False, **kwargs): 971 | r"""nanmean same as matlab function: calculate mean values by removing all nan. 972 | """ 973 | if not inplace: 974 | v = v.clone() 975 | is_nan = torch.isnan(v) 976 | v[is_nan] = 0 977 | return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs) 978 | 979 | 980 | def _symm_pad_torch(im: torch.Tensor, padding: [int, int, int, int]): 981 | """Symmetric padding same as tensorflow. 982 | Ref: https://discuss.pytorch.org/t/symmetric-padding/19866/3 983 | """ 984 | h, w = im.shape[-2:] 985 | left, right, top, bottom = padding 986 | 987 | x_idx = np.arange(-left, w + right) 988 | y_idx = np.arange(-top, h + bottom) 989 | 990 | def reflect(x, minx, maxx): 991 | """ Reflects an array around two points making a triangular waveform that ramps up 992 | and down, allowing for pad lengths greater than the input length """ 993 | rng = maxx - minx 994 | double_rng = 2 * rng 995 | mod = np.fmod(x - minx, double_rng) 996 | normed_mod = np.where(mod < 0, mod + double_rng, mod) 997 | out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx 998 | return np.array(out, dtype=x.dtype) 999 | 1000 | x_pad = reflect(x_idx, -0.5, w - 0.5) 1001 | y_pad = reflect(y_idx, -0.5, h - 0.5) 1002 | xx, yy = np.meshgrid(x_pad, y_pad) 1003 | return im[..., yy, xx] 1004 | 1005 | 1006 | def _blockproc_torch(x, kernel: int or tuple or list, fun, border_size=None, pad_partial=False, pad_method='zero'): 1007 | r"""blockproc function like matlab 1008 | 1009 | Difference: 1010 | - Partial blocks is discarded (if exist) for fast GPU process. 1011 | 1012 | Args: 1013 | x (tensor): shape (b, c, h, w) 1014 | kernel (int or tuple): block size 1015 | func: function to process each block 1016 | border_size (int or tuple): border pixels to each block 1017 | pad_partial: pad partial blocks to make them full-sized, default False 1018 | pad_method: [zero, replicate, symmetric] how to pad partial block when pad_partial is set True 1019 | 1020 | Return: 1021 | results (tensor): concatenated results of each block 1022 | 1023 | """ 1024 | assert len(x.shape) == 4, f'Shape of input has to be (b, c, h, w) but got {x.shape}' 1025 | kernel = _to_tuple(2)(kernel) 1026 | if pad_partial: 1027 | b, c, h, w = x.shape 1028 | stride = kernel 1029 | h2 = math.ceil(h / stride[0]) 1030 | w2 = math.ceil(w / stride[1]) 1031 | pad_row = (h2 - 1) * stride[0] + kernel[0] - h 1032 | pad_col = (w2 - 1) * stride[1] + kernel[1] - w 1033 | padding = (0, pad_col, 0, pad_row) 1034 | if pad_method == 'zero': 1035 | x = F.pad(x, padding, mode='constant') 1036 | elif pad_method == 'symmetric': 1037 | x = _symm_pad_torch(x, padding) 1038 | else: 1039 | x = F.pad(x, padding, mode=pad_method) 1040 | 1041 | if border_size is not None: 1042 | raise NotImplementedError('Blockproc with border is not implemented yet') 1043 | else: 1044 | b, c, h, w = x.shape 1045 | block_size_h, block_size_w = kernel 1046 | num_block_h = math.floor(h / block_size_h) 1047 | num_block_w = math.floor(w / block_size_w) 1048 | 1049 | # extract blocks in (row, column) manner, i.e., stored with column first 1050 | blocks = F.unfold(x, kernel, stride=kernel) 1051 | blocks = blocks.reshape(b, c, *kernel, num_block_h, num_block_w) 1052 | blocks = blocks.permute(5, 4, 0, 1, 2, 3).reshape(num_block_h * num_block_w * b, c, *kernel) 1053 | 1054 | results = fun(blocks) 1055 | results = results.reshape(num_block_h * num_block_w, b, *results.shape[1:]).transpose(0, 1) 1056 | return results 1057 | 1058 | 1059 | def _image_resize_torch(tensor: torch.Tensor, 1060 | scale_factor: typing.Optional[float] = None, 1061 | sizes: typing.Optional[typing.Tuple[int, int]] = None, 1062 | kernel: typing.Union[str, torch.Tensor] = "cubic", 1063 | sigma: float = 2, 1064 | padding_type: str = "reflect", 1065 | antialiasing: bool = True) -> torch.Tensor: 1066 | """ 1067 | Args: 1068 | tensor (torch.Tensor): 1069 | scale_factor (float): 1070 | sizes (tuple(int, int)): 1071 | kernel (str, default='cubic'): 1072 | sigma (float, default=2): 1073 | padding_type (str, default='reflect'): 1074 | antialiasing (bool, default=True): 1075 | Return: 1076 | torch.Tensor: 1077 | """ 1078 | scales = (scale_factor, scale_factor) 1079 | 1080 | if scale_factor is None and sizes is None: 1081 | raise ValueError('One of scale or sizes must be specified!') 1082 | if scale_factor is not None and sizes is not None: 1083 | raise ValueError('Please specify scale or sizes to avoid conflict!') 1084 | 1085 | tensor, b, c, h, w = _reshape_input_torch(tensor) 1086 | 1087 | if sizes is None and scale_factor is not None: 1088 | ''' 1089 | # Check if we can apply the convolution algorithm 1090 | scale_inv = 1 / scale 1091 | if isinstance(kernel, str) and scale_inv.is_integer(): 1092 | kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing) 1093 | elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer(): 1094 | raise ValueError( 1095 | 'An integer downsampling factor ' 1096 | 'should be used with a predefined kernel!' 1097 | ) 1098 | ''' 1099 | # Determine output size 1100 | sizes = (math.ceil(h * scale_factor), math.ceil(w * scale_factor)) 1101 | scales = (scale_factor, scale_factor) 1102 | 1103 | if scale_factor is None and sizes is not None: 1104 | scales = (sizes[0] / h, sizes[1] / w) 1105 | 1106 | tensor, dtype = _cast_input_torch(tensor) 1107 | 1108 | if isinstance(kernel, str) and sizes is not None: 1109 | # Core resizing module 1110 | tensor = _resize_1d_torch( 1111 | tensor, 1112 | -2, 1113 | size=sizes[0], 1114 | scale=scales[0], 1115 | kernel=kernel, 1116 | sigma=sigma, 1117 | padding_type=padding_type, 1118 | antialiasing=antialiasing) 1119 | tensor = _resize_1d_torch( 1120 | tensor, 1121 | -1, 1122 | size=sizes[1], 1123 | scale=scales[1], 1124 | kernel=kernel, 1125 | sigma=sigma, 1126 | padding_type=padding_type, 1127 | antialiasing=antialiasing) 1128 | elif isinstance(kernel, torch.Tensor) and scale_factor is not None: 1129 | tensor = _downsampling_2d_torch(tensor, kernel, scale=int(1 / scale_factor)) 1130 | 1131 | tensor = _reshape_output_torch(tensor, b, c) 1132 | tensor = _cast_output_torch(tensor, dtype) 1133 | return tensor 1134 | 1135 | 1136 | def _estimate_aggd_parameters_torch(tensor: torch.Tensor, 1137 | get_sigma: bool) -> [torch.Tensor, torch.Tensor, torch.Tensor]: 1138 | """PyTorch implements the BRISQUE (Blind/Referenceless Image Spatial Quality Evaluator) function 1139 | This function is used to estimate an asymmetric generalized Gaussian distribution 1140 | 1141 | Reference papers: 1142 | `No-Reference Image Quality Assessment in the Spatial Domain` 1143 | `Referenceless Image Spatial Quality Evaluation Engine` 1144 | 1145 | Args: 1146 | tensor (torch.Tensor): data vector 1147 | get_sigma (bool): whether to return the covariance mean 1148 | 1149 | Returns: 1150 | aggd_parameters (torch.Tensor): asymmetric generalized Gaussian distribution 1151 | left_std (torch.Tensor): symmetric left data vector variance mean 1152 | right_std (torch.Tensor): Symmetric right side data vector variance mean 1153 | 1154 | """ 1155 | # The following is obtained according to the formula and the method provided in the paper on WIki encyclopedia 1156 | aggd = torch.arange(0.2, 10 + 0.001, 0.001).to(tensor) 1157 | r_gam = (2 * torch.lgamma(2. / aggd) - (torch.lgamma(1. / aggd) + torch.lgamma(3. / aggd))).exp() 1158 | r_gam = r_gam.repeat(tensor.size(0), 1) 1159 | 1160 | mask_left = tensor < 0 1161 | mask_right = tensor > 0 1162 | count_left = mask_left.sum(dim=(-1, -2), dtype=torch.float32) 1163 | count_right = mask_right.sum(dim=(-1, -2), dtype=torch.float32) 1164 | 1165 | left_std = torch.sqrt_((tensor * mask_left).pow(2).sum(dim=(-1, -2)) / (count_left + 1e-8)) 1166 | right_std = torch.sqrt_((tensor * mask_right).pow(2).sum(dim=(-1, -2)) / (count_right + 1e-8)) 1167 | gamma_hat = left_std / right_std 1168 | rhat = tensor.abs().mean(dim=(-1, -2)).pow(2) / tensor.pow(2).mean(dim=(-1, -2)) 1169 | rhat_norm = (rhat * (gamma_hat.pow(3) + 1) * (gamma_hat + 1)) / (gamma_hat.pow(2) + 1).pow(2) 1170 | 1171 | array_position = (r_gam - rhat_norm).abs().argmin(dim=-1) 1172 | aggd_parameters = aggd[array_position] 1173 | 1174 | if get_sigma: 1175 | left_beta = left_std.squeeze(-1) * ( 1176 | torch.lgamma(1 / aggd_parameters) - torch.lgamma(3 / aggd_parameters)).exp().sqrt() 1177 | right_beta = right_std.squeeze(-1) * ( 1178 | torch.lgamma(1 / aggd_parameters) - torch.lgamma(3 / aggd_parameters)).exp().sqrt() 1179 | return aggd_parameters, left_beta, right_beta 1180 | 1181 | else: 1182 | left_std = left_std.squeeze_(-1) 1183 | right_std = right_std.squeeze_(-1) 1184 | return aggd_parameters, left_std, right_std 1185 | 1186 | 1187 | def _get_mscn_feature_torch(tensor: torch.Tensor) -> Tensor: 1188 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function, 1189 | This function is used to calculate the feature map 1190 | 1191 | Reference papers: 1192 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video` 1193 | 1194 | Args: 1195 | tensor (torch.Tensor): The image to be evaluated for NIQE sharpness 1196 | 1197 | Returns: 1198 | feature (torch.Tensor): image feature map 1199 | 1200 | """ 1201 | batch_size = tensor.shape[0] 1202 | aggd_block = tensor[:, [0]] 1203 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters_torch(aggd_block, True) 1204 | feature = [aggd_parameters, (left_beta + right_beta) / 2] 1205 | 1206 | shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] 1207 | for i in range(len(shifts)): 1208 | shifted_block = torch.roll(aggd_block, shifts[i], dims=(2, 3)) 1209 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters_torch(aggd_block * shifted_block, True) 1210 | mean = (right_beta - left_beta) * (torch.lgamma(2 / aggd_parameters) - torch.lgamma(1 / aggd_parameters)).exp() 1211 | feature.extend((aggd_parameters, mean, left_beta, right_beta)) 1212 | 1213 | feature = [x.reshape(batch_size, 1) for x in feature] 1214 | feature = torch.cat(feature, dim=-1) 1215 | 1216 | return feature 1217 | 1218 | 1219 | def _fit_mscn_ipac_torch(tensor: torch.Tensor, 1220 | mu_pris_param: torch.Tensor, 1221 | cov_pris_param: torch.Tensor, 1222 | block_size_height: int, 1223 | block_size_width: int, 1224 | kernel_size: int = 7, 1225 | kernel_sigma: float = 7. / 6, 1226 | padding: str = "replicate") -> Tensor: 1227 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function, 1228 | This function is used to fit the inner product of adjacent coefficients of MSCN 1229 | 1230 | Reference papers: 1231 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video` 1232 | 1233 | Args: 1234 | tensor (torch.Tensor): The image to be evaluated for NIQE sharpness 1235 | mu_pris_param (torch.Tensor): mean of predefined multivariate Gaussians, model computed on original dataset 1236 | cov_pris_param (torch.Tensor): Covariance of predefined multivariate Gaussian model computed on original dataset 1237 | block_size_height (int): the height of the block into which the image is divided 1238 | block_size_width (int): The width of the block into which the image is divided 1239 | kernel_size (int): Gaussian filter size 1240 | kernel_sigma (int): sigma value in Gaussian filter 1241 | padding (str): how to pad pixels. Default: ``replicate`` 1242 | 1243 | Returns: 1244 | niqe_metric (torch.Tensor): NIQE score 1245 | 1246 | """ 1247 | # crop image 1248 | b, c, h, w = tensor.shape 1249 | num_block_height = math.floor(h / block_size_height) 1250 | num_block_width = math.floor(w / block_size_width) 1251 | tensor = tensor[..., 0:num_block_height * block_size_height, 0:num_block_width * block_size_width] 1252 | 1253 | distparam = [] 1254 | for scale in (1, 2): 1255 | kernel = _fspecial_gaussian_torch(kernel_size, kernel_sigma, 1).to(tensor) 1256 | mu = _image_filter(tensor, kernel, padding=padding) 1257 | std = _image_filter(tensor ** 2, kernel, padding=padding) 1258 | sigma = torch.sqrt_((std - mu ** 2).abs() + 1e-8) 1259 | structdis = (tensor - mu) / (sigma + 1) 1260 | 1261 | distparam.append(_blockproc_torch(structdis, 1262 | [block_size_height // scale, block_size_width // scale], 1263 | fun=_get_mscn_feature_torch)) 1264 | 1265 | if scale == 1: 1266 | tensor = _image_resize_torch(tensor / 255., scale_factor=0.5, antialiasing=True) 1267 | tensor = tensor * 255. 1268 | 1269 | distparam = torch.cat(distparam, -1) 1270 | 1271 | # Fit MVG (Multivariate Gaussian) model to distorted patch features 1272 | mu_distparam = _nanmean_torch(distparam, dim=1) 1273 | cov_distparam = _nancov_torch(distparam) 1274 | 1275 | invcov_param = torch.linalg.pinv((cov_pris_param + cov_distparam) / 2) 1276 | diff = (mu_pris_param - mu_distparam).unsqueeze(1) 1277 | niqe_metric = torch.bmm(torch.bmm(diff, invcov_param), diff.transpose(1, 2)).squeeze() 1278 | niqe_metric = torch.sqrt(niqe_metric) 1279 | 1280 | return niqe_metric 1281 | 1282 | 1283 | def _niqe_torch(tensor: torch.Tensor, 1284 | crop_border: int, 1285 | niqe_model_path: str, 1286 | block_size_height: int = 96, 1287 | block_size_width: int = 96 1288 | ) -> Tensor: 1289 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function, 1290 | 1291 | Attributes: 1292 | tensor (torch.Tensor): The image to evaluate the sharpness of the BRISQUE 1293 | crop_border (int): crop border a few pixels 1294 | niqe_model_path (str): NIQE model estimator weight address 1295 | block_size_height (int): The height of the block the image is divided into. Default: 96 1296 | block_size_width (int): The width of the block the image is divided into. Default: 96 1297 | 1298 | Returns: 1299 | niqe_metrics (torch.Tensor): NIQE metrics 1300 | 1301 | """ 1302 | # crop border pixels 1303 | if crop_border > 0: 1304 | tensor = tensor[:, :, crop_border:-crop_border, crop_border:-crop_border] 1305 | 1306 | # Load the NIQE feature extraction model 1307 | niqe_model = loadmat(niqe_model_path) 1308 | 1309 | mu_pris_param = np.ravel(niqe_model["mu_prisparam"]) 1310 | cov_pris_param = niqe_model["cov_prisparam"] 1311 | mu_pris_param = torch.from_numpy(mu_pris_param).to(tensor) 1312 | cov_pris_param = torch.from_numpy(cov_pris_param).to(tensor) 1313 | 1314 | mu_pris_param = mu_pris_param.repeat(tensor.size(0), 1) 1315 | cov_pris_param = cov_pris_param.repeat(tensor.size(0), 1, 1) 1316 | 1317 | # NIQE only tests on Y channel images and needs to convert the images 1318 | y_tensor = rgb_to_ycbcr_torch(tensor, only_use_y_channel=True) 1319 | y_tensor *= 255.0 1320 | y_tensor = y_tensor.round() 1321 | 1322 | # Convert data type to torch.float64 bit 1323 | y_tensor = y_tensor.to(torch.float64) 1324 | 1325 | niqe_metric = _fit_mscn_ipac_torch(y_tensor, 1326 | mu_pris_param, 1327 | cov_pris_param, 1328 | block_size_height, 1329 | block_size_width) 1330 | 1331 | return niqe_metric 1332 | 1333 | 1334 | class NIQE(nn.Module): 1335 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function, 1336 | 1337 | Attributes: 1338 | crop_border (int): crop border a few pixels 1339 | niqe_model_path (str): NIQE model address 1340 | block_size_height (int): The height of the block the image is divided into. Default: 96 1341 | block_size_width (int): The width of the block the image is divided into. Default: 96 1342 | 1343 | Returns: 1344 | niqe_metrics (torch.Tensor): NIQE metrics 1345 | 1346 | """ 1347 | 1348 | def __init__(self, crop_border: int, 1349 | niqe_model_path: str, 1350 | block_size_height: int = 96, 1351 | block_size_width: int = 96) -> None: 1352 | super().__init__() 1353 | self.crop_border = crop_border 1354 | self.niqe_model_path = niqe_model_path 1355 | self.block_size_height = block_size_height 1356 | self.block_size_width = block_size_width 1357 | 1358 | def forward(self, raw_tensor: torch.Tensor) -> Tensor: 1359 | niqe_metrics = _niqe_torch(raw_tensor, 1360 | self.crop_border, 1361 | self.niqe_model_path, 1362 | self.block_size_height, 1363 | self.block_size_width) 1364 | 1365 | return niqe_metrics 1366 | --------------------------------------------------------------------------------