├── results └── .gitkeep ├── samples └── .gitkeep ├── figure ├── comic.png └── sr_comic.png ├── requirements.txt ├── scripts ├── run.py ├── split_train_valid_dataset.py └── prepare_dataset.py ├── data └── README.md ├── .gitignore ├── config.py ├── model.py ├── setup.py ├── test.py ├── inference.py ├── utils.py ├── dataset.py ├── README.md ├── LICENSE ├── train.py ├── imgproc.py └── image_quality_assessment.py /results/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /samples/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figure/comic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lornatang/ESPCN-PyTorch/HEAD/figure/comic.png -------------------------------------------------------------------------------- /figure/sr_comic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lornatang/ESPCN-PyTorch/HEAD/figure/sr_comic.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | numpy 3 | tqdm 4 | torch 5 | setuptools 6 | torchvision 7 | natsort -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Prepare dataset 4 | os.system("python3 ./prepare_dataset.py --images_dir ../data/T91/original --output_dir ../data/T91/ESPCN/train --image_size 70 --step 35 --num_workers 16") 5 | 6 | # Split train and valid 7 | os.system("python3 ./split_train_valid_dataset.py --train_images_dir ../data/T91/ESPCN/train --valid_images_dir ../data/T91/ESPCN/valid --valid_samples_ratio 0.1") 8 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | ## Step1: Download datasets 4 | 5 | Contains DIV2K, DIV8K, Flickr2K, OST, T91, Set5, Set14, BSDS100 and BSDS200, etc. 6 | 7 | - [Google Driver](https://drive.google.com/drive/folders/1A6lzGeQrFMxPqJehK9s37ce-tPDj20mD?usp=sharing) 8 | - [Baidu Driver](https://pan.baidu.com/s/1o-8Ty_7q6DiS3ykLU09IVg?pwd=llot) 9 | 10 | ## Step2: Prepare the dataset in the following format 11 | 12 | ```text 13 | # Train dataset struct 14 | - T91 15 | - original 16 | - t1.png 17 | - t2.png 18 | ... 19 | 20 | # Test dataset struct 21 | - Set5 22 | - GTmod12 23 | - baby.png 24 | - bird.png 25 | - ... 26 | - LRbicx4 27 | - baby.png 28 | - bird.png 29 | - ... 30 | ``` 31 | 32 | ## Step3: Preprocess the train dataset 33 | 34 | ```bash 35 | cd /scripts 36 | python3 run.py 37 | ``` 38 | 39 | ## Step4: Check that the final dataset directory schema is completely correct 40 | 41 | ```text 42 | # Train dataset 43 | - T91 44 | - original 45 | - ESPCN 46 | - train 47 | - valid 48 | 49 | # Test dataset 50 | - Set5 51 | - GTmod12 52 | - baby.png 53 | - bird.png 54 | - ... 55 | - LRbicx4 56 | - baby.png 57 | - bird.png 58 | - ... 59 | 60 | ``` 61 | -------------------------------------------------------------------------------- /scripts/split_train_valid_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import argparse 15 | import os 16 | import random 17 | import shutil 18 | 19 | from tqdm import tqdm 20 | 21 | 22 | def main(args) -> None: 23 | if not os.path.exists(args.train_images_dir): 24 | os.makedirs(args.train_images_dir) 25 | if not os.path.exists(args.valid_images_dir): 26 | os.makedirs(args.valid_images_dir) 27 | 28 | train_files = os.listdir(args.train_images_dir) 29 | valid_files = random.sample(train_files, int(len(train_files) * args.valid_samples_ratio)) 30 | 31 | process_bar = tqdm(valid_files, total=len(valid_files), unit="image", desc="Split train/valid dataset") 32 | 33 | for image_file_name in process_bar: 34 | shutil.copyfile(f"{args.train_images_dir}/{image_file_name}", f"{args.valid_images_dir}/{image_file_name}") 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser(description="Split train and valid dataset scripts.") 39 | parser.add_argument("--train_images_dir", type=str, help="Path to train image directory.") 40 | parser.add_argument("--valid_images_dir", type=str, help="Path to valid image directory.") 41 | parser.add_argument("--valid_samples_ratio", type=float, help="What percentage of the data is extracted from the training set into the validation set.") 42 | args = parser.parse_args() 43 | 44 | main(args) 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # custom 126 | .idea 127 | .vscode 128 | 129 | # Mac configure file. 130 | .DS_Store 131 | 132 | # Program run create directory. 133 | data 134 | results 135 | samples 136 | 137 | # Program run create file. 138 | *.mdb 139 | *.bmp 140 | *.png 141 | *.mp4 142 | *.zip 143 | *.csv 144 | *.pth 145 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import random 15 | 16 | import numpy as np 17 | import torch 18 | from torch.backends import cudnn 19 | 20 | # Random seed to maintain reproducible results 21 | random.seed(0) 22 | torch.manual_seed(0) 23 | np.random.seed(0) 24 | # Use GPU for training by default 25 | device = torch.device("cuda", 0) 26 | # Turning on when the image size does not change during training can speed up training 27 | cudnn.benchmark = True 28 | # When evaluating the performance of the SR model, whether to verify only the Y channel image data 29 | only_test_y_channel = False 30 | # Model architecture name 31 | model_arch_name = "espcn_x4" 32 | # Model arch config 33 | in_channels = 1 34 | out_channels = 1 35 | channels = 64 36 | upscale_factor = 4 37 | # Current configuration parameter method 38 | mode = "train" 39 | # Experiment name, easy to save weights and log files 40 | exp_name = "ESPCN_x4-T91" 41 | 42 | if mode == "train": 43 | # Dataset address 44 | train_gt_images_dir = f"./data/T91/ESPCN/train" 45 | 46 | test_gt_images_dir = f"./data/Set5/GTmod12" 47 | test_lr_images_dir = f"./data/Set5/LRbicx{upscale_factor}" 48 | 49 | gt_image_size = int(17 * upscale_factor) 50 | batch_size = 16 51 | num_workers = 4 52 | 53 | # The address to load the pretrained model 54 | pretrained_model_weights_path = f"" 55 | 56 | # Incremental training and migration training 57 | resume_model_weights_path = f"" 58 | 59 | # Total num epochs 60 | epochs = 3000 61 | 62 | # loss function weights 63 | loss_weights = 1.0 64 | 65 | # Optimizer parameter 66 | model_lr = 1e-2 67 | model_momentum = 0.9 68 | model_weight_decay = 1e-4 69 | model_nesterov = False 70 | 71 | # EMA parameter 72 | model_ema_decay = 0.999 73 | 74 | # Dynamically adjust the learning rate policy 75 | lr_scheduler_milestones = [int(epochs * 0.1), int(epochs * 0.8)] 76 | lr_scheduler_gamma = 0.1 77 | 78 | # How many iterations to print the training result 79 | train_print_frequency = 100 80 | test_print_frequency = 1 81 | 82 | if mode == "test": 83 | # Test data address 84 | lr_dir = f"./data/Set5/LRbicx{upscale_factor}" 85 | sr_dir = f"./results/test/{exp_name}" 86 | gt_dir = "./data/Set5/GTmod12" 87 | 88 | model_weights_path = "./results/pretrained_models/ESPCN_x4-T91-64bf5ee4.pth.tar" 89 | -------------------------------------------------------------------------------- /scripts/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import argparse 15 | import multiprocessing 16 | import os 17 | import shutil 18 | 19 | import cv2 20 | import numpy as np 21 | from tqdm import tqdm 22 | 23 | 24 | def main(args) -> None: 25 | if os.path.exists(args.output_dir): 26 | shutil.rmtree(args.output_dir) 27 | os.makedirs(args.output_dir) 28 | 29 | # Get all image paths 30 | image_file_names = os.listdir(args.images_dir) 31 | 32 | # Splitting images with multiple threads 33 | progress_bar = tqdm(total=len(image_file_names), unit="image", desc="Prepare split image") 34 | workers_pool = multiprocessing.Pool(args.num_workers) 35 | for image_file_name in image_file_names: 36 | workers_pool.apply_async(worker, args=(image_file_name, args), callback=lambda arg: progress_bar.update(1)) 37 | workers_pool.close() 38 | workers_pool.join() 39 | progress_bar.close() 40 | 41 | 42 | def worker(image_file_name, args) -> None: 43 | image = cv2.imread(f"{args.images_dir}/{image_file_name}", cv2.IMREAD_UNCHANGED) 44 | 45 | image_height, image_width = image.shape[0:2] 46 | 47 | index = 1 48 | if image_height >= args.gt_image_size and image_width >= args.gt_image_size: 49 | for pos_y in range(0, image_height - args.gt_image_size + 1, args.step): 50 | for pos_x in range(0, image_width - args.gt_image_size + 1, args.step): 51 | # Crop 52 | crop_image = image[pos_y: pos_y + args.gt_image_size, pos_x:pos_x + args.gt_image_size, ...] 53 | crop_image = np.ascontiguousarray(crop_image) 54 | # Save image 55 | cv2.imwrite(f"{args.output_dir}/{image_file_name.split('.')[-2]}_{index:04d}.{image_file_name.split('.')[-1]}", crop_image) 56 | 57 | index += 1 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser(description="Prepare database scripts.") 62 | parser.add_argument("--images_dir", type=str, help="Path to input image directory.") 63 | parser.add_argument("--output_dir", type=str, help="Path to generator image directory.") 64 | parser.add_argument("--image_size", type=int, help="Low-resolution image size from raw image.") 65 | parser.add_argument("--step", type=int, help="Crop image similar to sliding window.") 66 | parser.add_argument("--num_workers", type=int, help="How many threads to open at the same time.") 67 | args = parser.parse_args() 68 | 69 | main(args) 70 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import math 15 | 16 | import torch 17 | from torch import nn, Tensor 18 | 19 | __all__ = [ 20 | "ESPCN", 21 | "espcn_x2", "espcn_x3", "espcn_x4", "espcn_x8", 22 | ] 23 | 24 | 25 | class ESPCN(nn.Module): 26 | def __init__( 27 | self, 28 | in_channels: int, 29 | out_channels: int, 30 | channels: int, 31 | upscale_factor: int, 32 | ) -> None: 33 | super(ESPCN, self).__init__() 34 | hidden_channels = channels // 2 35 | out_channels = int(out_channels * (upscale_factor ** 2)) 36 | 37 | # Feature mapping 38 | self.feature_maps = nn.Sequential( 39 | nn.Conv2d(in_channels, channels, (5, 5), (1, 1), (2, 2)), 40 | nn.Tanh(), 41 | nn.Conv2d(channels, hidden_channels, (3, 3), (1, 1), (1, 1)), 42 | nn.Tanh(), 43 | ) 44 | 45 | # Sub-pixel convolution layer 46 | self.sub_pixel = nn.Sequential( 47 | nn.Conv2d(hidden_channels, out_channels, (3, 3), (1, 1), (1, 1)), 48 | nn.PixelShuffle(upscale_factor), 49 | ) 50 | 51 | # Initial model weights 52 | for module in self.modules(): 53 | if isinstance(module, nn.Conv2d): 54 | if module.in_channels == 32: 55 | nn.init.normal_(module.weight.data, 56 | 0.0, 57 | 0.001) 58 | nn.init.zeros_(module.bias.data) 59 | else: 60 | nn.init.normal_(module.weight.data, 61 | 0.0, 62 | math.sqrt(2 / (module.out_channels * module.weight.data[0][0].numel()))) 63 | nn.init.zeros_(module.bias.data) 64 | 65 | def forward(self, x: Tensor) -> Tensor: 66 | return self._forward_impl(x) 67 | 68 | # Support torch.script function. 69 | def _forward_impl(self, x: Tensor) -> Tensor: 70 | x = self.feature_maps(x) 71 | x = self.sub_pixel(x) 72 | 73 | x = torch.clamp_(x, 0.0, 1.0) 74 | 75 | return x 76 | 77 | 78 | def espcn_x2(**kwargs) -> ESPCN: 79 | model = ESPCN(upscale_factor=2, **kwargs) 80 | 81 | return model 82 | 83 | 84 | def espcn_x3(**kwargs) -> ESPCN: 85 | model = ESPCN(upscale_factor=3, **kwargs) 86 | 87 | return model 88 | 89 | 90 | def espcn_x4(**kwargs) -> ESPCN: 91 | model = ESPCN(upscale_factor=4, **kwargs) 92 | 93 | return model 94 | 95 | 96 | def espcn_x8(**kwargs) -> ESPCN: 97 | model = ESPCN(upscale_factor=8, **kwargs) 98 | 99 | return model -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import io 15 | import os 16 | import sys 17 | from shutil import rmtree 18 | 19 | from setuptools import Command 20 | from setuptools import find_packages 21 | from setuptools import setup 22 | 23 | # Configure library params. 24 | NAME = "espcn_pytorch" 25 | DESCRIPTION = "Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network." 26 | URL = "https://github.com/Lornatang/ESPCN-PyTorch" 27 | EMAIL = "liu_changyu@dakewe.com" 28 | AUTHOR = "Liu Goodfellow" 29 | REQUIRES_PYTHON = ">=3.8.0" 30 | VERSION = "1.0.0" 31 | 32 | # Libraries that must be installed. 33 | REQUIRED = ["torch"] 34 | 35 | # The following libraries directory need to be installed if you need to run all scripts. 36 | EXTRAS = {} 37 | 38 | # Find the current running location. 39 | here = os.path.abspath(os.path.dirname(__file__)) 40 | 41 | # About README file description. 42 | try: 43 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 44 | long_description = "\n" + f.read() 45 | except FileNotFoundError: 46 | long_description = DESCRIPTION 47 | 48 | # Set Current Library Version. 49 | about = {} 50 | if not VERSION: 51 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 52 | with open(os.path.join(here, project_slug, "__version__.py")) as f: 53 | exec(f.read(), about) 54 | else: 55 | about["__version__"] = VERSION 56 | 57 | 58 | class UploadCommand(Command): 59 | description = "Build and publish the package." 60 | user_options = [] 61 | 62 | @staticmethod 63 | def status(s): 64 | print("\033[1m{0}\033[0m".format(s)) 65 | 66 | def initialize_options(self): 67 | pass 68 | 69 | def finalize_options(self): 70 | pass 71 | 72 | def run(self): 73 | try: 74 | self.status("Removing previous builds…") 75 | rmtree(os.path.join(here, "dist")) 76 | except OSError: 77 | pass 78 | 79 | self.status("Building Source and Wheel (universal) distribution…") 80 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) 81 | 82 | self.status("Uploading the package to PyPI via Twine…") 83 | os.system("twine upload dist/*") 84 | 85 | self.status("Pushing git tags…") 86 | os.system("git tag v{0}".format(about["__version__"])) 87 | os.system("git push --tags") 88 | 89 | sys.exit() 90 | 91 | 92 | setup(name=NAME, 93 | version=about["__version__"], 94 | description=DESCRIPTION, 95 | long_description=long_description, 96 | long_description_content_type="text/markdown", 97 | author=AUTHOR, 98 | author_email=EMAIL, 99 | python_requires=REQUIRES_PYTHON, 100 | url=URL, 101 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 102 | install_requires=REQUIRED, 103 | extras_require=EXTRAS, 104 | include_package_data=True, 105 | license="Apache", 106 | classifiers=[ 107 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 108 | "License :: OSI Approved :: Apache Software License", 109 | "Programming Language :: Python :: 3 :: Only" 110 | ], 111 | cmdclass={ 112 | "upload": UploadCommand, 113 | }, 114 | ) 115 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import os 15 | 16 | import cv2 17 | import numpy as np 18 | import torch 19 | from natsort import natsorted 20 | 21 | import imgproc 22 | import model 23 | import config 24 | from image_quality_assessment import PSNR, SSIM 25 | from utils import make_directory 26 | 27 | model_names = sorted( 28 | name for name in model.__dict__ if 29 | name.islower() and not name.startswith("__") and callable(model.__dict__[name])) 30 | 31 | 32 | def main() -> None: 33 | # Initialize the super-resolution bsrgan_model 34 | g_model = model.__dict__[config.model_arch_name](in_channels=config.in_channels, 35 | out_channels=config.out_channels, 36 | channels=config.channels) 37 | g_model = g_model.to(device=config.device) 38 | print(f"Build `{config.model_arch_name}` model successfully.") 39 | 40 | # Load the super-resolution bsrgan_model weights 41 | checkpoint = torch.load(config.model_weights_path, map_location=lambda storage, loc: storage) 42 | g_model.load_state_dict(checkpoint["state_dict"]) 43 | print(f"Load `{config.model_arch_name}` model weights " 44 | f"`{os.path.abspath(config.model_weights_path)}` successfully.") 45 | 46 | # Create a folder of super-resolution experiment results 47 | make_directory(config.sr_dir) 48 | 49 | # Start the verification mode of the bsrgan_model. 50 | g_model.eval() 51 | 52 | # Initialize the sharpness evaluation function 53 | psnr = PSNR(config.upscale_factor, config.only_test_y_channel) 54 | ssim = SSIM(config.upscale_factor, config.only_test_y_channel) 55 | 56 | # Set the sharpness evaluation function calculation device to the specified model 57 | psnr = psnr.to(device=config.device, non_blocking=True) 58 | ssim = ssim.to(device=config.device, non_blocking=True) 59 | 60 | # Initialize IQA metrics 61 | psnr_metrics = 0.0 62 | ssim_metrics = 0.0 63 | 64 | # Get a list of test image file names. 65 | file_names = natsorted(os.listdir(config.lr_dir)) 66 | # Get the number of test image files. 67 | total_files = len(file_names) 68 | 69 | for index in range(total_files): 70 | lr_image_path = os.path.join(config.lr_dir, file_names[index]) 71 | sr_image_path = os.path.join(config.sr_dir, file_names[index]) 72 | gt_image_path = os.path.join(config.gt_dir, file_names[index]) 73 | 74 | print(f"Processing `{os.path.abspath(lr_image_path)}`...") 75 | gt_y_tensor, gt_cb_image, gt_cr_image = imgproc.preprocess_one_image(gt_image_path, config.device) 76 | lr_y_tensor, lr_cb_image, lr_cr_image = imgproc.preprocess_one_image(lr_image_path, config.device) 77 | 78 | # Only reconstruct the Y channel image data. 79 | with torch.no_grad(): 80 | sr_y_tensor = g_model(lr_y_tensor) 81 | 82 | # Save image 83 | sr_y_image = imgproc.tensor_to_image(sr_y_tensor, range_norm=False, half=True) 84 | sr_y_image = sr_y_image.astype(np.float32) / 255.0 85 | sr_ycbcr_image = cv2.merge([sr_y_image, gt_cb_image, gt_cr_image]) 86 | sr_image = imgproc.ycbcr_to_bgr(sr_ycbcr_image) 87 | cv2.imwrite(sr_image_path, sr_image * 255.0) 88 | 89 | # Cal IQA metrics 90 | psnr_metrics += psnr(sr_y_tensor, gt_y_tensor).item() 91 | ssim_metrics += ssim(sr_y_tensor, gt_y_tensor).item() 92 | 93 | # Calculate the average value of the sharpness evaluation index, 94 | # and all index range values are cut according to the following values 95 | # PSNR range value is 0~100 96 | # SSIM range value is 0~1 97 | avg_psnr = 100 if psnr_metrics / total_files > 100 else psnr_metrics / total_files 98 | avg_ssim = 1 if ssim_metrics / total_files > 1 else ssim_metrics / total_files 99 | 100 | print(f"PSNR: {avg_psnr:4.2f} [dB]\n" 101 | f"SSIM: {avg_ssim:4.4f} [u]") 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import argparse 15 | import os 16 | 17 | import cv2 18 | import torch 19 | from flatbuffers.builder import np 20 | from torch import nn 21 | 22 | import config 23 | import imgproc 24 | import model 25 | from utils import load_state_dict 26 | 27 | model_names = sorted( 28 | name for name in model.__dict__ if 29 | name.islower() and not name.startswith("__") and callable(model.__dict__[name])) 30 | 31 | 32 | def choice_device(device_type: str) -> torch.device: 33 | # Select model processing equipment type 34 | if device_type == "cuda": 35 | device = torch.device("cuda", 0) 36 | else: 37 | device = torch.device("cpu") 38 | return device 39 | 40 | 41 | def build_model(model_arch_name: str, device: torch.device) -> nn.Module: 42 | # Initialize the super-resolution model 43 | sr_model = model.__dict__[model_arch_name](in_channels=1, 44 | out_channels=1, 45 | channels=64) 46 | sr_model = sr_model.to(device=device) 47 | 48 | return sr_model 49 | 50 | 51 | def main(args): 52 | device = choice_device(args.device_type) 53 | 54 | # Initialize the model 55 | sr_model = build_model(args.model_arch_name, device) 56 | print(f"Build `{args.model_arch_name}` model successfully.") 57 | 58 | # Load model weights 59 | sr_model = load_state_dict(sr_model, args.model_weights_path) 60 | print(f"Load `{args.model_arch_name}` model weights `{os.path.abspath(args.model_weights_path)}` successfully.") 61 | 62 | # Start the verification mode of the model. 63 | sr_model.eval() 64 | 65 | lr_y_tensor, lr_cb_image, lr_cr_image = imgproc.preprocess_one_image(args.inputs_path, device) 66 | 67 | bic_cb_image = cv2.resize(lr_cb_image, 68 | (int(lr_cb_image.shape[1] * args.upscale_factor), 69 | int(lr_cb_image.shape[0] * args.upscale_factor)), 70 | interpolation=cv2.INTER_CUBIC) 71 | bic_cr_image = cv2.resize(lr_cr_image, 72 | (int(lr_cr_image.shape[1] * args.upscale_factor), 73 | int(lr_cr_image.shape[0] * args.upscale_factor)), 74 | interpolation=cv2.INTER_CUBIC) 75 | # Use the model to generate super-resolved images 76 | with torch.no_grad(): 77 | sr_y_tensor = sr_model(lr_y_tensor) 78 | 79 | # Save image 80 | sr_y_image = imgproc.tensor_to_image(sr_y_tensor, range_norm=False, half=False) 81 | sr_y_image = sr_y_image.astype(np.float32) / 255.0 82 | 83 | sr_ycbcr_image = cv2.merge([sr_y_image[:, :, 0], bic_cb_image, bic_cr_image]) 84 | sr_image = imgproc.ycbcr_to_bgr(sr_ycbcr_image) 85 | cv2.imwrite(args.output_path, sr_image * 255.0) 86 | 87 | print(f"SR image save to `{args.output_path}`") 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser(description="Using the model generator super-resolution images.") 92 | parser.add_argument("--model_arch_name", 93 | type=str, 94 | default="espcn_x4") 95 | parser.add_argument("--upscale_factor", 96 | type=int, 97 | default=4) 98 | parser.add_argument("--inputs_path", 99 | type=str, 100 | default="./figure/comic.png", 101 | help="Low-resolution image path.") 102 | parser.add_argument("--output_path", 103 | type=str, 104 | default="./figure/sr_comic.png", 105 | help="Super-resolution image path.") 106 | parser.add_argument("--model_weights_path", 107 | type=str, 108 | default="./results/pretrained_models/ESPCN_x4-T91-64bf5ee4.pth.tar", 109 | help="Model weights file path.") 110 | parser.add_argument("--device_type", 111 | type=str, 112 | default="cpu", 113 | choices=["cpu", "cuda"]) 114 | args = parser.parse_args() 115 | 116 | main(args) 117 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import os 15 | import shutil 16 | from enum import Enum 17 | from typing import Any 18 | 19 | import torch 20 | from torch import nn 21 | from torch.nn import Module 22 | from torch.optim import Optimizer 23 | 24 | __all__ = [ 25 | "load_state_dict", "make_directory", "save_checkpoint", 26 | "Summary", "AverageMeter", "ProgressMeter" 27 | ] 28 | 29 | 30 | def load_state_dict( 31 | model: nn.Module, 32 | model_weights_path: str, 33 | ema_model: nn.Module = None, 34 | optimizer: torch.optim.Optimizer = None, 35 | scheduler: torch.optim.lr_scheduler = None, 36 | load_mode: str = "pretrained", 37 | ) -> tuple[Module, Module, Any, Any, Any, Optimizer | None, Any] | tuple[Module, Any, Any, Any, Optimizer | None, Any] | Module: 38 | # Load model weights 39 | checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage) 40 | 41 | if load_mode == "resume": 42 | # Restore the parameters in the training node to this point 43 | start_epoch = checkpoint["epoch"] 44 | best_psnr = checkpoint["best_psnr"] 45 | best_ssim = checkpoint["best_ssim"] 46 | # Load model state dict. Extract the fitted model weights 47 | model_state_dict = model.state_dict() 48 | state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys()} 49 | # Overwrite the model weights to the current model (base model) 50 | model_state_dict.update(state_dict) 51 | model.load_state_dict(model_state_dict) 52 | # Load the optimizer model 53 | optimizer.load_state_dict(checkpoint["optimizer"]) 54 | 55 | if scheduler is not None: 56 | # Load the scheduler model 57 | scheduler.load_state_dict(checkpoint["scheduler"]) 58 | 59 | if ema_model is not None: 60 | # Load ema model state dict. Extract the fitted model weights 61 | ema_model_state_dict = ema_model.state_dict() 62 | ema_state_dict = {k: v for k, v in checkpoint["ema_state_dict"].items() if k in ema_model_state_dict.keys()} 63 | # Overwrite the model weights to the current model (ema model) 64 | ema_model_state_dict.update(ema_state_dict) 65 | ema_model.load_state_dict(ema_model_state_dict) 66 | 67 | return model, ema_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler 68 | elif load_mode == "pretrained": 69 | # Load model state dict. Extract the fitted model weights 70 | model_state_dict = model.state_dict() 71 | state_dict = {k: v for k, v in checkpoint["state_dict"].items() if 72 | k in model_state_dict.keys() and v.size() == model_state_dict[k].size()} 73 | # Overwrite the model weights to the current model 74 | model_state_dict.update(state_dict) 75 | model.load_state_dict(model_state_dict) 76 | 77 | return model 78 | 79 | else: 80 | assert f"Unsupported `{load_mode}`, only supported `resume` and `pretrained`." 81 | 82 | 83 | def make_directory(dir_path: str) -> None: 84 | if not os.path.exists(dir_path): 85 | os.makedirs(dir_path) 86 | 87 | 88 | def save_checkpoint( 89 | state_dict: dict, 90 | file_name: str, 91 | samples_dir: str, 92 | results_dir: str, 93 | best_file_name: str, 94 | last_file_name: str, 95 | is_best: bool = False, 96 | is_last: bool = False, 97 | ) -> None: 98 | checkpoint_path = os.path.join(samples_dir, file_name) 99 | torch.save(state_dict, checkpoint_path) 100 | 101 | if is_best: 102 | shutil.copyfile(checkpoint_path, os.path.join(results_dir, best_file_name)) 103 | if is_last: 104 | shutil.copyfile(checkpoint_path, os.path.join(results_dir, last_file_name)) 105 | 106 | 107 | class Summary(Enum): 108 | NONE = 0 109 | AVERAGE = 1 110 | SUM = 2 111 | COUNT = 3 112 | 113 | 114 | class AverageMeter(object): 115 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): 116 | self.name = name 117 | self.fmt = fmt 118 | self.summary_type = summary_type 119 | self.reset() 120 | 121 | def reset(self): 122 | self.val = 0 123 | self.avg = 0 124 | self.sum = 0 125 | self.count = 0 126 | 127 | def update(self, val, n=1): 128 | self.val = val 129 | self.sum += val * n 130 | self.count += n 131 | self.avg = self.sum / self.count 132 | 133 | def __str__(self): 134 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 135 | return fmtstr.format(**self.__dict__) 136 | 137 | def summary(self): 138 | if self.summary_type is Summary.NONE: 139 | fmtstr = "" 140 | elif self.summary_type is Summary.AVERAGE: 141 | fmtstr = "{name} {avg:.2f}" 142 | elif self.summary_type is Summary.SUM: 143 | fmtstr = "{name} {sum:.2f}" 144 | elif self.summary_type is Summary.COUNT: 145 | fmtstr = "{name} {count:.2f}" 146 | else: 147 | raise ValueError(f"Invalid summary type {self.summary_type}") 148 | 149 | return fmtstr.format(**self.__dict__) 150 | 151 | 152 | class ProgressMeter(object): 153 | def __init__(self, num_batches, meters, prefix=""): 154 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 155 | self.meters = meters 156 | self.prefix = prefix 157 | 158 | def display(self, batch): 159 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 160 | entries += [str(meter) for meter in self.meters] 161 | print("\t".join(entries)) 162 | 163 | def display_summary(self): 164 | entries = [" *"] 165 | entries += [meter.summary() for meter in self.meters] 166 | print(" ".join(entries)) 167 | 168 | def _get_batch_fmtstr(self, num_batches): 169 | num_digits = len(str(num_batches // 1)) 170 | fmt = "{:" + str(num_digits) + "d}" 171 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 172 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import os 15 | import queue 16 | import threading 17 | 18 | import cv2 19 | import numpy as np 20 | import torch 21 | from torch import Tensor 22 | from torch.utils.data import Dataset, DataLoader 23 | 24 | import imgproc 25 | 26 | __all__ = [ 27 | "TrainValidImageDataset", "TestImageDataset", 28 | "PrefetchGenerator", "PrefetchDataLoader", "CPUPrefetcher", "CUDAPrefetcher", 29 | ] 30 | 31 | 32 | class TrainValidImageDataset(Dataset): 33 | """Define training/valid dataset loading methods. 34 | 35 | Args: 36 | gt_image_dir (str): Train/Valid ground-truth dataset address. 37 | gt_image_size (int): Ground-truth resolution image size. 38 | upscale_factor (int): Image up scale factor. 39 | mode (str): Data set loading method, the training data set is for data enhancement, and the 40 | verification dataset is not for data enhancement. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | gt_image_dir: str, 46 | gt_image_size: int, 47 | upscale_factor: int, 48 | mode: str, 49 | ) -> None: 50 | super(TrainValidImageDataset, self).__init__() 51 | self.image_file_names = [os.path.join(gt_image_dir, image_file_name) for image_file_name in 52 | os.listdir(gt_image_dir)] 53 | self.gt_image_size = gt_image_size 54 | self.upscale_factor = upscale_factor 55 | self.mode = mode 56 | 57 | def __getitem__(self, batch_index: int) -> [dict[str, Tensor], dict[str, Tensor]]: 58 | # Read a batch of image data 59 | gt_crop_image = cv2.imread(self.image_file_names[batch_index]).astype(np.float32) / 255. 60 | 61 | # Image processing operations 62 | if self.mode == "Train": 63 | gt_crop_image = imgproc.random_crop(gt_crop_image, self.gt_image_size) 64 | elif self.mode == "Valid": 65 | gt_crop_image = imgproc.center_crop(gt_crop_image, self.gt_image_size) 66 | else: 67 | raise ValueError("Unsupported data processing model, please use `Train` or `Valid`.") 68 | 69 | lr_crop_image = imgproc.image_resize(gt_crop_image, 1 / self.upscale_factor) 70 | 71 | # BGR convert Y channel 72 | gt_crop_y_image = imgproc.bgr_to_ycbcr(gt_crop_image, only_use_y_channel=True) 73 | lr_crop_y_image = imgproc.bgr_to_ycbcr(lr_crop_image, only_use_y_channel=True) 74 | 75 | # Convert image data into Tensor stream format (PyTorch). 76 | # Note: The range of input and output is between [0, 1] 77 | gt_crop_y_tensor = imgproc.image_to_tensor(gt_crop_y_image, False, False) 78 | lr_crop_y_tensor = imgproc.image_to_tensor(lr_crop_y_image, False, False) 79 | 80 | return {"gt": gt_crop_y_tensor, "lr": lr_crop_y_tensor} 81 | 82 | def __len__(self) -> int: 83 | return len(self.image_file_names) 84 | 85 | 86 | class TestImageDataset(Dataset): 87 | """Define Test dataset loading methods. 88 | 89 | Args: 90 | test_gt_images_dir (str): ground truth image in test image 91 | test_lr_images_dir (str): low-resolution image in test image 92 | """ 93 | 94 | def __init__(self, test_gt_images_dir: str, test_lr_images_dir: str) -> None: 95 | super(TestImageDataset, self).__init__() 96 | # Get all image file names in folder 97 | self.gt_image_file_names = [os.path.join(test_gt_images_dir, x) for x in os.listdir(test_gt_images_dir)] 98 | self.lr_image_file_names = [os.path.join(test_lr_images_dir, x) for x in os.listdir(test_lr_images_dir)] 99 | 100 | def __getitem__(self, batch_index: int) -> [torch.Tensor, torch.Tensor]: 101 | # Read a batch of image data 102 | gt_image = cv2.imread(self.gt_image_file_names[batch_index]).astype(np.float32) / 255. 103 | lr_image = cv2.imread(self.lr_image_file_names[batch_index]).astype(np.float32) / 255. 104 | 105 | # BGR convert Y channel 106 | gt_y_image = imgproc.bgr_to_ycbcr(gt_image, only_use_y_channel=True) 107 | lr_y_image = imgproc.bgr_to_ycbcr(lr_image, only_use_y_channel=True) 108 | 109 | # Convert image data into Tensor stream format (PyTorch). 110 | # Note: The range of input and output is between [0, 1] 111 | gt_y_tensor = imgproc.image_to_tensor(gt_y_image, False, False) 112 | lr_y_tensor = imgproc.image_to_tensor(lr_y_image, False, False) 113 | 114 | return {"gt": gt_y_tensor, "lr": lr_y_tensor} 115 | 116 | def __len__(self) -> int: 117 | return len(self.gt_image_file_names) 118 | 119 | 120 | class PrefetchGenerator(threading.Thread): 121 | """A fast data prefetch generator. 122 | 123 | Args: 124 | generator: Data generator. 125 | num_data_prefetch_queue (int): How many early data load queues. 126 | """ 127 | 128 | def __init__(self, generator, num_data_prefetch_queue: int) -> None: 129 | threading.Thread.__init__(self) 130 | self.queue = queue.Queue(num_data_prefetch_queue) 131 | self.generator = generator 132 | self.daemon = True 133 | self.start() 134 | 135 | def run(self) -> None: 136 | for item in self.generator: 137 | self.queue.put(item) 138 | self.queue.put(None) 139 | 140 | def __next__(self): 141 | next_item = self.queue.get() 142 | if next_item is None: 143 | raise StopIteration 144 | return next_item 145 | 146 | def __iter__(self): 147 | return self 148 | 149 | 150 | class PrefetchDataLoader(DataLoader): 151 | """A fast data prefetch dataloader. 152 | 153 | Args: 154 | num_data_prefetch_queue (int): How many early data load queues. 155 | kwargs (dict): Other extended parameters. 156 | """ 157 | 158 | def __init__(self, num_data_prefetch_queue: int, **kwargs) -> None: 159 | self.num_data_prefetch_queue = num_data_prefetch_queue 160 | super(PrefetchDataLoader, self).__init__(**kwargs) 161 | 162 | def __iter__(self): 163 | return PrefetchGenerator(super().__iter__(), self.num_data_prefetch_queue) 164 | 165 | 166 | class CPUPrefetcher: 167 | """Use the CPU side to accelerate data reading. 168 | 169 | Args: 170 | dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. 171 | """ 172 | 173 | def __init__(self, dataloader: DataLoader) -> None: 174 | self.original_dataloader = dataloader 175 | self.data = iter(dataloader) 176 | 177 | def next(self): 178 | try: 179 | return next(self.data) 180 | except StopIteration: 181 | return None 182 | 183 | def reset(self): 184 | self.data = iter(self.original_dataloader) 185 | 186 | def __len__(self) -> int: 187 | return len(self.original_dataloader) 188 | 189 | 190 | class CUDAPrefetcher: 191 | """Use the CUDA side to accelerate data reading. 192 | 193 | Args: 194 | dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. 195 | device (torch.device): Specify running device. 196 | """ 197 | 198 | def __init__(self, dataloader: DataLoader, device: torch.device): 199 | self.batch_data = None 200 | self.original_dataloader = dataloader 201 | self.device = device 202 | 203 | self.data = iter(dataloader) 204 | self.stream = torch.cuda.Stream() 205 | self.preload() 206 | 207 | def preload(self): 208 | try: 209 | self.batch_data = next(self.data) 210 | except StopIteration: 211 | self.batch_data = None 212 | return None 213 | 214 | with torch.cuda.stream(self.stream): 215 | for k, v in self.batch_data.items(): 216 | if torch.is_tensor(v): 217 | self.batch_data[k] = self.batch_data[k].to(self.device, non_blocking=True) 218 | 219 | def next(self): 220 | torch.cuda.current_stream().wait_stream(self.stream) 221 | batch_data = self.batch_data 222 | self.preload() 223 | return batch_data 224 | 225 | def reset(self): 226 | self.data = iter(self.original_dataloader) 227 | self.preload() 228 | 229 | def __len__(self) -> int: 230 | return len(self.original_dataloader) 231 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESPCN-PyTorch 2 | 3 | ## Overview 4 | 5 | This repository contains an op-for-op PyTorch reimplementation of [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](https://arxiv.org/abs/1609.05158v2). 6 | 7 | ## Table of contents 8 | 9 | - [ESPCN-PyTorch](#espcn-pytorch) 10 | - [Overview](#overview) 11 | - [Table of contents](#table-of-contents) 12 | - [About Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](#about-real-time-single-image-and-video-super-resolution-using-an-efficient-sub-pixel-convolutional-neural-network) 13 | - [Download weights](#download-weights) 14 | - [Download datasets](#download-datasets) 15 | - [How Test and Train](#how-test-and-train) 16 | - [Test ESPCN_x4](#test-espcn_x4) 17 | - [Train ESPCN_x4](#train-espcn_x4) 18 | - [Resume ESPCN_x4](#resume-train-espcn_x4) 19 | - [Result](#result) 20 | - [Credit](#credit) 21 | - [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](#real-time-single-image-and-video-super-resolution-using-an-efficient-sub-pixel-convolutional-neural-network) 22 | 23 | ## About Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network 24 | 25 | If you're new to ESPCN, here's an abstract straight from the paper: 26 | 27 | Recently, several models based on deep neural networks have achieved great success in terms of both reconstruction accuracy and computational 28 | performance for single image super-resolution. In these methods, the low resolution (LR) 29 | input image is upscaled to the high resolution (HR) space using a single filter, commonly bicubic interpolation, before reconstruction. This means 30 | that the super-resolution (SR) operation is performed in HR space. We demonstrate that this is sub-optimal and adds computational complexity. In this 31 | paper, we present the first convolutional neural network (CNN) capable of real-time SR of 1080p videos on a single K2 GPU. To achieve this, we propose 32 | a novel CNN architecture where the feature maps are extracted in the LR space. In addition, we introduce an efficient sub-pixel convolution layer 33 | which learns an array of upscaling filters to upscale the final LR feature maps into the HR output. By doing so, we effectively replace the 34 | handcrafted bicubic filter in the SR pipeline with more complex upscaling filters specifically trained for each feature map, whilst also reducing the 35 | computational complexity of the overall SR operation. We evaluate the proposed approach using images and videos from publicly available datasets and 36 | show that it performs significantly better (+0.15dB on Images and +0.39dB on Videos) and is an order of magnitude faster than previous CNN-based 37 | methods. 38 | 39 | ## Download weights 40 | 41 | - [Google Driver](https://drive.google.com/drive/folders/17ju2HN7Y6pyPK2CC_AqnAfTOe9_3hCQ8?usp=sharing) 42 | - [Baidu Driver](https://pan.baidu.com/s/1yNs4rqIb004-NKEdKBJtYg?pwd=llot) 43 | 44 | ## Download datasets 45 | 46 | Contains DIV2K, DIV8K, Flickr2K, OST, T91, Set5, Set14, BSDS100 and BSDS200, etc. 47 | 48 | - [Google Driver](https://drive.google.com/drive/folders/1A6lzGeQrFMxPqJehK9s37ce-tPDj20mD?usp=sharing) 49 | - [Baidu Driver](https://pan.baidu.com/s/1o-8Ty_7q6DiS3ykLU09IVg?pwd=llot) 50 | 51 | Please refer to `README.md` in the `data` directory for the method of making a dataset. 52 | 53 | ## How Test and Train 54 | 55 | Both training and testing only need to modify the `config.py` file. 56 | 57 | ### Test ESPCN_x4 58 | 59 | Modify the `config.py` file. 60 | 61 | - line 31: `model_arch_name` change to `espcn_x4`. 62 | - line 36: `upscale_factor` change to `4`. 63 | - line 38: `mode` change to `test`. 64 | - line 40: `exp_name` change to `ESPCN_x4-Set5`. 65 | - line 84: `lr_dir` change to `f"./data/Set5/LRbicx{upscale_factor}"`. 66 | - line 86: `gt_dir` change to `f"./data/Set5/GTmod12"`. 67 | - line 88: `model_weights_path` change to `./results/pretrained_models/ESPCN_x4-T91-64bf5ee4.pth.tar`. 68 | 69 | ```bash 70 | python3 test.py 71 | ``` 72 | 73 | ### Train ESPCN_x4 74 | 75 | Modify the `config.py` file. 76 | 77 | - line 31: `model_arch_name` change to `espcn_x4`. 78 | - line 36: `upscale_factor` change to `4`. 79 | - line 38: `mode` change to `test`. 80 | - line 40: `exp_name` change to `ESPCN_x4-Set5`. 81 | - line 84: `lr_dir` change to `f"./data/Set5/LRbicx{upscale_factor}"`. 82 | - line 86: `gt_dir` change to `f"./data/Set5/GTmod12"`. 83 | 84 | ```bash 85 | python3 train.py 86 | ``` 87 | 88 | ### Resume train ESPCN_x4 89 | 90 | Modify the `config.py` file. 91 | 92 | - line 31: `model_arch_name` change to `espcn_x4`. 93 | - line 36: `upscale_factor` change to `4`. 94 | - line 38: `mode` change to `test`. 95 | - line 40: `exp_name` change to `ESPCN_x4-Set5`. 96 | - line 57: `resume_model_weights_path` change to `./samples/ESPCN_x4-Set5/epoch_xxx.pth.tar`. 97 | - line 84: `lr_dir` change to `f"./data/Set5/LRbicx{upscale_factor}"`. 98 | - line 86: `gt_dir` change to `f"./data/Set5/GTmod12"`. 99 | 100 | ```bash 101 | python3 train.py 102 | ``` 103 | 104 | ## Result 105 | 106 | Source of original paper results: [https://arxiv.org/pdf/1609.05158v2.pdf](https://arxiv.org/pdf/1609.05158v2.pdf) 107 | 108 | In the following table, the value in `()` indicates the result of the project, and `-` indicates no test. 109 | 110 | | Method | Scale | Set5 (PSNR) | Set14 (PSNR) | 111 | |:--------:|:-----:|:----------------:|:----------------:| 112 | | ESPCN_x4 | 2 | -(**36.64**) | -(**32.35**) | 113 | | ESPCN_x3 | 3 | 32.55(**32.55**) | 29.08(**29.20**) | 114 | | ESPCN_x4 | 4 | 30.90(**30.26**) | 27.73(**27.41**) | 115 | 116 | ```bash 117 | # Download `ESPCN_x4-T91-64bf5ee4.pth.tar` weights to `./results/pretrained_models/ESPCN_x4-T91-64bf5ee4.pth.tar` 118 | # More detail see `README.md` 119 | python3 ./inference.py 120 | ``` 121 | 122 | Input: 123 | 124 | 125 | 126 | Output: 127 | 128 | 129 | 130 | ```text 131 | Build `espcn_x4` model successfully. 132 | Load `espcn_x4` model weights `./results/pretrained_models/ESPCN_x4-T91-64bf5ee4.pth.tar` successfully. 133 | SR image save to `./figure/sr_comic.png` 134 | ``` 135 | 136 | ### Credit 137 | 138 | #### Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network 139 | 140 | _Wenzhe Shi, Jose Caballero, Ferenc Huszár, Johannes Totz, Andrew P. Aitken, Rob Bishop, Daniel Rueckert, Zehan Wang_
141 | 142 | **Abstract**
143 | Recently, several models based on deep neural networks have achieved great success in terms of both reconstruction accuracy and computational 144 | performance for single image super-resolution. In these methods, the low resolution (LR) 145 | input image is upscaled to the high resolution (HR) space using a single filter, commonly bicubic interpolation, before reconstruction. This means 146 | that the super-resolution (SR) operation is performed in HR space. We demonstrate that this is sub-optimal and adds computational complexity. In this 147 | paper, we present the first convolutional neural network (CNN) capable of real-time SR of 1080p videos on a single K2 GPU. To achieve this, we propose 148 | a novel CNN architecture where the feature maps are extracted in the LR space. In addition, we introduce an efficient sub-pixel convolution layer 149 | which learns an array of upscaling filters to upscale the final LR feature maps into the HR output. By doing so, we effectively replace the 150 | handcrafted bicubic filter in the SR pipeline with more complex upscaling filters specifically trained for each feature map, whilst also reducing the 151 | computational complexity of the overall SR operation. We evaluate the proposed approach using images and videos from publicly available datasets and 152 | show that it performs significantly better (+0.15dB on Images and +0.39dB on Videos) and is an order of magnitude faster than previous CNN-based 153 | methods. 154 | 155 | [[Paper]](https://arxiv.org/pdf/1609.05158) 156 | 157 | ``` 158 | @article{DBLP:journals/corr/ShiCHTABRW16, 159 | author = {Wenzhe Shi and 160 | Jose Caballero and 161 | Ferenc Husz{\'{a}}r and 162 | Johannes Totz and 163 | Andrew P. Aitken and 164 | Rob Bishop and 165 | Daniel Rueckert and 166 | Zehan Wang}, 167 | title = {Real-Time Single Image and Video Super-Resolution Using an Efficient 168 | Sub-Pixel Convolutional Neural Network}, 169 | journal = {CoRR}, 170 | volume = {abs/1609.05158}, 171 | year = {2016}, 172 | url = {http://arxiv.org/abs/1609.05158}, 173 | archivePrefix = {arXiv}, 174 | eprint = {1609.05158}, 175 | timestamp = {Mon, 13 Aug 2018 16:47:09 +0200}, 176 | biburl = {https://dblp.org/rec/journals/corr/ShiCHTABRW16.bib}, 177 | bibsource = {dblp computer science bibliography, https://dblp.org} 178 | } 179 | ``` 180 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================ 14 | import os 15 | import time 16 | 17 | import torch 18 | from torch import nn 19 | from torch import optim 20 | from torch.cuda import amp 21 | from torch.optim import lr_scheduler 22 | from torch.utils.data import DataLoader 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | import config 26 | import model 27 | from dataset import CUDAPrefetcher, TrainValidImageDataset, TestImageDataset 28 | from image_quality_assessment import PSNR, SSIM 29 | from utils import load_state_dict, make_directory, save_checkpoint, AverageMeter, ProgressMeter 30 | 31 | 32 | def main(): 33 | # Initialize the number of training epochs 34 | start_epoch = 0 35 | 36 | # Initialize training to generate network evaluation indicators 37 | best_psnr = 0.0 38 | best_ssim = 0.0 39 | 40 | train_prefetcher, test_prefetcher = load_dataset() 41 | print("Load all datasets successfully.") 42 | 43 | espcn_model = build_model() 44 | print(f"Build `{config.model_arch_name}` model successfully.") 45 | 46 | criterion = define_loss() 47 | print("Define all loss functions successfully.") 48 | 49 | optimizer = define_optimizer(espcn_model) 50 | print("Define all optimizer functions successfully.") 51 | 52 | scheduler = define_scheduler(optimizer) 53 | print("Define all optimizer scheduler successfully.") 54 | 55 | print("Check whether to load pretrained model weights...") 56 | if config.pretrained_model_weights_path: 57 | espcn_model = load_state_dict(espcn_model, config.pretrained_model_weights_path, load_mode="pretrained") 58 | print(f"Loaded `{config.pretrained_model_weights_path}` pretrained model weights successfully.") 59 | else: 60 | print("Pretrained model weights not found.") 61 | 62 | print("Check whether the pretrained model is restored...") 63 | if config.resume_model_weights_path: 64 | espcn_model, _, start_epoch, best_psnr, best_ssim, optimizer, _ = load_state_dict( 65 | espcn_model, 66 | config.resume_model_weights_path, 67 | optimizer=optimizer, 68 | load_mode="resume") 69 | print("Loaded pretrained model weights.") 70 | else: 71 | print("Resume training model not found. Start training from scratch.") 72 | 73 | # Create a experiment results 74 | samples_dir = os.path.join("samples", config.exp_name) 75 | results_dir = os.path.join("results", config.exp_name) 76 | make_directory(samples_dir) 77 | make_directory(results_dir) 78 | 79 | # Create training process log file 80 | writer = SummaryWriter(os.path.join("samples", "logs", config.exp_name)) 81 | 82 | # Initialize the gradient scaler 83 | scaler = amp.GradScaler() 84 | 85 | # Create an IQA evaluation model 86 | psnr_model = PSNR(config.upscale_factor, config.only_test_y_channel) 87 | ssim_model = SSIM(config.upscale_factor, config.only_test_y_channel) 88 | 89 | # Transfer the IQA model to the specified device 90 | psnr_model = psnr_model.to(device=config.device) 91 | ssim_model = ssim_model.to(device=config.device) 92 | 93 | for epoch in range(start_epoch, config.epochs): 94 | train(espcn_model, 95 | train_prefetcher, 96 | criterion, 97 | optimizer, 98 | epoch, 99 | scaler, 100 | writer) 101 | psnr, ssim = validate(espcn_model, 102 | test_prefetcher, 103 | epoch, 104 | writer, 105 | psnr_model, 106 | ssim_model, 107 | "Test") 108 | print("\n") 109 | 110 | # Update lr 111 | scheduler.step() 112 | 113 | # Automatically save the model with the highest index 114 | is_best = psnr > best_psnr and ssim > best_ssim 115 | is_last = (epoch + 1) == config.epochs 116 | best_psnr = max(psnr, best_psnr) 117 | best_ssim = max(ssim, best_ssim) 118 | save_checkpoint({"epoch": epoch + 1, 119 | "best_psnr": best_psnr, 120 | "best_ssim": best_ssim, 121 | "state_dict": espcn_model.state_dict(), 122 | "optimizer": optimizer.state_dict()}, 123 | f"g_epoch_{epoch + 1}.pth.tar", 124 | samples_dir, 125 | results_dir, 126 | "g_best.pth.tar", 127 | "g_last.pth.tar", 128 | is_best, 129 | is_last) 130 | 131 | 132 | def load_dataset() -> [CUDAPrefetcher, CUDAPrefetcher]: 133 | # Load train, test and valid datasets 134 | train_datasets = TrainValidImageDataset(config.train_gt_images_dir, 135 | config.gt_image_size, 136 | config.upscale_factor, 137 | "Train") 138 | test_datasets = TestImageDataset(config.test_gt_images_dir, config.test_lr_images_dir) 139 | 140 | # Generator all dataloader 141 | train_dataloader = DataLoader(train_datasets, 142 | batch_size=config.batch_size, 143 | shuffle=True, 144 | num_workers=config.num_workers, 145 | pin_memory=True, 146 | drop_last=True, 147 | persistent_workers=True) 148 | test_dataloader = DataLoader(test_datasets, 149 | batch_size=1, 150 | shuffle=False, 151 | num_workers=1, 152 | pin_memory=True, 153 | drop_last=False, 154 | persistent_workers=True) 155 | 156 | # Place all data on the preprocessing data loader 157 | train_prefetcher = CUDAPrefetcher(train_dataloader, config.device) 158 | test_prefetcher = CUDAPrefetcher(test_dataloader, config.device) 159 | 160 | return train_prefetcher, test_prefetcher 161 | 162 | 163 | def build_model() -> nn.Module: 164 | espcn_model = model.__dict__[config.model_arch_name](in_channels=config.in_channels, 165 | out_channels=config.out_channels, 166 | channels=config.channels) 167 | espcn_model = espcn_model.to(device=config.device) 168 | 169 | return espcn_model 170 | 171 | 172 | def define_loss() -> nn.MSELoss: 173 | criterion = nn.MSELoss() 174 | criterion = criterion.to(device=config.device) 175 | 176 | return criterion 177 | 178 | 179 | def define_optimizer(espcn_model) -> optim.SGD: 180 | optimizer = optim.SGD(espcn_model.parameters(), 181 | lr=config.model_lr, 182 | momentum=config.model_momentum, 183 | weight_decay=config.model_weight_decay, 184 | nesterov=config.model_nesterov) 185 | 186 | return optimizer 187 | 188 | 189 | def define_scheduler(optimizer) -> lr_scheduler.MultiStepLR: 190 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.lr_scheduler_milestones, 191 | gamma=config.lr_scheduler_gamma) 192 | 193 | return scheduler 194 | 195 | 196 | def train( 197 | espcn_model: nn.Module, 198 | train_prefetcher: CUDAPrefetcher, 199 | criterion: nn.MSELoss, 200 | optimizer: optim.Adam, 201 | epoch: int, 202 | scaler: amp.GradScaler, 203 | writer: SummaryWriter 204 | ) -> None: 205 | # Calculate how many batches of data are in each Epoch 206 | batches = len(train_prefetcher) 207 | # Print information of progress bar during training 208 | batch_time = AverageMeter("Time", ":6.3f") 209 | data_time = AverageMeter("Data", ":6.3f") 210 | losses = AverageMeter("Loss", ":6.6f") 211 | progress = ProgressMeter(batches, [batch_time, data_time, losses], prefix=f"Epoch: [{epoch + 1}]") 212 | 213 | # Put the generative network model in training mode 214 | espcn_model.train() 215 | 216 | # Initialize the number of data batches to print logs on the terminal 217 | batch_index = 0 218 | 219 | # Initialize the data loader and load the first batch of data 220 | train_prefetcher.reset() 221 | batch_data = train_prefetcher.next() 222 | 223 | # Get the initialization training time 224 | end = time.time() 225 | 226 | while batch_data is not None: 227 | # Calculate the time it takes to load a batch of data 228 | data_time.update(time.time() - end) 229 | 230 | # Transfer in-memory data to CUDA devices to speed up training 231 | gt = batch_data["gt"].to(device=config.device, non_blocking=True) 232 | lr = batch_data["lr"].to(device=config.device, non_blocking=True) 233 | 234 | # Initialize generator gradients 235 | espcn_model.zero_grad(set_to_none=True) 236 | 237 | # Mixed precision training 238 | with amp.autocast(): 239 | sr = espcn_model(lr) 240 | loss = torch.mul(config.loss_weights, criterion(sr, gt)) 241 | 242 | # Backpropagation 243 | scaler.scale(loss).backward() 244 | # update generator weights 245 | scaler.step(optimizer) 246 | scaler.update() 247 | 248 | # Statistical loss value for terminal data output 249 | losses.update(loss.item(), lr.size(0)) 250 | 251 | # Calculate the time it takes to fully train a batch of data 252 | batch_time.update(time.time() - end) 253 | end = time.time() 254 | 255 | # Write the data during training to the training log file 256 | if batch_index % config.train_print_frequency == 0: 257 | # Record loss during training and output to file 258 | writer.add_scalar("Train/Loss", loss.item(), batch_index + epoch * batches + 1) 259 | progress.display(batch_index + 1) 260 | 261 | # Preload the next batch of data 262 | batch_data = train_prefetcher.next() 263 | 264 | # Add 1 to the number of data batches to ensure that the terminal prints data normally 265 | batch_index += 1 266 | 267 | 268 | def validate( 269 | espcn_model: nn.Module, 270 | data_prefetcher: CUDAPrefetcher, 271 | epoch: int, 272 | writer: SummaryWriter, 273 | psnr_model: nn.Module, 274 | ssim_model: nn.Module, 275 | mode: str 276 | ) -> [float, float]: 277 | # Calculate how many batches of data are in each Epoch 278 | batch_time = AverageMeter("Time", ":6.3f") 279 | psnres = AverageMeter("PSNR", ":4.2f") 280 | ssimes = AverageMeter("SSIM", ":4.4f") 281 | progress = ProgressMeter(len(data_prefetcher), [batch_time, psnres, ssimes], prefix=f"{mode}: ") 282 | 283 | # Put the adversarial network model in validation mode 284 | espcn_model.eval() 285 | 286 | # Initialize the number of data batches to print logs on the terminal 287 | batch_index = 0 288 | 289 | # Initialize the data loader and load the first batch of data 290 | data_prefetcher.reset() 291 | batch_data = data_prefetcher.next() 292 | 293 | # Get the initialization test time 294 | end = time.time() 295 | 296 | with torch.no_grad(): 297 | while batch_data is not None: 298 | # Transfer the in-memory data to the CUDA device to speed up the test 299 | gt = batch_data["gt"].to(device=config.device, non_blocking=True) 300 | lr = batch_data["lr"].to(device=config.device, non_blocking=True) 301 | 302 | # Use the generator model to generate a fake sample 303 | with amp.autocast(): 304 | sr = espcn_model(lr) 305 | 306 | # Statistical loss value for terminal data output 307 | psnr = psnr_model(sr, gt) 308 | ssim = ssim_model(sr, gt) 309 | psnres.update(psnr.item(), lr.size(0)) 310 | ssimes.update(ssim.item(), lr.size(0)) 311 | 312 | # Calculate the time it takes to fully test a batch of data 313 | batch_time.update(time.time() - end) 314 | end = time.time() 315 | 316 | # Record training log information 317 | if batch_index % config.test_print_frequency == 0: 318 | progress.display(batch_index + 1) 319 | 320 | # Preload the next batch of data 321 | batch_data = data_prefetcher.next() 322 | 323 | # After training a batch of data, add 1 to the number of data batches to ensure that the 324 | # terminal print data normally 325 | batch_index += 1 326 | 327 | # print metrics 328 | progress.display_summary() 329 | 330 | if mode == "Valid" or mode == "Test": 331 | writer.add_scalar(f"{mode}/PSNR", psnres.avg, epoch + 1) 332 | writer.add_scalar(f"{mode}/SSIM", ssimes.avg, epoch + 1) 333 | else: 334 | raise ValueError("Unsupported mode, please use `Valid` or `Test`.") 335 | 336 | return psnres.avg, ssimes.avg 337 | 338 | 339 | if __name__ == "__main__": 340 | main() 341 | -------------------------------------------------------------------------------- /imgproc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import math 15 | import random 16 | from typing import Any 17 | 18 | import cv2 19 | import numpy as np 20 | import torch 21 | from numpy import ndarray 22 | from torch import Tensor 23 | from torchvision.transforms import functional as F_vision 24 | 25 | __all__ = [ 26 | "image_to_tensor", "tensor_to_image", 27 | "image_resize", "preprocess_one_image", 28 | "expand_y", "rgb_to_ycbcr", "bgr_to_ycbcr", "ycbcr_to_bgr", "ycbcr_to_rgb", 29 | "rgb_to_ycbcr_torch", "bgr_to_ycbcr_torch", 30 | "center_crop", "random_crop", "random_rotate", "random_vertically_flip", "random_horizontally_flip", 31 | "center_crop_torch", "random_crop_torch", "random_rotate_torch", "random_vertically_flip_torch", 32 | "random_horizontally_flip_torch", 33 | ] 34 | 35 | 36 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 37 | def _cubic(x: Any) -> Any: 38 | """Implementation of `cubic` function in Matlab under Python language. 39 | 40 | Args: 41 | x: Element vector. 42 | 43 | Returns: 44 | Bicubic interpolation 45 | 46 | """ 47 | absx = torch.abs(x) 48 | absx2 = absx ** 2 49 | absx3 = absx ** 3 50 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( 51 | -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ( 52 | ((absx > 1) * (absx <= 2)).type_as(absx)) 53 | 54 | 55 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 56 | def _calculate_weights_indices(in_length: int, 57 | out_length: int, 58 | scale: float, 59 | kernel_width: int, 60 | antialiasing: bool) -> [np.ndarray, np.ndarray, int, int]: 61 | """Implementation of `calculate_weights_indices` function in Matlab under Python language. 62 | 63 | Args: 64 | in_length (int): Input length. 65 | out_length (int): Output length. 66 | scale (float): Scale factor. 67 | kernel_width (int): Kernel width. 68 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations. 69 | Caution: Bicubic down-sampling in PIL uses antialiasing by default. 70 | 71 | Returns: 72 | weights, indices, sym_len_s, sym_len_e 73 | 74 | """ 75 | if (scale < 1) and antialiasing: 76 | # Use a modified kernel (larger kernel width) to simultaneously 77 | # interpolate and antialiasing 78 | kernel_width = kernel_width / scale 79 | 80 | # Output-space coordinates 81 | x = torch.linspace(1, out_length, out_length) 82 | 83 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 84 | # in output space maps to 0.5 in input space, and 0.5 + scale in output 85 | # space maps to 1.5 in input space. 86 | u = x / scale + 0.5 * (1 - 1 / scale) 87 | 88 | # What is the left-most pixel that can be involved in the computation? 89 | left = torch.floor(u - kernel_width / 2) 90 | 91 | # What is the maximum number of pixels that can be involved in the 92 | # computation? Note: it's OK to use an extra pixel here; if the 93 | # corresponding weights are all zero, it will be eliminated at the end 94 | # of this function. 95 | p = math.ceil(kernel_width) + 2 96 | 97 | # The indices of the input pixels involved in computing the k-th output 98 | # pixel are in row k of the indices matrix. 99 | indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( 100 | out_length, p) 101 | 102 | # The weights used to compute the k-th output pixel are in row k of the 103 | # weights matrix. 104 | distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices 105 | 106 | # apply cubic kernel 107 | if (scale < 1) and antialiasing: 108 | weights = scale * _cubic(distance_to_center * scale) 109 | else: 110 | weights = _cubic(distance_to_center) 111 | 112 | # Normalize the weights matrix so that each row sums to 1. 113 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 114 | weights = weights / weights_sum.expand(out_length, p) 115 | 116 | # If a column in weights is all zero, get rid of it. only consider the 117 | # first and last column. 118 | weights_zero_tmp = torch.sum((weights == 0), 0) 119 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 120 | indices = indices.narrow(1, 1, p - 2) 121 | weights = weights.narrow(1, 1, p - 2) 122 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 123 | indices = indices.narrow(1, 0, p - 2) 124 | weights = weights.narrow(1, 0, p - 2) 125 | weights = weights.contiguous() 126 | indices = indices.contiguous() 127 | sym_len_s = -indices.min() + 1 128 | sym_len_e = indices.max() - in_length 129 | indices = indices + sym_len_s - 1 130 | return weights, indices, int(sym_len_s), int(sym_len_e) 131 | 132 | 133 | def image_to_tensor(image: ndarray, range_norm: bool, half: bool) -> Tensor: 134 | """Convert the image data type to the Tensor (NCWH) data type supported by PyTorch 135 | 136 | Args: 137 | image (np.ndarray): The image data read by ``OpenCV.imread``, the data range is [0,255] or [0, 1] 138 | range_norm (bool): Scale [0, 1] data to between [-1, 1] 139 | half (bool): Whether to convert torch.float32 similarly to torch.half type 140 | 141 | Returns: 142 | tensor (Tensor): Data types supported by PyTorch 143 | 144 | Examples: 145 | >>> example_image = cv2.imread("lr_image.bmp") 146 | >>> example_tensor = image_to_tensor(example_image, range_norm=True, half=False) 147 | 148 | """ 149 | # Convert image data type to Tensor data type 150 | tensor = F_vision.to_tensor(image) 151 | 152 | # Scale the image data from [0, 1] to [-1, 1] 153 | if range_norm: 154 | tensor = tensor.mul(2.0).sub(1.0) 155 | 156 | # Convert torch.float32 image data type to torch.half image data type 157 | if half: 158 | tensor = tensor.half() 159 | 160 | return tensor 161 | 162 | 163 | def tensor_to_image(tensor: Tensor, range_norm: bool, half: bool) -> Any: 164 | """Convert the Tensor(NCWH) data type supported by PyTorch to the np.ndarray(WHC) image data type 165 | 166 | Args: 167 | tensor (Tensor): Data types supported by PyTorch (NCHW), the data range is [0, 1] 168 | range_norm (bool): Scale [-1, 1] data to between [0, 1] 169 | half (bool): Whether to convert torch.float32 similarly to torch.half type. 170 | 171 | Returns: 172 | image (np.ndarray): Data types supported by PIL or OpenCV 173 | 174 | Examples: 175 | >>> example_image = cv2.imread("lr_image.bmp") 176 | >>> example_tensor = image_to_tensor(example_image, range_norm=False, half=False) 177 | 178 | """ 179 | if range_norm: 180 | tensor = tensor.add(1.0).div(2.0) 181 | if half: 182 | tensor = tensor.half() 183 | 184 | image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8") 185 | 186 | return image 187 | 188 | 189 | def preprocess_one_image(image_path: str, device: torch.device) -> [Tensor, ndarray, ndarray]: 190 | image = cv2.imread(image_path).astype(np.float32) / 255.0 191 | 192 | # BGR to YCbCr 193 | ycbcr_image = bgr_to_ycbcr(image, only_use_y_channel=False) 194 | 195 | # Split YCbCr image data 196 | y_image, cb_image, cr_image = cv2.split(ycbcr_image) 197 | 198 | # Convert image data to pytorch format data 199 | y_tensor = image_to_tensor(y_image, False, False).unsqueeze_(0) 200 | 201 | # Transfer tensor channel image format data to CUDA device 202 | y_tensor = y_tensor.to(device=device, non_blocking=True) 203 | 204 | return y_tensor, cb_image, cr_image 205 | 206 | 207 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 208 | def image_resize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any: 209 | """Implementation of `imresize` function in Matlab under Python language. 210 | 211 | Args: 212 | image: The input image. 213 | scale_factor (float): Scale factor. The same scale applies for both height and width. 214 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations. 215 | Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``. 216 | 217 | Returns: 218 | out_2 (np.ndarray): Output image with shape (c, h, w), [0, 1] range, w/o round 219 | 220 | """ 221 | squeeze_flag = False 222 | if type(image).__module__ == np.__name__: # numpy type 223 | numpy_type = True 224 | if image.ndim == 2: 225 | image = image[:, :, None] 226 | squeeze_flag = True 227 | image = torch.from_numpy(image.transpose(2, 0, 1)).float() 228 | else: 229 | numpy_type = False 230 | if image.ndim == 2: 231 | image = image.unsqueeze(0) 232 | squeeze_flag = True 233 | 234 | in_c, in_h, in_w = image.size() 235 | out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor) 236 | kernel_width = 4 237 | 238 | # get weights and indices 239 | weights_h, indices_h, sym_len_hs, sym_len_he = _calculate_weights_indices(in_h, out_h, scale_factor, kernel_width, 240 | antialiasing) 241 | weights_w, indices_w, sym_len_ws, sym_len_we = _calculate_weights_indices(in_w, out_w, scale_factor, kernel_width, 242 | antialiasing) 243 | # process H dimension 244 | # symmetric copying 245 | img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) 246 | img_aug.narrow(1, sym_len_hs, in_h).copy_(image) 247 | 248 | sym_patch = image[:, :sym_len_hs, :] 249 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 250 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 251 | img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) 252 | 253 | sym_patch = image[:, -sym_len_he:, :] 254 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 255 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 256 | img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) 257 | 258 | out_1 = torch.FloatTensor(in_c, out_h, in_w) 259 | kernel_width = weights_h.size(1) 260 | for i in range(out_h): 261 | idx = int(indices_h[i][0]) 262 | for j in range(in_c): 263 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) 264 | 265 | # process W dimension 266 | # symmetric copying 267 | out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) 268 | out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) 269 | 270 | sym_patch = out_1[:, :, :sym_len_ws] 271 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 272 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 273 | out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) 274 | 275 | sym_patch = out_1[:, :, -sym_len_we:] 276 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 277 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 278 | out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) 279 | 280 | out_2 = torch.FloatTensor(in_c, out_h, out_w) 281 | kernel_width = weights_w.size(1) 282 | for i in range(out_w): 283 | idx = int(indices_w[i][0]) 284 | for j in range(in_c): 285 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) 286 | 287 | if squeeze_flag: 288 | out_2 = out_2.squeeze(0) 289 | if numpy_type: 290 | out_2 = out_2.numpy() 291 | if not squeeze_flag: 292 | out_2 = out_2.transpose(1, 2, 0) 293 | 294 | return out_2 295 | 296 | 297 | def expand_y(image: np.ndarray) -> np.ndarray: 298 | """Convert BGR channel to YCbCr format, 299 | and expand Y channel data in YCbCr, from HW to HWC 300 | 301 | Args: 302 | image (np.ndarray): Y channel image data 303 | 304 | Returns: 305 | y_image (np.ndarray): Y-channel image data in HWC form 306 | 307 | """ 308 | # Normalize image data to [0, 1] 309 | image = image.astype(np.float32) / 255. 310 | 311 | # Convert BGR to YCbCr, and extract only Y channel 312 | y_image = bgr_to_ycbcr(image, only_use_y_channel=True) 313 | 314 | # Expand Y channel 315 | y_image = y_image[..., None] 316 | 317 | # Normalize the image data to [0, 255] 318 | y_image = y_image.astype(np.float64) * 255.0 319 | 320 | return y_image 321 | 322 | 323 | def rgb_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray: 324 | """Implementation of rgb2ycbcr function in Matlab under Python language 325 | 326 | Args: 327 | image (np.ndarray): Image input in RGB format. 328 | only_use_y_channel (bool): Extract Y channel separately 329 | 330 | Returns: 331 | image (np.ndarray): YCbCr image array data 332 | 333 | """ 334 | if only_use_y_channel: 335 | image = np.dot(image, [65.481, 128.553, 24.966]) + 16.0 336 | else: 337 | image = np.matmul(image, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [ 338 | 16, 128, 128] 339 | 340 | image /= 255. 341 | image = image.astype(np.float32) 342 | 343 | return image 344 | 345 | 346 | def bgr_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray: 347 | """Implementation of bgr2ycbcr function in Matlab under Python language. 348 | 349 | Args: 350 | image (np.ndarray): Image input in BGR format 351 | only_use_y_channel (bool): Extract Y channel separately 352 | 353 | Returns: 354 | image (np.ndarray): YCbCr image array data 355 | 356 | """ 357 | if only_use_y_channel: 358 | image = np.dot(image, [24.966, 128.553, 65.481]) + 16.0 359 | else: 360 | image = np.matmul(image, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [ 361 | 16, 128, 128] 362 | 363 | image /= 255. 364 | image = image.astype(np.float32) 365 | 366 | return image 367 | 368 | 369 | def ycbcr_to_rgb(image: np.ndarray) -> np.ndarray: 370 | """Implementation of ycbcr2rgb function in Matlab under Python language. 371 | 372 | Args: 373 | image (np.ndarray): Image input in YCbCr format. 374 | 375 | Returns: 376 | image (np.ndarray): RGB image array data 377 | 378 | """ 379 | image_dtype = image.dtype 380 | image *= 255. 381 | 382 | image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621], 383 | [0, -0.00153632, 0.00791071], 384 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 385 | 386 | image /= 255. 387 | image = image.astype(image_dtype) 388 | 389 | return image 390 | 391 | 392 | def ycbcr_to_bgr(image: np.ndarray) -> np.ndarray: 393 | """Implementation of ycbcr2bgr function in Matlab under Python language. 394 | 395 | Args: 396 | image (np.ndarray): Image input in YCbCr format. 397 | 398 | Returns: 399 | image (np.ndarray): BGR image array data 400 | 401 | """ 402 | image_dtype = image.dtype 403 | image *= 255. 404 | 405 | image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621], 406 | [0.00791071, -0.00153632, 0], 407 | [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] 408 | 409 | image /= 255. 410 | image = image.astype(image_dtype) 411 | 412 | return image 413 | 414 | 415 | def rgb_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor: 416 | """Implementation of rgb2ycbcr function in Matlab under PyTorch 417 | 418 | References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion` 419 | 420 | Args: 421 | tensor (Tensor): Image data in PyTorch format 422 | only_use_y_channel (bool): Extract only Y channel 423 | 424 | Returns: 425 | tensor (Tensor): YCbCr image data in PyTorch format 426 | 427 | """ 428 | if only_use_y_channel: 429 | weight = Tensor([[65.481], [128.553], [24.966]]).to(tensor) 430 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 431 | else: 432 | weight = Tensor([[65.481, -37.797, 112.0], 433 | [128.553, -74.203, -93.786], 434 | [24.966, 112.0, -18.214]]).to(tensor) 435 | bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor) 436 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias 437 | 438 | tensor /= 255. 439 | 440 | return tensor 441 | 442 | 443 | def bgr_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor: 444 | """Implementation of bgr2ycbcr function in Matlab under PyTorch 445 | 446 | References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion` 447 | 448 | Args: 449 | tensor (Tensor): Image data in PyTorch format 450 | only_use_y_channel (bool): Extract only Y channel 451 | 452 | Returns: 453 | tensor (Tensor): YCbCr image data in PyTorch format 454 | 455 | """ 456 | if only_use_y_channel: 457 | weight = Tensor([[24.966], [128.553], [65.481]]).to(tensor) 458 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 459 | else: 460 | weight = Tensor([[24.966, 112.0, -18.214], 461 | [128.553, -74.203, -93.786], 462 | [65.481, -37.797, 112.0]]).to(tensor) 463 | bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor) 464 | tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias 465 | 466 | tensor /= 255. 467 | 468 | return tensor 469 | 470 | 471 | def center_crop(image: np.ndarray, image_size: int) -> np.ndarray: 472 | """Crop small image patches from one image center area. 473 | 474 | Args: 475 | image (np.ndarray): The input image for `OpenCV.imread`. 476 | image_size (int): The size of the captured image area. 477 | 478 | Returns: 479 | patch_image (np.ndarray): Small patch image 480 | 481 | """ 482 | image_height, image_width = image.shape[:2] 483 | 484 | # Just need to find the top and left coordinates of the image 485 | top = (image_height - image_size) // 2 486 | left = (image_width - image_size) // 2 487 | 488 | # Crop image patch 489 | patch_image = image[top:top + image_size, left:left + image_size, ...] 490 | 491 | return patch_image 492 | 493 | 494 | def random_crop(image: np.ndarray, image_size: int) -> np.ndarray: 495 | """Crop small image patches from one image. 496 | 497 | Args: 498 | image (np.ndarray): The input image for `OpenCV.imread`. 499 | image_size (int): The size of the captured image area. 500 | 501 | Returns: 502 | patch_image (np.ndarray): Small patch image 503 | 504 | """ 505 | image_height, image_width = image.shape[:2] 506 | 507 | # Just need to find the top and left coordinates of the image 508 | top = random.randint(0, image_height - image_size) 509 | left = random.randint(0, image_width - image_size) 510 | 511 | # Crop image patch 512 | patch_image = image[top:top + image_size, left:left + image_size, ...] 513 | 514 | return patch_image 515 | 516 | 517 | def random_rotate(image, 518 | angles: list, 519 | center: tuple[int, int] = None, 520 | scale_factor: float = 1.0) -> np.ndarray: 521 | """Rotate an image by a random angle 522 | 523 | Args: 524 | image (np.ndarray): Image read with OpenCV 525 | angles (list): Rotation angle range 526 | center (optional, tuple[int, int]): High resolution image selection center point. Default: ``None`` 527 | scale_factor (optional, float): scaling factor. Default: 1.0 528 | 529 | Returns: 530 | rotated_image (np.ndarray): image after rotation 531 | 532 | """ 533 | image_height, image_width = image.shape[:2] 534 | 535 | if center is None: 536 | center = (image_width // 2, image_height // 2) 537 | 538 | # Random select specific angle 539 | angle = random.choice(angles) 540 | matrix = cv2.getRotationMatrix2D(center, angle, scale_factor) 541 | rotated_image = cv2.warpAffine(image, matrix, (image_width, image_height)) 542 | 543 | return rotated_image 544 | 545 | 546 | def random_horizontally_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray: 547 | """Flip the image upside down randomly 548 | 549 | Args: 550 | image (np.ndarray): Image read with OpenCV 551 | p (optional, float): Horizontally flip probability. Default: 0.5 552 | 553 | Returns: 554 | horizontally_flip_image (np.ndarray): image after horizontally flip 555 | 556 | """ 557 | if random.random() < p: 558 | horizontally_flip_image = cv2.flip(image, 1) 559 | else: 560 | horizontally_flip_image = image 561 | 562 | return horizontally_flip_image 563 | 564 | 565 | def random_vertically_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray: 566 | """Flip an image horizontally randomly 567 | 568 | Args: 569 | image (np.ndarray): Image read with OpenCV 570 | p (optional, float): Vertically flip probability. Default: 0.5 571 | 572 | Returns: 573 | vertically_flip_image (np.ndarray): image after vertically flip 574 | 575 | """ 576 | if random.random() < p: 577 | vertically_flip_image = cv2.flip(image, 0) 578 | else: 579 | vertically_flip_image = image 580 | 581 | return vertically_flip_image 582 | 583 | 584 | def center_crop_torch( 585 | gt_images: ndarray | Tensor | list[ndarray] | list[Tensor], 586 | lr_images: ndarray | Tensor | list[ndarray] | list[Tensor], 587 | gt_patch_size: int, 588 | upscale_factor: int, 589 | ) -> [ndarray, ndarray] or [Tensor, Tensor] or [list[ndarray], list[ndarray]] or [list[Tensor], list[Tensor]]: 590 | if not isinstance(gt_images, list): 591 | gt_images = [gt_images] 592 | if not isinstance(lr_images, list): 593 | lr_images = [lr_images] 594 | 595 | # Detect input image data type 596 | input_type = "Tensor" if torch.is_tensor(lr_images[0]) else "Numpy" 597 | 598 | if input_type == "Tensor": 599 | lr_image_height, lr_image_width = lr_images[0].size()[-2:] 600 | else: 601 | lr_image_height, lr_image_width = lr_images[0].shape[0:2] 602 | 603 | # Compute low-resolution image patch size 604 | lr_patch_size = gt_patch_size // upscale_factor 605 | 606 | # Calculate the start indices of the crop 607 | lr_top = (lr_image_height - lr_patch_size) // 2 608 | lr_left = (lr_image_width - lr_patch_size) // 2 609 | 610 | # Crop lr image patch 611 | if input_type == "Tensor": 612 | lr_images = [lr_image[ 613 | :, 614 | :, 615 | lr_top:lr_top + lr_patch_size, 616 | lr_left:lr_left + lr_patch_size] for lr_image in lr_images] 617 | else: 618 | lr_images = [lr_image[ 619 | lr_top:lr_top + lr_patch_size, 620 | lr_left:lr_left + lr_patch_size, 621 | ...] for lr_image in lr_images] 622 | 623 | # Crop gt image patch 624 | gt_top, gt_left = int(lr_top * upscale_factor), int(lr_left * upscale_factor) 625 | 626 | if input_type == "Tensor": 627 | gt_images = [v[ 628 | :, 629 | :, 630 | gt_top:gt_top + gt_patch_size, 631 | gt_left:gt_left + gt_patch_size] for v in gt_images] 632 | else: 633 | gt_images = [v[ 634 | gt_top:gt_top + gt_patch_size, 635 | gt_left:gt_left + gt_patch_size, 636 | ...] for v in gt_images] 637 | 638 | # When image number is 1 639 | if len(gt_images) == 1: 640 | gt_images = gt_images[0] 641 | if len(lr_images) == 1: 642 | lr_images = lr_images[0] 643 | 644 | return gt_images, lr_images 645 | 646 | 647 | def random_crop_torch( 648 | gt_images: ndarray | Tensor | list[ndarray] | list[Tensor], 649 | lr_images: ndarray | Tensor | list[ndarray] | list[Tensor], 650 | gt_patch_size: int, 651 | upscale_factor: int, 652 | ) -> [ndarray, ndarray] or [Tensor, Tensor] or [list[ndarray], list[ndarray]] or [list[Tensor], list[Tensor]]: 653 | if not isinstance(gt_images, list): 654 | gt_images = [gt_images] 655 | if not isinstance(lr_images, list): 656 | lr_images = [lr_images] 657 | 658 | # Detect input image data type 659 | input_type = "Tensor" if torch.is_tensor(lr_images[0]) else "Numpy" 660 | 661 | if input_type == "Tensor": 662 | lr_image_height, lr_image_width = lr_images[0].size()[-2:] 663 | else: 664 | lr_image_height, lr_image_width = lr_images[0].shape[0:2] 665 | 666 | # Compute low-resolution image patch size 667 | lr_patch_size = gt_patch_size // upscale_factor 668 | 669 | # Just need to find the top and left coordinates of the image 670 | lr_top = random.randint(0, lr_image_height - lr_patch_size) 671 | lr_left = random.randint(0, lr_image_width - lr_patch_size) 672 | 673 | # Crop lr image patch 674 | if input_type == "Tensor": 675 | lr_images = [lr_image[ 676 | :, 677 | :, 678 | lr_top:lr_top + lr_patch_size, 679 | lr_left:lr_left + lr_patch_size] for lr_image in lr_images] 680 | else: 681 | lr_images = [lr_image[ 682 | lr_top:lr_top + lr_patch_size, 683 | lr_left:lr_left + lr_patch_size, 684 | ...] for lr_image in lr_images] 685 | 686 | # Crop gt image patch 687 | gt_top, gt_left = int(lr_top * upscale_factor), int(lr_left * upscale_factor) 688 | 689 | if input_type == "Tensor": 690 | gt_images = [v[ 691 | :, 692 | :, 693 | gt_top:gt_top + gt_patch_size, 694 | gt_left:gt_left + gt_patch_size] for v in gt_images] 695 | else: 696 | gt_images = [v[ 697 | gt_top:gt_top + gt_patch_size, 698 | gt_left:gt_left + gt_patch_size, 699 | ...] for v in gt_images] 700 | 701 | # When image number is 1 702 | if len(gt_images) == 1: 703 | gt_images = gt_images[0] 704 | if len(lr_images) == 1: 705 | lr_images = lr_images[0] 706 | 707 | return gt_images, lr_images 708 | 709 | 710 | def random_rotate_torch( 711 | gt_images: ndarray | Tensor | list[ndarray] | list[Tensor], 712 | lr_images: ndarray | Tensor | list[ndarray] | list[Tensor], 713 | upscale_factor: int, 714 | angles: list, 715 | gt_center: tuple = None, 716 | lr_center: tuple = None, 717 | rotate_scale_factor: float = 1.0 718 | ) -> [ndarray, ndarray] or [Tensor, Tensor] or [list[ndarray], list[ndarray]] or [list[Tensor], list[Tensor]]: 719 | # Random select specific angle 720 | angle = random.choice(angles) 721 | 722 | if not isinstance(gt_images, list): 723 | gt_images = [gt_images] 724 | if not isinstance(lr_images, list): 725 | lr_images = [lr_images] 726 | 727 | # Detect input image data type 728 | input_type = "Tensor" if torch.is_tensor(lr_images[0]) else "Numpy" 729 | 730 | if input_type == "Tensor": 731 | lr_image_height, lr_image_width = lr_images[0].size()[-2:] 732 | else: 733 | lr_image_height, lr_image_width = lr_images[0].shape[0:2] 734 | 735 | # Rotate LR image 736 | if lr_center is None: 737 | lr_center = [lr_image_width // 2, lr_image_height // 2] 738 | 739 | lr_matrix = cv2.getRotationMatrix2D(lr_center, angle, rotate_scale_factor) 740 | 741 | if input_type == "Tensor": 742 | lr_images = [F_vision.rotate(lr_image, angle, center=lr_center) for lr_image in lr_images] 743 | else: 744 | lr_images = [cv2.warpAffine(lr_image, lr_matrix, (lr_image_width, lr_image_height)) for lr_image in lr_images] 745 | 746 | # Rotate GT image 747 | gt_image_width = int(lr_image_width * upscale_factor) 748 | gt_image_height = int(lr_image_height * upscale_factor) 749 | 750 | if gt_center is None: 751 | gt_center = [gt_image_width // 2, gt_image_height // 2] 752 | 753 | gt_matrix = cv2.getRotationMatrix2D(gt_center, angle, rotate_scale_factor) 754 | 755 | if input_type == "Tensor": 756 | gt_images = [F_vision.rotate(gt_image, angle, center=gt_center) for gt_image in gt_images] 757 | else: 758 | gt_images = [cv2.warpAffine(gt_image, gt_matrix, (gt_image_width, gt_image_height)) for gt_image in gt_images] 759 | 760 | # When image number is 1 761 | if len(gt_images) == 1: 762 | gt_images = gt_images[0] 763 | if len(lr_images) == 1: 764 | lr_images = lr_images[0] 765 | 766 | return gt_images, lr_images 767 | 768 | 769 | def random_horizontally_flip_torch( 770 | gt_images: ndarray | Tensor | list[ndarray] | list[Tensor], 771 | lr_images: ndarray | Tensor | list[ndarray] | list[Tensor], 772 | p: float = 0.5 773 | ) -> [ndarray, ndarray] or [Tensor, Tensor] or [list[ndarray], list[ndarray]] or [list[Tensor], list[Tensor]]: 774 | # Get horizontal flip probability 775 | flip_prob = random.random() 776 | 777 | if not isinstance(gt_images, list): 778 | gt_images = [gt_images] 779 | if not isinstance(lr_images, list): 780 | lr_images = [lr_images] 781 | 782 | # Detect input image data type 783 | input_type = "Tensor" if torch.is_tensor(lr_images[0]) else "Numpy" 784 | 785 | if flip_prob > p: 786 | if input_type == "Tensor": 787 | lr_images = [F_vision.hflip(lr_image) for lr_image in lr_images] 788 | gt_images = [F_vision.hflip(gt_image) for gt_image in gt_images] 789 | else: 790 | lr_images = [cv2.flip(lr_image, 1) for lr_image in lr_images] 791 | gt_images = [cv2.flip(gt_image, 1) for gt_image in gt_images] 792 | 793 | # When image number is 1 794 | if len(gt_images) == 1: 795 | gt_images = gt_images[0] 796 | if len(lr_images) == 1: 797 | lr_images = lr_images[0] 798 | 799 | return gt_images, lr_images 800 | 801 | 802 | def random_vertically_flip_torch( 803 | gt_images: ndarray | Tensor | list[ndarray] | list[Tensor], 804 | lr_images: ndarray | Tensor | list[ndarray] | list[Tensor], 805 | p: float = 0.5 806 | ) -> [ndarray, ndarray] or [Tensor, Tensor] or [list[ndarray], list[ndarray]] or [list[Tensor], list[Tensor]]: 807 | # Get vertical flip probability 808 | flip_prob = random.random() 809 | 810 | if not isinstance(gt_images, list): 811 | gt_images = [gt_images] 812 | if not isinstance(lr_images, list): 813 | lr_images = [lr_images] 814 | 815 | # Detect input image data type 816 | input_type = "Tensor" if torch.is_tensor(lr_images[0]) else "Numpy" 817 | 818 | if flip_prob > p: 819 | if input_type == "Tensor": 820 | lr_images = [F_vision.vflip(lr_image) for lr_image in lr_images] 821 | gt_images = [F_vision.vflip(gt_image) for gt_image in gt_images] 822 | else: 823 | lr_images = [cv2.flip(lr_image, 0) for lr_image in lr_images] 824 | gt_images = [cv2.flip(gt_image, 0) for gt_image in gt_images] 825 | 826 | # When image number is 1 827 | if len(gt_images) == 1: 828 | gt_images = gt_images[0] 829 | if len(lr_images) == 1: 830 | lr_images = lr_images[0] 831 | 832 | return gt_images, lr_images 833 | -------------------------------------------------------------------------------- /image_quality_assessment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import collections.abc 15 | import math 16 | import typing 17 | import warnings 18 | from itertools import repeat 19 | from typing import Any 20 | 21 | import cv2 22 | import numpy as np 23 | import torch 24 | from numpy import ndarray 25 | from scipy.io import loadmat 26 | from scipy.ndimage.filters import convolve 27 | from scipy.special import gamma 28 | from torch import nn 29 | from torch.nn import functional as F 30 | 31 | from imgproc import image_resize, expand_y, bgr_to_ycbcr, rgb_to_ycbcr_torch 32 | 33 | __all__ = [ 34 | "psnr", "ssim", "niqe", 35 | "PSNR", "SSIM", "NIQE", 36 | ] 37 | 38 | _I = typing.Optional[int] 39 | _D = typing.Optional[torch.dtype] 40 | 41 | 42 | # The following is the implementation of IQA method in Python, using CPU as processing device 43 | def _check_image(raw_image: np.ndarray, dst_image: np.ndarray): 44 | """Check whether the size and type of the two images are the same 45 | 46 | Args: 47 | raw_image (np.ndarray): image data to be compared, BGR format, data range [0, 255] 48 | dst_image (np.ndarray): reference image data, BGR format, data range [0, 255] 49 | 50 | """ 51 | # check image scale 52 | assert raw_image.shape == dst_image.shape, \ 53 | f"Supplied images have different sizes {str(raw_image.shape)} and {str(dst_image.shape)}" 54 | 55 | # check image type 56 | if raw_image.dtype != dst_image.dtype: 57 | warnings.warn(f"Supplied images have different dtypes{str(raw_image.shape)} and {str(dst_image.shape)}") 58 | 59 | 60 | def psnr(raw_image: np.ndarray, dst_image: np.ndarray, crop_border: int, only_test_y_channel: bool) -> float: 61 | """Python implements PSNR (Peak Signal-to-Noise Ratio, peak signal-to-noise ratio) function 62 | 63 | Args: 64 | raw_image (np.ndarray): image data to be compared, BGR format, data range [0, 255] 65 | dst_image (np.ndarray): reference image data, BGR format, data range [0, 255] 66 | crop_border (int): crop border a few pixels 67 | only_test_y_channel (bool): Whether to test only the Y channel of the image. 68 | 69 | Returns: 70 | psnr_metrics (np.float64): PSNR metrics 71 | 72 | """ 73 | # Check if two images are similar in scale and type 74 | _check_image(raw_image, dst_image) 75 | 76 | # crop border pixels 77 | if crop_border > 0: 78 | raw_image = raw_image[crop_border:-crop_border, crop_border:-crop_border, ...] 79 | dst_image = dst_image[crop_border:-crop_border, crop_border:-crop_border, ...] 80 | 81 | # If you only test the Y channel, you need to extract the Y channel data of the YCbCr channel data separately 82 | if only_test_y_channel: 83 | raw_image = expand_y(raw_image) 84 | dst_image = expand_y(dst_image) 85 | 86 | # Convert data type to numpy.float64 bit 87 | raw_image = raw_image.astype(np.float64) 88 | dst_image = dst_image.astype(np.float64) 89 | 90 | psnr_metrics = 10 * np.log10((255.0 ** 2) / np.mean((raw_image - dst_image) ** 2) + 1e-8) 91 | 92 | return psnr_metrics 93 | 94 | 95 | def _ssim(raw_image: np.ndarray, dst_image: np.ndarray) -> float: 96 | """Python implements the SSIM (Structural Similarity) function, which only calculates single-channel data 97 | 98 | Args: 99 | raw_image (np.ndarray): The image data to be compared, in BGR format, the data range is [0, 255] 100 | dst_image (np.ndarray): reference image data, BGR format, data range is [0, 255] 101 | 102 | Returns: 103 | ssim_metrics (float): SSIM metrics for single channel 104 | 105 | """ 106 | c1 = (0.01 * 255.0) ** 2 107 | c2 = (0.03 * 255.0) ** 2 108 | 109 | kernel = cv2.getGaussianKernel(11, 1.5) 110 | kernel_window = np.outer(kernel, kernel.transpose()) 111 | 112 | raw_mean = cv2.filter2D(raw_image, -1, kernel_window)[5:-5, 5:-5] 113 | dst_mean = cv2.filter2D(dst_image, -1, kernel_window)[5:-5, 5:-5] 114 | raw_mean_square = raw_mean ** 2 115 | dst_mean_square = dst_mean ** 2 116 | raw_dst_mean = raw_mean * dst_mean 117 | raw_variance = cv2.filter2D(raw_image ** 2, -1, kernel_window)[5:-5, 5:-5] - raw_mean_square 118 | dst_variance = cv2.filter2D(dst_image ** 2, -1, kernel_window)[5:-5, 5:-5] - dst_mean_square 119 | raw_dst_covariance = cv2.filter2D(raw_image * dst_image, -1, kernel_window)[5:-5, 5:-5] - raw_dst_mean 120 | 121 | ssim_molecular = (2 * raw_dst_mean + c1) * (2 * raw_dst_covariance + c2) 122 | ssim_denominator = (raw_mean_square + dst_mean_square + c1) * (raw_variance + dst_variance + c2) 123 | 124 | ssim_metrics = ssim_molecular / ssim_denominator 125 | ssim_metrics = float(np.mean(ssim_metrics)) 126 | 127 | return ssim_metrics 128 | 129 | 130 | def ssim(raw_image: np.ndarray, dst_image: np.ndarray, crop_border: int, only_test_y_channel: bool) -> float: 131 | """Python implements the SSIM (Structural Similarity) function, which calculates single/multi-channel data 132 | 133 | Args: 134 | raw_image (np.ndarray): The image data to be compared, in BGR format, the data range is [0, 255] 135 | dst_image (np.ndarray): reference image data, BGR format, data range is [0, 255] 136 | crop_border (int): crop border a few pixels 137 | only_test_y_channel (bool): Whether to test only the Y channel of the image 138 | 139 | Returns: 140 | ssim_metrics (float): SSIM metrics for single channel 141 | 142 | """ 143 | # Check if two images are similar in scale and type 144 | _check_image(raw_image, dst_image) 145 | 146 | # crop border pixels 147 | if crop_border > 0: 148 | raw_image = raw_image[crop_border:-crop_border, crop_border:-crop_border, ...] 149 | dst_image = dst_image[crop_border:-crop_border, crop_border:-crop_border, ...] 150 | 151 | # If you only test the Y channel, you need to extract the Y channel data of the YCbCr channel data separately 152 | if only_test_y_channel: 153 | raw_image = expand_y(raw_image) 154 | dst_image = expand_y(dst_image) 155 | 156 | # Convert data type to numpy.float64 bit 157 | raw_image = raw_image.astype(np.float64) 158 | dst_image = dst_image.astype(np.float64) 159 | 160 | channels_ssim_metrics = [] 161 | for channel in range(raw_image.shape[2]): 162 | ssim_metrics = _ssim(raw_image[..., channel], dst_image[..., channel]) 163 | channels_ssim_metrics.append(ssim_metrics) 164 | ssim_metrics = np.mean(np.asarray(channels_ssim_metrics)) 165 | 166 | return float(ssim_metrics) 167 | 168 | 169 | def _estimate_aggd_parameters(vector: np.ndarray) -> [np.ndarray, float, float]: 170 | """Python implements the NIQE (Natural Image Quality Evaluator) function, 171 | This function is used to estimate an asymmetric generalized Gaussian distribution 172 | 173 | Reference papers: 174 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video` 175 | 176 | Args: 177 | vector (np.ndarray): data vector 178 | 179 | Returns: 180 | aggd_parameters (np.ndarray): asymmetric generalized Gaussian distribution 181 | left_beta (float): symmetric left data vector variance mean product 182 | right_beta (float): symmetric right side data vector variance mean product 183 | 184 | """ 185 | # The following is obtained according to the formula and the method provided in the paper on WIki encyclopedia 186 | vector = vector.flatten() 187 | gam = np.arange(0.2, 10.001, 0.001) # len = 9801 188 | gam_reciprocal = np.reciprocal(gam) 189 | r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) 190 | 191 | left_std = np.sqrt(np.mean(vector[vector < 0] ** 2)) 192 | right_std = np.sqrt(np.mean(vector[vector > 0] ** 2)) 193 | gamma_hat = left_std / right_std 194 | rhat = (np.mean(np.abs(vector))) ** 2 / np.mean(vector ** 2) 195 | rhat_norm = (rhat * (gamma_hat ** 3 + 1) * (gamma_hat + 1)) / ((gamma_hat ** 2 + 1) ** 2) 196 | array_position = np.argmin((r_gam - rhat_norm) ** 2) 197 | 198 | aggd_parameters = gam[array_position] 199 | left_beta = left_std * np.sqrt(gamma(1 / aggd_parameters) / gamma(3 / aggd_parameters)) 200 | right_beta = right_std * np.sqrt(gamma(1 / aggd_parameters) / gamma(3 / aggd_parameters)) 201 | 202 | return aggd_parameters, left_beta, right_beta 203 | 204 | 205 | def _get_mscn_feature(image: np.ndarray) -> list[float | Any]: 206 | """Python implements the NIQE (Natural Image Quality Evaluator) function, 207 | This function is used to calculate the MSCN feature map 208 | 209 | Reference papers: 210 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video` 211 | 212 | Args: 213 | image (np.ndarray): Grayscale image of MSCN feature to be calculated, BGR format, data range is [0, 255] 214 | 215 | Returns: 216 | mscn_feature (np.ndarray): MSCN feature map of the image 217 | 218 | """ 219 | mscn_feature = [] 220 | # Calculate the asymmetric generalized Gaussian distribution 221 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters(image) 222 | mscn_feature.extend([aggd_parameters, (left_beta + right_beta) / 2]) 223 | 224 | shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] 225 | for i in range(len(shifts)): 226 | shifted_block = np.roll(image, shifts[i], axis=(0, 1)) 227 | # Calculate the asymmetric generalized Gaussian distribution 228 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters(image * shifted_block) 229 | mean = (right_beta - left_beta) * (gamma(2 / aggd_parameters) / gamma(1 / aggd_parameters)) 230 | mscn_feature.extend([aggd_parameters, mean, left_beta, right_beta]) 231 | 232 | return mscn_feature 233 | 234 | 235 | def _fit_mscn_ipac(image: np.ndarray, 236 | mu_pris_param: np.ndarray, 237 | cov_pris_param: np.ndarray, 238 | gaussian_window: np.ndarray, 239 | block_size_height: int, 240 | block_size_width: int) -> float: 241 | """Python implements the NIQE (Natural Image Quality Evaluator) function, 242 | This function is used to fit the inner product of adjacent coefficients of MSCN 243 | 244 | Reference papers: 245 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video` 246 | 247 | Args: 248 | image (np.ndarray): The image data of the NIQE to be tested, in BGR format, the data range is [0, 255] 249 | mu_pris_param (np.ndarray): Mean of predefined multivariate Gaussians, model computed on original dataset. 250 | cov_pris_param (np.ndarray): Covariance of predefined multivariate Gaussian model computed on original dataset. 251 | gaussian_window (np.ndarray): 7x7 Gaussian window for smoothing the image 252 | block_size_height (int): the height of the block into which the image is divided 253 | block_size_width (int): The width of the block into which the image is divided 254 | 255 | Returns: 256 | niqe_metric (np.ndarray): NIQE score 257 | 258 | """ 259 | image_height, image_width = image.shape 260 | num_block_height = math.floor(image_height / block_size_height) 261 | num_block_width = math.floor(image_width / block_size_width) 262 | image = image[0:num_block_height * block_size_height, 0:num_block_width * block_size_width] 263 | 264 | features_parameters = [] 265 | for scale in (1, 2): 266 | mu = convolve(image, gaussian_window, mode="nearest") 267 | sigma = np.sqrt(np.abs(convolve(np.square(image), gaussian_window, mode="nearest") - np.square(mu))) 268 | image_norm = (image - mu) / (sigma + 1) 269 | 270 | feature = [] 271 | for idx_w in range(num_block_width): 272 | for idx_h in range(num_block_height): 273 | vector = image_norm[ 274 | idx_h * block_size_height // scale:(idx_h + 1) * block_size_height // scale, 275 | idx_w * block_size_width // scale:(idx_w + 1) * block_size_width // scale] 276 | feature.append(_get_mscn_feature(vector)) 277 | 278 | features_parameters.append(np.array(feature)) 279 | 280 | if scale == 1: 281 | image = image_resize(image / 255., scale_factor=0.5, antialiasing=True) 282 | image = image * 255. 283 | 284 | features_parameters = np.concatenate(features_parameters, axis=1) 285 | 286 | # Fitting a multivariate Gaussian kernel model to distorted patch features 287 | mu_distparam = np.nanmean(features_parameters, axis=0) 288 | distparam_no_nan = features_parameters[~np.isnan(features_parameters).any(axis=1)] 289 | cov_distparam = np.cov(distparam_no_nan, rowvar=False) 290 | 291 | invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) 292 | niqe_metric = np.matmul(np.matmul((mu_pris_param - mu_distparam), invcov_param), 293 | np.transpose((mu_pris_param - mu_distparam))) 294 | 295 | niqe_metric = np.sqrt(niqe_metric) 296 | niqe_metric = float(np.squeeze(niqe_metric)) 297 | 298 | return niqe_metric 299 | 300 | 301 | def niqe(image: np.ndarray, 302 | crop_border: int, 303 | niqe_model_path: str, 304 | block_size_height: int = 96, 305 | block_size_width: int = 96) -> float: 306 | """Python implements the NIQE (Natural Image Quality Evaluator) function, 307 | This function computes single/multi-channel data 308 | 309 | Args: 310 | image (np.ndarray): The image data to be compared, in BGR format, the data range is [0, 255] 311 | crop_border (int): crop border a few pixels 312 | niqe_model_path: NIQE estimator model address 313 | block_size_height (int): The height of the block the image is divided into. Default: 96 314 | block_size_width (int): The width of the block the image is divided into. Default: 96 315 | 316 | Returns: 317 | niqe_metrics (float): NIQE indicator under single channel 318 | 319 | """ 320 | # crop border pixels 321 | if crop_border > 0: 322 | image = image[crop_border:-crop_border, crop_border:-crop_border, ...] 323 | 324 | # Defining the NIQE Feature Extraction Model 325 | niqe_model = np.load(niqe_model_path) 326 | 327 | mu_pris_param = niqe_model["mu_pris_param"] 328 | cov_pris_param = niqe_model["cov_pris_param"] 329 | gaussian_window = niqe_model["gaussian_window"] 330 | 331 | # NIQE only tests on Y channel images and needs to convert the images 332 | y_image = bgr_to_ycbcr(image, only_use_y_channel=True) 333 | 334 | # Convert data type to numpy.float64 bit 335 | y_image = y_image.astype(np.float64) 336 | 337 | niqe_metric = _fit_mscn_ipac(y_image, 338 | mu_pris_param, 339 | cov_pris_param, 340 | gaussian_window, 341 | block_size_height, 342 | block_size_width) 343 | 344 | return niqe_metric 345 | 346 | 347 | # The following is the IQA method implemented by PyTorch, using CUDA as the processing device 348 | def _check_tensor_shape(raw_tensor: torch.Tensor, dst_tensor: torch.Tensor): 349 | """Check if the dimensions of the two tensors are the same 350 | 351 | Args: 352 | raw_tensor (np.ndarray or torch.Tensor): image tensor flow to be compared, RGB format, data range [0, 1] 353 | dst_tensor (np.ndarray or torch.Tensor): reference image tensorflow, RGB format, data range [0, 1] 354 | 355 | """ 356 | # Check if tensor scales are consistent 357 | assert raw_tensor.shape == dst_tensor.shape, \ 358 | f"Supplied images have different sizes {str(raw_tensor.shape)} and {str(dst_tensor.shape)}" 359 | 360 | 361 | def _psnr_torch(raw_tensor: torch.Tensor, dst_tensor: torch.Tensor, crop_border: int, 362 | only_test_y_channel: bool) -> float: 363 | """PyTorch implements PSNR (Peak Signal-to-Noise Ratio, peak signal-to-noise ratio) function 364 | 365 | Args: 366 | raw_tensor (torch.Tensor): image tensor flow to be compared, RGB format, data range [0, 1] 367 | dst_tensor (torch.Tensor): reference image tensorflow, RGB format, data range [0, 1] 368 | crop_border (int): crop border a few pixels 369 | only_test_y_channel (bool): Whether to test only the Y channel of the image 370 | 371 | Returns: 372 | psnr_metrics (torch.Tensor): PSNR metrics 373 | 374 | """ 375 | # Check if two tensor scales are similar 376 | _check_tensor_shape(raw_tensor, dst_tensor) 377 | 378 | # crop border pixels 379 | if crop_border > 0: 380 | raw_tensor = raw_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border] 381 | dst_tensor = dst_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border] 382 | 383 | # Convert RGB tensor data to YCbCr tensor, and extract only Y channel data 384 | if only_test_y_channel: 385 | raw_tensor = rgb_to_ycbcr_torch(raw_tensor, only_use_y_channel=True) 386 | dst_tensor = rgb_to_ycbcr_torch(dst_tensor, only_use_y_channel=True) 387 | 388 | # Convert data type to torch.float64 bit 389 | raw_tensor = raw_tensor.to(torch.float64) 390 | dst_tensor = dst_tensor.to(torch.float64) 391 | 392 | mse_value = torch.mean((raw_tensor * 255.0 - dst_tensor * 255.0) ** 2 + 1e-8, dim=[1, 2, 3]) 393 | psnr_metrics = 10 * torch.log10_(255.0 ** 2 / mse_value) 394 | 395 | return psnr_metrics 396 | 397 | 398 | class PSNR(nn.Module): 399 | """PyTorch implements PSNR (Peak Signal-to-Noise Ratio, peak signal-to-noise ratio) function 400 | 401 | Attributes: 402 | crop_border (int): crop border a few pixels 403 | only_test_y_channel (bool): Whether to test only the Y channel of the image 404 | 405 | Returns: 406 | psnr_metrics (torch.Tensor): PSNR metrics 407 | 408 | """ 409 | 410 | def __init__(self, crop_border: int, only_test_y_channel: bool) -> None: 411 | super().__init__() 412 | self.crop_border = crop_border 413 | self.only_test_y_channel = only_test_y_channel 414 | 415 | def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor) -> float: 416 | psnr_metrics = _psnr_torch(raw_tensor, dst_tensor, self.crop_border, self.only_test_y_channel) 417 | 418 | return psnr_metrics 419 | 420 | 421 | def _ssim_torch(raw_tensor: torch.Tensor, 422 | dst_tensor: torch.Tensor, 423 | window_size: int, 424 | gaussian_kernel_window: np.ndarray) -> float: 425 | """PyTorch implements the SSIM (Structural Similarity) function, which only calculates single-channel data 426 | 427 | Args: 428 | raw_tensor (torch.Tensor): image tensor flow to be compared, RGB format, data range [0, 255] 429 | dst_tensor (torch.Tensor): reference image tensorflow, RGB format, data range [0, 255] 430 | window_size (int): Gaussian filter size 431 | gaussian_kernel_window (np.ndarray): Gaussian filter 432 | 433 | Returns: 434 | ssim_metrics (torch.Tensor): SSIM metrics 435 | 436 | """ 437 | c1 = (0.01 * 255.0) ** 2 438 | c2 = (0.03 * 255.0) ** 2 439 | 440 | gaussian_kernel_window = torch.from_numpy(gaussian_kernel_window).view(1, 1, window_size, window_size) 441 | gaussian_kernel_window = gaussian_kernel_window.expand(raw_tensor.size(1), 1, window_size, window_size) 442 | gaussian_kernel_window = gaussian_kernel_window.to(device=raw_tensor.device, dtype=raw_tensor.dtype) 443 | 444 | raw_mean = F.conv2d(raw_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0), groups=raw_tensor.shape[1]) 445 | dst_mean = F.conv2d(dst_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0), groups=dst_tensor.shape[1]) 446 | raw_mean_square = raw_mean ** 2 447 | dst_mean_square = dst_mean ** 2 448 | raw_dst_mean = raw_mean * dst_mean 449 | raw_variance = F.conv2d(raw_tensor * raw_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0), 450 | groups=raw_tensor.shape[1]) - raw_mean_square 451 | dst_variance = F.conv2d(dst_tensor * dst_tensor, gaussian_kernel_window, stride=(1, 1), padding=(0, 0), 452 | groups=raw_tensor.shape[1]) - dst_mean_square 453 | raw_dst_covariance = F.conv2d(raw_tensor * dst_tensor, gaussian_kernel_window, stride=1, padding=(0, 0), 454 | groups=raw_tensor.shape[1]) - raw_dst_mean 455 | 456 | ssim_molecular = (2 * raw_dst_mean + c1) * (2 * raw_dst_covariance + c2) 457 | ssim_denominator = (raw_mean_square + dst_mean_square + c1) * (raw_variance + dst_variance + c2) 458 | 459 | ssim_metrics = ssim_molecular / ssim_denominator 460 | ssim_metrics = torch.mean(ssim_metrics, [1, 2, 3]).float() 461 | 462 | return ssim_metrics 463 | 464 | 465 | def _ssim_single_torch(raw_tensor: torch.Tensor, 466 | dst_tensor: torch.Tensor, 467 | crop_border: int, 468 | only_test_y_channel: bool, 469 | window_size: int, 470 | gaussian_kernel_window: ndarray) -> float: 471 | """PyTorch implements the SSIM (Structural Similarity) function, which only calculates single-channel data 472 | 473 | Args: 474 | raw_tensor (Tensor): image tensor flow to be compared, RGB format, data range [0, 1] 475 | dst_tensor (Tensor): reference image tensorflow, RGB format, data range [0, 1] 476 | crop_border (int): crop border a few pixels 477 | only_test_y_channel (bool): Whether to test only the Y channel of the image 478 | window_size (int): Gaussian filter size 479 | gaussian_kernel_window (ndarray): Gaussian filter 480 | 481 | Returns: 482 | ssim_metrics (torch.Tensor): SSIM metrics 483 | 484 | """ 485 | # Check if two tensor scales are similar 486 | _check_tensor_shape(raw_tensor, dst_tensor) 487 | 488 | # crop border pixels 489 | if crop_border > 0: 490 | raw_tensor = raw_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border] 491 | dst_tensor = dst_tensor[:, :, crop_border:-crop_border, crop_border:-crop_border] 492 | 493 | # Convert RGB tensor data to YCbCr tensor, and extract only Y channel data 494 | if only_test_y_channel: 495 | raw_tensor = rgb_to_ycbcr_torch(raw_tensor, only_use_y_channel=True) 496 | dst_tensor = rgb_to_ycbcr_torch(dst_tensor, only_use_y_channel=True) 497 | 498 | # Convert data type to torch.float64 bit 499 | raw_tensor = raw_tensor.to(torch.float64) 500 | dst_tensor = dst_tensor.to(torch.float64) 501 | 502 | ssim_metrics = _ssim_torch(raw_tensor * 255.0, dst_tensor * 255.0, window_size, gaussian_kernel_window) 503 | 504 | return ssim_metrics 505 | 506 | 507 | class SSIM(nn.Module): 508 | """PyTorch implements the SSIM (Structural Similarity) function, which only calculates single-channel data 509 | 510 | Args: 511 | crop_border (int): crop border a few pixels 512 | only_only_test_y_channel (bool): Whether to test only the Y channel of the image 513 | window_size (int): Gaussian filter size 514 | gaussian_sigma (float): sigma parameter in Gaussian filter 515 | 516 | Returns: 517 | ssim_metrics (torch.Tensor): SSIM metrics 518 | 519 | """ 520 | 521 | def __init__(self, crop_border: int, 522 | only_only_test_y_channel: bool, 523 | window_size: int = 11, 524 | gaussian_sigma: float = 1.5) -> None: 525 | super().__init__() 526 | self.crop_border = crop_border 527 | self.only_test_y_channel = only_only_test_y_channel 528 | self.window_size = window_size 529 | 530 | gaussian_kernel = cv2.getGaussianKernel(window_size, gaussian_sigma) 531 | self.gaussian_kernel_window = np.outer(gaussian_kernel, gaussian_kernel.transpose()) 532 | 533 | def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor) -> float: 534 | ssim_metrics = _ssim_single_torch(raw_tensor, 535 | dst_tensor, 536 | self.crop_border, 537 | self.only_test_y_channel, 538 | self.window_size, 539 | self.gaussian_kernel_window) 540 | 541 | return ssim_metrics 542 | 543 | 544 | def _fspecial_gaussian_torch(window_size: int, sigma: float, channels: int): 545 | """PyTorch implements the fspecial_gaussian() function in MATLAB 546 | 547 | Args: 548 | window_size (int): Gaussian filter size 549 | sigma (float): sigma parameter in Gaussian filter 550 | channels (int): number of input image channels 551 | 552 | Returns: 553 | gaussian_kernel_window (torch.Tensor): Gaussian filter in Tensor format 554 | 555 | """ 556 | if type(window_size) is int: 557 | shape = (window_size, window_size) 558 | else: 559 | shape = window_size 560 | m, n = [(ss - 1.) / 2. for ss in shape] 561 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 562 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) 563 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 564 | sumh = h.sum() 565 | 566 | if sumh != 0: 567 | h /= sumh 568 | 569 | gaussian_kernel_window = torch.from_numpy(h).float().repeat(channels, 1, 1, 1) 570 | 571 | return gaussian_kernel_window 572 | 573 | 574 | def _to_tuple(n): 575 | def parse(x): 576 | if isinstance(x, collections.abc.Iterable): 577 | return x 578 | return tuple(repeat(x, n)) 579 | 580 | return parse 581 | 582 | 583 | def _excact_padding_2d(tensor: torch.Tensor, 584 | kernel: torch.Tensor | tuple, 585 | stride: int = 1, 586 | dilation: int = 1, 587 | mode: str = "same") -> torch.Tensor: 588 | assert len(tensor.shape) == 4, f"Only support 4D tensor input, but got {tensor.shape}" 589 | kernel = _to_tuple(2)(kernel) 590 | stride = _to_tuple(2)(stride) 591 | dilation = _to_tuple(2)(dilation) 592 | b, c, h, w = tensor.shape 593 | h2 = math.ceil(h / stride[0]) 594 | w2 = math.ceil(w / stride[1]) 595 | pad_row = (h2 - 1) * stride[0] + (kernel[0] - 1) * dilation[0] + 1 - h 596 | pad_col = (w2 - 1) * stride[1] + (kernel[1] - 1) * dilation[1] + 1 - w 597 | pad_l, pad_r, pad_t, pad_b = (pad_col // 2, pad_col - pad_col // 2, pad_row // 2, pad_row - pad_row // 2) 598 | 599 | mode = mode if mode != "same" else "constant" 600 | if mode != "symmetric": 601 | tensor = F.pad(tensor, (pad_l, pad_r, pad_t, pad_b), mode=mode) 602 | elif mode == "symmetric": 603 | sym_h = torch.flip(tensor, [2]) 604 | sym_w = torch.flip(tensor, [3]) 605 | sym_hw = torch.flip(tensor, [2, 3]) 606 | 607 | row1 = torch.cat((sym_hw, sym_h, sym_hw), dim=3) 608 | row2 = torch.cat((sym_w, tensor, sym_w), dim=3) 609 | row3 = torch.cat((sym_hw, sym_h, sym_hw), dim=3) 610 | 611 | whole_map = torch.cat((row1, row2, row3), dim=2) 612 | 613 | tensor = whole_map[:, :, h - pad_t:2 * h + pad_b, w - pad_l:2 * w + pad_r, ] 614 | 615 | return tensor 616 | 617 | 618 | class ExactPadding2d(nn.Module): 619 | r"""This function calculate exact padding values for 4D tensor inputs, 620 | and support the same padding mode as tensorflow. 621 | 622 | Args: 623 | kernel (int or tuple): kernel size. 624 | stride (int or tuple): stride size. 625 | dilation (int or tuple): dilation size, default with 1. 626 | mode (srt): padding mode can be ('same', 'symmetric', 'replicate', 'circular') 627 | 628 | """ 629 | 630 | def __init__(self, kernel, stride=1, dilation=1, mode="same") -> None: 631 | super().__init__() 632 | self.kernel = _to_tuple(2)(kernel) 633 | self.stride = _to_tuple(2)(stride) 634 | self.dilation = _to_tuple(2)(dilation) 635 | self.mode = mode 636 | 637 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 638 | return _excact_padding_2d(tensor, self.kernel, self.stride, self.dilation, self.mode) 639 | 640 | 641 | def _image_filter(tensor: torch.Tensor, 642 | weight: torch.Tensor, 643 | bias=None, 644 | stride: int = 1, 645 | padding: str = "same", 646 | dilation: int = 1, 647 | groups: int = 1): 648 | """PyTorch implements the imfilter() function in MATLAB 649 | 650 | Args: 651 | tensor (torch.Tensor): Tensor image data 652 | weight (torch.Tensor): filter weight 653 | padding (str): how to pad pixels. Default: ``same`` 654 | dilation (int): convolution dilation scale 655 | groups (int): number of grouped convolutions 656 | 657 | """ 658 | kernel_size = weight.shape[2:] 659 | exact_padding_2d = ExactPadding2d(kernel_size, stride, dilation, mode=padding) 660 | 661 | return F.conv2d(exact_padding_2d(tensor), weight, bias, stride, dilation=dilation, groups=groups) 662 | 663 | 664 | def _reshape_input_torch(tensor: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]: 665 | if tensor.dim() == 4: 666 | b, c, h, w = tensor.size() 667 | elif tensor.dim() == 3: 668 | c, h, w = tensor.size() 669 | b = None 670 | elif tensor.dim() == 2: 671 | h, w = tensor.size() 672 | b = c = None 673 | else: 674 | raise ValueError('{}-dim Tensor is not supported!'.format(tensor.dim())) 675 | 676 | tensor = tensor.view(-1, 1, h, w) 677 | return tensor, b, c, h, w 678 | 679 | 680 | def _reshape_output_torch(tensor: torch.Tensor, b: _I, c: _I) -> torch.Tensor: 681 | rh = tensor.size(-2) 682 | rw = tensor.size(-1) 683 | # Back to the original dimension 684 | if b is not None: 685 | tensor = tensor.view(b, c, rh, rw) # 4-dim 686 | else: 687 | if c is not None: 688 | tensor = tensor.view(c, rh, rw) # 3-dim 689 | else: 690 | tensor = tensor.view(rh, rw) # 2-dim 691 | 692 | return tensor 693 | 694 | 695 | def _cast_input_torch(tensor: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]: 696 | if tensor.dtype != torch.float32 or tensor.dtype != torch.float64: 697 | dtype = tensor.dtype 698 | tensor = tensor.float() 699 | else: 700 | dtype = None 701 | 702 | return tensor, dtype 703 | 704 | 705 | def _cast_output_torch(tensor: torch.Tensor, dtype: _D) -> torch.Tensor: 706 | if dtype is not None: 707 | if not dtype.is_floating_point: 708 | tensor = tensor.round() 709 | # To prevent over/underflow when converting types 710 | if dtype is torch.uint8: 711 | tensor = tensor.clamp(0, 255) 712 | 713 | tensor = tensor.to(dtype=dtype) 714 | 715 | return tensor 716 | 717 | 718 | def _cubic_contribution_torch(tensor: torch.Tensor, a: float = -0.5) -> torch.Tensor: 719 | ax = tensor.abs() 720 | ax2 = ax * ax 721 | ax3 = ax * ax2 722 | 723 | range_01 = ax.le(1) 724 | range_12 = torch.logical_and(ax.gt(1), ax.le(2)) 725 | 726 | cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1 727 | cont_01 = cont_01 * range_01.to(dtype=tensor.dtype) 728 | 729 | cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a) 730 | cont_12 = cont_12 * range_12.to(dtype=tensor.dtype) 731 | 732 | cont = cont_01 + cont_12 733 | return cont 734 | 735 | 736 | def _gaussian_contribution_torch(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor: 737 | range_3sigma = (x.abs() <= 3 * sigma + 1) 738 | # Normalization will be done after 739 | cont = torch.exp(-x.pow(2) / (2 * sigma ** 2)) 740 | cont = cont * range_3sigma.to(dtype=x.dtype) 741 | return cont 742 | 743 | 744 | def _reflect_padding_torch(tensor: torch.Tensor, dim: int, pad_pre: int, pad_post: int) -> torch.Tensor: 745 | """ 746 | Apply reflect padding to the given Tensor. 747 | Note that it is slightly different from the PyTorch functional.pad, 748 | where boundary elements are used only once. 749 | Instead, we follow the MATLAB implementation 750 | which uses boundary elements twice. 751 | For example, 752 | [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation, 753 | while our implementation yields [a, a, b, c, d, d]. 754 | """ 755 | b, c, h, w = tensor.size() 756 | if dim == 2 or dim == -2: 757 | padding_buffer = tensor.new_zeros(b, c, h + pad_pre + pad_post, w) 758 | padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(tensor) 759 | for p in range(pad_pre): 760 | padding_buffer[..., pad_pre - p - 1, :].copy_(tensor[..., p, :]) 761 | for p in range(pad_post): 762 | padding_buffer[..., h + pad_pre + p, :].copy_(tensor[..., -(p + 1), :]) 763 | else: 764 | padding_buffer = tensor.new_zeros(b, c, h, w + pad_pre + pad_post) 765 | padding_buffer[..., pad_pre:(w + pad_pre)].copy_(tensor) 766 | for p in range(pad_pre): 767 | padding_buffer[..., pad_pre - p - 1].copy_(tensor[..., p]) 768 | for p in range(pad_post): 769 | padding_buffer[..., w + pad_pre + p].copy_(tensor[..., -(p + 1)]) 770 | 771 | return padding_buffer 772 | 773 | 774 | def _padding_torch(tensor: torch.Tensor, 775 | dim: int, 776 | pad_pre: int, 777 | pad_post: int, 778 | padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor: 779 | if padding_type is None: 780 | return tensor 781 | elif padding_type == 'reflect': 782 | x_pad = _reflect_padding_torch(tensor, dim, pad_pre, pad_post) 783 | else: 784 | raise ValueError('{} padding is not supported!'.format(padding_type)) 785 | 786 | return x_pad 787 | 788 | 789 | def _get_padding_torch(tensor: torch.Tensor, kernel_size: int, x_size: int) -> typing.Tuple[int, int, torch.Tensor]: 790 | tensor = tensor.long() 791 | r_min = tensor.min() 792 | r_max = tensor.max() + kernel_size - 1 793 | 794 | if r_min <= 0: 795 | pad_pre = -r_min 796 | pad_pre = pad_pre.item() 797 | tensor += pad_pre 798 | else: 799 | pad_pre = 0 800 | 801 | if r_max >= x_size: 802 | pad_post = r_max - x_size + 1 803 | pad_post = pad_post.item() 804 | else: 805 | pad_post = 0 806 | 807 | return pad_pre, pad_post, tensor 808 | 809 | 810 | def _get_weight_torch(tensor: torch.Tensor, 811 | kernel_size: int, 812 | kernel: str = "cubic", 813 | sigma: float = 2.0, 814 | antialiasing_factor: float = 1) -> torch.Tensor: 815 | buffer_pos = tensor.new_zeros(kernel_size, len(tensor)) 816 | for idx, buffer_sub in enumerate(buffer_pos): 817 | buffer_sub.copy_(tensor - idx) 818 | 819 | # Expand (downsampling) / Shrink (upsampling) the receptive field. 820 | buffer_pos *= antialiasing_factor 821 | if kernel == 'cubic': 822 | weight = _cubic_contribution_torch(buffer_pos) 823 | elif kernel == 'gaussian': 824 | weight = _gaussian_contribution_torch(buffer_pos, sigma=sigma) 825 | else: 826 | raise ValueError('{} kernel is not supported!'.format(kernel)) 827 | 828 | weight /= weight.sum(dim=0, keepdim=True) 829 | return weight 830 | 831 | 832 | def _reshape_tensor_torch(tensor: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor: 833 | # Resize height 834 | if dim == 2 or dim == -2: 835 | k = (kernel_size, 1) 836 | h_out = tensor.size(-2) - kernel_size + 1 837 | w_out = tensor.size(-1) 838 | # Resize width 839 | else: 840 | k = (1, kernel_size) 841 | h_out = tensor.size(-2) 842 | w_out = tensor.size(-1) - kernel_size + 1 843 | 844 | unfold = F.unfold(tensor, k) 845 | unfold = unfold.view(unfold.size(0), -1, h_out, w_out) 846 | return unfold 847 | 848 | 849 | def _resize_1d_torch(tensor: torch.Tensor, 850 | dim: int, 851 | size: int, 852 | scale: float, 853 | kernel: str = 'cubic', 854 | sigma: float = 2.0, 855 | padding_type: str = 'reflect', 856 | antialiasing: bool = True) -> torch.Tensor: 857 | """ 858 | Args: 859 | tensor (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W). 860 | dim (int): 861 | scale (float): 862 | size (int): 863 | Return: 864 | """ 865 | # Identity case 866 | if scale == 1: 867 | return tensor 868 | 869 | # Default bicubic kernel with antialiasing (only when downsampling) 870 | if kernel == 'cubic': 871 | kernel_size = 4 872 | else: 873 | kernel_size = math.floor(6 * sigma) 874 | 875 | if antialiasing and (scale < 1): 876 | antialiasing_factor = scale 877 | kernel_size = math.ceil(kernel_size / antialiasing_factor) 878 | else: 879 | antialiasing_factor = 1 880 | 881 | # We allow margin to both sizes 882 | kernel_size += 2 883 | 884 | # Weights only depend on the shape of input and output, 885 | # so we do not calculate gradients here. 886 | with torch.no_grad(): 887 | pos = torch.linspace( 888 | 0, 889 | size - 1, 890 | steps=size, 891 | dtype=tensor.dtype, 892 | device=tensor.device, 893 | ) 894 | pos = (pos + 0.5) / scale - 0.5 895 | base = pos.floor() - (kernel_size // 2) + 1 896 | dist = pos - base 897 | weight = _get_weight_torch( 898 | dist, 899 | kernel_size, 900 | kernel=kernel, 901 | sigma=sigma, 902 | antialiasing_factor=antialiasing_factor, 903 | ) 904 | pad_pre, pad_post, base = _get_padding_torch(base, kernel_size, tensor.size(dim)) 905 | 906 | # To back-propagate through x 907 | x_pad = _padding_torch(tensor, dim, pad_pre, pad_post, padding_type=padding_type) 908 | unfold = _reshape_tensor_torch(x_pad, dim, kernel_size) 909 | # Subsampling first 910 | if dim == 2 or dim == -2: 911 | sample = unfold[..., base, :] 912 | weight = weight.view(1, kernel_size, sample.size(2), 1) 913 | else: 914 | sample = unfold[..., base] 915 | weight = weight.view(1, kernel_size, 1, sample.size(3)) 916 | 917 | # Apply the kernel 918 | tensor = sample * weight 919 | tensor = tensor.sum(dim=1, keepdim=True) 920 | return tensor 921 | 922 | 923 | def _downsampling_2d_torch(tensor: torch.Tensor, k: torch.Tensor, scale: int, 924 | padding_type: str = 'reflect') -> torch.Tensor: 925 | c = tensor.size(1) 926 | k_h = k.size(-2) 927 | k_w = k.size(-1) 928 | 929 | k = k.to(dtype=tensor.dtype, device=tensor.device) 930 | k = k.view(1, 1, k_h, k_w) 931 | k = k.repeat(c, c, 1, 1) 932 | e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False) 933 | e = e.view(c, c, 1, 1) 934 | k = k * e 935 | 936 | pad_h = (k_h - scale) // 2 937 | pad_w = (k_w - scale) // 2 938 | tensor = _padding_torch(tensor, -2, pad_h, pad_h, padding_type=padding_type) 939 | tensor = _padding_torch(tensor, -1, pad_w, pad_w, padding_type=padding_type) 940 | y = F.conv2d(tensor, k, padding=0, stride=scale) 941 | return y 942 | 943 | 944 | def _cov_torch(tensor, rowvar=True, bias=False): 945 | r"""Estimate a covariance matrix (np.cov) 946 | Ref: https://gist.github.com/ModarTensai/5ab449acba9df1a26c12060240773110 947 | """ 948 | tensor = tensor if rowvar else tensor.transpose(-1, -2) 949 | tensor = tensor - tensor.mean(dim=-1, keepdim=True) 950 | if tensor.shape[-1] - int(not bool(bias)) == 0: 951 | factor = 1 952 | else: 953 | factor = 1 / (tensor.shape[-1] - int(not bool(bias))) 954 | return factor * tensor @ tensor.transpose(-1, -2) 955 | 956 | 957 | def _nancov_torch(x): 958 | r"""Calculate nancov for batched tensor, rows that contains nan value 959 | will be removed. 960 | Args: 961 | x (tensor): (B, row_num, feat_dim) 962 | Return: 963 | cov (tensor): (B, feat_dim, feat_dim) 964 | """ 965 | assert len(x.shape) == 3, f'Shape of input should be (batch_size, row_num, feat_dim), but got {x.shape}' 966 | b, rownum, feat_dim = x.shape 967 | nan_mask = torch.isnan(x).any(dim=2, keepdim=True) 968 | cov_x = [] 969 | for i in range(b): 970 | x_no_nan = x[i].masked_select(~nan_mask[i]).reshape(-1, feat_dim) 971 | 972 | cov_x.append(_cov_torch(x_no_nan, rowvar=False)) 973 | return torch.stack(cov_x) 974 | 975 | 976 | def _nanmean_torch(v, *args, inplace=False, **kwargs): 977 | r"""nanmean same as matlab function: calculate mean values by removing all nan. 978 | """ 979 | if not inplace: 980 | v = v.clone() 981 | is_nan = torch.isnan(v) 982 | v[is_nan] = 0 983 | return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs) 984 | 985 | 986 | def _symm_pad_torch(im: torch.Tensor, padding: [int, int, int, int]): 987 | """Symmetric padding same as tensorflow. 988 | Ref: https://discuss.pytorch.org/t/symmetric-padding/19866/3 989 | """ 990 | h, w = im.shape[-2:] 991 | left, right, top, bottom = padding 992 | 993 | x_idx = np.arange(-left, w + right) 994 | y_idx = np.arange(-top, h + bottom) 995 | 996 | def reflect(x, minx, maxx): 997 | """ Reflects an array around two points making a triangular waveform that ramps up 998 | and down, allowing for pad lengths greater than the input length """ 999 | rng = maxx - minx 1000 | double_rng = 2 * rng 1001 | mod = np.fmod(x - minx, double_rng) 1002 | normed_mod = np.where(mod < 0, mod + double_rng, mod) 1003 | out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx 1004 | return np.array(out, dtype=x.dtype) 1005 | 1006 | x_pad = reflect(x_idx, -0.5, w - 0.5) 1007 | y_pad = reflect(y_idx, -0.5, h - 0.5) 1008 | xx, yy = np.meshgrid(x_pad, y_pad) 1009 | return im[..., yy, xx] 1010 | 1011 | 1012 | def _blockproc_torch(x, kernel, fun, border_size=None, pad_partial=False, pad_method='zero'): 1013 | r"""blockproc function like matlab 1014 | 1015 | Difference: 1016 | - Partial blocks is discarded (if exist) for fast GPU process. 1017 | 1018 | Args: 1019 | x (tensor): shape (b, c, h, w) 1020 | kernel (int or tuple): block size 1021 | func (function): function to process each block 1022 | border_size (int or tuple): border pixels to each block 1023 | pad_partial: pad partial blocks to make them full-sized, default False 1024 | pad_method: [zero, replicate, symmetric] how to pad partial block when pad_partial is set True 1025 | 1026 | Return: 1027 | results (tensor): concatenated results of each block 1028 | 1029 | """ 1030 | assert len(x.shape) == 4, f'Shape of input has to be (b, c, h, w) but got {x.shape}' 1031 | kernel = _to_tuple(2)(kernel) 1032 | if pad_partial: 1033 | b, c, h, w = x.shape 1034 | stride = kernel 1035 | h2 = math.ceil(h / stride[0]) 1036 | w2 = math.ceil(w / stride[1]) 1037 | pad_row = (h2 - 1) * stride[0] + kernel[0] - h 1038 | pad_col = (w2 - 1) * stride[1] + kernel[1] - w 1039 | padding = (0, pad_col, 0, pad_row) 1040 | if pad_method == 'zero': 1041 | x = F.pad(x, padding, mode='constant') 1042 | elif pad_method == 'symmetric': 1043 | x = _symm_pad_torch(x, padding) 1044 | else: 1045 | x = F.pad(x, padding, mode=pad_method) 1046 | 1047 | if border_size is not None: 1048 | raise NotImplementedError('Blockproc with border is not implemented yet') 1049 | else: 1050 | b, c, h, w = x.shape 1051 | block_size_h, block_size_w = kernel 1052 | num_block_h = math.floor(h / block_size_h) 1053 | num_block_w = math.floor(w / block_size_w) 1054 | 1055 | # extract blocks in (row, column) manner, i.e., stored with column first 1056 | blocks = F.unfold(x, kernel, stride=kernel) 1057 | blocks = blocks.reshape(b, c, *kernel, num_block_h, num_block_w) 1058 | blocks = blocks.permute(5, 4, 0, 1, 2, 3).reshape(num_block_h * num_block_w * b, c, *kernel) 1059 | 1060 | results = fun(blocks) 1061 | results = results.reshape(num_block_h * num_block_w, b, *results.shape[1:]).transpose(0, 1) 1062 | return results 1063 | 1064 | 1065 | def _image_resize_torch(tensor: torch.Tensor, 1066 | scale_factor: typing.Optional[float] = None, 1067 | sizes: typing.Optional[typing.Tuple[int, int]] = None, 1068 | kernel: typing.Union[str, torch.Tensor] = "cubic", 1069 | sigma: float = 2, 1070 | rotation_degree: float = 0, 1071 | padding_type: str = "reflect", 1072 | antialiasing: bool = True) -> torch.Tensor: 1073 | """ 1074 | Args: 1075 | tensor (torch.Tensor): 1076 | scale_factor (float): 1077 | sizes (tuple(int, int)): 1078 | kernel (str, default='cubic'): 1079 | sigma (float, default=2): 1080 | rotation_degree (float, default=0): 1081 | padding_type (str, default='reflect'): 1082 | antialiasing (bool, default=True): 1083 | Return: 1084 | torch.Tensor: 1085 | """ 1086 | scales = (scale_factor, scale_factor) 1087 | 1088 | if scale_factor is None and sizes is None: 1089 | raise ValueError('One of scale or sizes must be specified!') 1090 | if scale_factor is not None and sizes is not None: 1091 | raise ValueError('Please specify scale or sizes to avoid conflict!') 1092 | 1093 | tensor, b, c, h, w = _reshape_input_torch(tensor) 1094 | 1095 | if sizes is None and scale_factor is not None: 1096 | ''' 1097 | # Check if we can apply the convolution algorithm 1098 | scale_inv = 1 / scale 1099 | if isinstance(kernel, str) and scale_inv.is_integer(): 1100 | kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing) 1101 | elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer(): 1102 | raise ValueError( 1103 | 'An integer downsampling factor ' 1104 | 'should be used with a predefined kernel!' 1105 | ) 1106 | ''' 1107 | # Determine output size 1108 | sizes = (math.ceil(h * scale_factor), math.ceil(w * scale_factor)) 1109 | scales = (scale_factor, scale_factor) 1110 | 1111 | if scale_factor is None and sizes is not None: 1112 | scales = (sizes[0] / h, sizes[1] / w) 1113 | 1114 | tensor, dtype = _cast_input_torch(tensor) 1115 | 1116 | if isinstance(kernel, str) and sizes is not None: 1117 | # Core resizing module 1118 | tensor = _resize_1d_torch( 1119 | tensor, 1120 | -2, 1121 | size=sizes[0], 1122 | scale=scales[0], 1123 | kernel=kernel, 1124 | sigma=sigma, 1125 | padding_type=padding_type, 1126 | antialiasing=antialiasing) 1127 | tensor = _resize_1d_torch( 1128 | tensor, 1129 | -1, 1130 | size=sizes[1], 1131 | scale=scales[1], 1132 | kernel=kernel, 1133 | sigma=sigma, 1134 | padding_type=padding_type, 1135 | antialiasing=antialiasing) 1136 | elif isinstance(kernel, torch.Tensor) and scale_factor is not None: 1137 | tensor = _downsampling_2d_torch(tensor, kernel, scale=int(1 / scale_factor)) 1138 | 1139 | tensor = _reshape_output_torch(tensor, b, c) 1140 | tensor = _cast_output_torch(tensor, dtype) 1141 | return tensor 1142 | 1143 | 1144 | def _estimate_aggd_parameters_torch(tensor: torch.Tensor, 1145 | get_sigma: bool) -> [torch.Tensor, torch.Tensor, torch.Tensor]: 1146 | """PyTorch implements the BRISQUE (Blind/Referenceless Image Spatial Quality Evaluator) function 1147 | This function is used to estimate an asymmetric generalized Gaussian distribution 1148 | 1149 | Reference papers: 1150 | `No-Reference Image Quality Assessment in the Spatial Domain` 1151 | `Referenceless Image Spatial Quality Evaluation Engine` 1152 | 1153 | Args: 1154 | tensor (torch.Tensor): data vector 1155 | get_sigma (bool): whether to return the covariance mean 1156 | 1157 | Returns: 1158 | aggd_parameters (torch.Tensor): asymmetric generalized Gaussian distribution 1159 | left_std (torch.Tensor): symmetric left data vector variance mean 1160 | right_std (torch.Tensor): Symmetric right side data vector variance mean 1161 | 1162 | """ 1163 | # The following is obtained according to the formula and the method provided in the paper on WIki encyclopedia 1164 | aggd = torch.arange(0.2, 10 + 0.001, 0.001).to(tensor) 1165 | r_gam = (2 * torch.lgamma(2. / aggd) - (torch.lgamma(1. / aggd) + torch.lgamma(3. / aggd))).exp() 1166 | r_gam = r_gam.repeat(tensor.size(0), 1) 1167 | 1168 | mask_left = tensor < 0 1169 | mask_right = tensor > 0 1170 | count_left = mask_left.sum(dim=(-1, -2), dtype=torch.float32) 1171 | count_right = mask_right.sum(dim=(-1, -2), dtype=torch.float32) 1172 | 1173 | left_std = torch.sqrt_((tensor * mask_left).pow(2).sum(dim=(-1, -2)) / (count_left + 1e-8)) 1174 | right_std = torch.sqrt_((tensor * mask_right).pow(2).sum(dim=(-1, -2)) / (count_right + 1e-8)) 1175 | gamma_hat = left_std / right_std 1176 | rhat = tensor.abs().mean(dim=(-1, -2)).pow(2) / tensor.pow(2).mean(dim=(-1, -2)) 1177 | rhat_norm = (rhat * (gamma_hat.pow(3) + 1) * (gamma_hat + 1)) / (gamma_hat.pow(2) + 1).pow(2) 1178 | 1179 | array_position = (r_gam - rhat_norm).abs().argmin(dim=-1) 1180 | aggd_parameters = aggd[array_position] 1181 | 1182 | if get_sigma: 1183 | left_beta = left_std.squeeze(-1) * ( 1184 | torch.lgamma(1 / aggd_parameters) - torch.lgamma(3 / aggd_parameters)).exp().sqrt() 1185 | right_beta = right_std.squeeze(-1) * ( 1186 | torch.lgamma(1 / aggd_parameters) - torch.lgamma(3 / aggd_parameters)).exp().sqrt() 1187 | return aggd_parameters, left_beta, right_beta 1188 | 1189 | else: 1190 | left_std = left_std.squeeze_(-1) 1191 | right_std = right_std.squeeze_(-1) 1192 | return aggd_parameters, left_std, right_std 1193 | 1194 | 1195 | def _get_mscn_feature_torch(tensor: torch.Tensor) -> np.ndarray: 1196 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function, 1197 | This function is used to calculate the feature map 1198 | 1199 | Reference papers: 1200 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video` 1201 | 1202 | Args: 1203 | tensor (torch.Tensor): The image to be evaluated for NIQE sharpness 1204 | 1205 | Returns: 1206 | feature (torch.Tensor): image feature map 1207 | 1208 | """ 1209 | batch_size = tensor.shape[0] 1210 | aggd_block = tensor[:, [0]] 1211 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters_torch(aggd_block, True) 1212 | feature = [aggd_parameters, (left_beta + right_beta) / 2] 1213 | 1214 | shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] 1215 | for i in range(len(shifts)): 1216 | shifted_block = torch.roll(aggd_block, shifts[i], dims=(2, 3)) 1217 | aggd_parameters, left_beta, right_beta = _estimate_aggd_parameters_torch(aggd_block * shifted_block, True) 1218 | mean = (right_beta - left_beta) * (torch.lgamma(2 / aggd_parameters) - torch.lgamma(1 / aggd_parameters)).exp() 1219 | feature.extend((aggd_parameters, mean, left_beta, right_beta)) 1220 | 1221 | feature = [x.reshape(batch_size, 1) for x in feature] 1222 | feature = torch.cat(feature, dim=-1) 1223 | 1224 | return feature 1225 | 1226 | 1227 | def _fit_mscn_ipac_torch(tensor: torch.Tensor, 1228 | mu_pris_param: torch.Tensor, 1229 | cov_pris_param: torch.Tensor, 1230 | block_size_height: int, 1231 | block_size_width: int, 1232 | kernel_size: int = 7, 1233 | kernel_sigma: float = 7. / 6, 1234 | padding: str = "replicate") -> float: 1235 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function, 1236 | This function is used to fit the inner product of adjacent coefficients of MSCN 1237 | 1238 | Reference papers: 1239 | `Estimation of shape parameter for generalized Gaussian distributions in subband decompositions of video` 1240 | 1241 | Args: 1242 | tensor (torch.Tensor): The image to be evaluated for NIQE sharpness 1243 | mu_pris_param (torch.Tensor): mean of predefined multivariate Gaussians, model computed on original dataset 1244 | cov_pris_param (torch.Tensor): Covariance of predefined multivariate Gaussian model computed on original dataset 1245 | block_size_height (int): the height of the block into which the image is divided 1246 | block_size_width (int): The width of the block into which the image is divided 1247 | kernel_size (int): Gaussian filter size 1248 | kernel_sigma (int): sigma value in Gaussian filter 1249 | padding (str): how to pad pixels. Default: ``replicate`` 1250 | 1251 | Returns: 1252 | niqe_metric (torch.Tensor): NIQE score 1253 | 1254 | """ 1255 | # crop image 1256 | b, c, h, w = tensor.shape 1257 | num_block_height = math.floor(h / block_size_height) 1258 | num_block_width = math.floor(w / block_size_width) 1259 | tensor = tensor[..., 0:num_block_height * block_size_height, 0:num_block_width * block_size_width] 1260 | 1261 | distparam = [] 1262 | for scale in (1, 2): 1263 | kernel = _fspecial_gaussian_torch(kernel_size, kernel_sigma, 1).to(tensor) 1264 | mu = _image_filter(tensor, kernel, padding=padding) 1265 | std = _image_filter(tensor ** 2, kernel, padding=padding) 1266 | sigma = torch.sqrt_((std - mu ** 2).abs() + 1e-8) 1267 | structdis = (tensor - mu) / (sigma + 1) 1268 | 1269 | distparam.append(_blockproc_torch(structdis, 1270 | [block_size_height // scale, block_size_width // scale], 1271 | fun=_get_mscn_feature_torch)) 1272 | 1273 | if scale == 1: 1274 | tensor = _image_resize_torch(tensor / 255., scale_factor=0.5, antialiasing=True) 1275 | tensor = tensor * 255. 1276 | 1277 | distparam = torch.cat(distparam, -1) 1278 | 1279 | # Fit MVG (Multivariate Gaussian) model to distorted patch features 1280 | mu_distparam = _nanmean_torch(distparam, dim=1) 1281 | cov_distparam = _nancov_torch(distparam) 1282 | 1283 | invcov_param = torch.linalg.pinv((cov_pris_param + cov_distparam) / 2) 1284 | diff = (mu_pris_param - mu_distparam).unsqueeze(1) 1285 | niqe_metric = torch.bmm(torch.bmm(diff, invcov_param), diff.transpose(1, 2)).squeeze() 1286 | niqe_metric = torch.sqrt(niqe_metric) 1287 | 1288 | return niqe_metric 1289 | 1290 | 1291 | def _niqe_torch(tensor: torch.Tensor, 1292 | crop_border: int, 1293 | niqe_model_path: str, 1294 | block_size_height: int = 96, 1295 | block_size_width: int = 96 1296 | ) -> float: 1297 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function, 1298 | 1299 | Attributes: 1300 | tensor (torch.Tensor): The image to evaluate the sharpness of the BRISQUE 1301 | crop_border (int): crop border a few pixels 1302 | niqe_model_path (str): NIQE model estimator weight address 1303 | block_size_height (int): The height of the block the image is divided into. Default: 96 1304 | block_size_width (int): The width of the block the image is divided into. Default: 96 1305 | 1306 | Returns: 1307 | niqe_metrics (torch.Tensor): NIQE metrics 1308 | 1309 | """ 1310 | # crop border pixels 1311 | if crop_border > 0: 1312 | tensor = tensor[:, :, crop_border:-crop_border, crop_border:-crop_border] 1313 | 1314 | # Load the NIQE feature extraction model 1315 | niqe_model = loadmat(niqe_model_path) 1316 | 1317 | mu_pris_param = np.ravel(niqe_model["mu_prisparam"]) 1318 | cov_pris_param = niqe_model["cov_prisparam"] 1319 | mu_pris_param = torch.from_numpy(mu_pris_param).to(tensor) 1320 | cov_pris_param = torch.from_numpy(cov_pris_param).to(tensor) 1321 | 1322 | mu_pris_param = mu_pris_param.repeat(tensor.size(0), 1) 1323 | cov_pris_param = cov_pris_param.repeat(tensor.size(0), 1, 1) 1324 | 1325 | # NIQE only tests on Y channel images and needs to convert the images 1326 | y_tensor = rgb_to_ycbcr_torch(tensor, only_use_y_channel=True) 1327 | y_tensor *= 255.0 1328 | y_tensor = y_tensor.round() 1329 | 1330 | # Convert data type to torch.float64 bit 1331 | y_tensor = y_tensor.to(torch.float64) 1332 | 1333 | niqe_metric = _fit_mscn_ipac_torch(y_tensor, 1334 | mu_pris_param, 1335 | cov_pris_param, 1336 | block_size_height, 1337 | block_size_width) 1338 | 1339 | return niqe_metric 1340 | 1341 | 1342 | class NIQE(nn.Module): 1343 | """PyTorch implements the NIQE (Natural Image Quality Evaluator) function, 1344 | 1345 | Attributes: 1346 | crop_border (int): crop border a few pixels 1347 | niqe_model_path (str): NIQE model address 1348 | block_size_height (int): The height of the block the image is divided into. Default: 96 1349 | block_size_width (int): The width of the block the image is divided into. Default: 96 1350 | 1351 | Returns: 1352 | niqe_metrics (torch.Tensor): NIQE metrics 1353 | 1354 | """ 1355 | 1356 | def __init__(self, crop_border: int, 1357 | niqe_model_path: str, 1358 | block_size_height: int = 96, 1359 | block_size_width: int = 96) -> None: 1360 | super().__init__() 1361 | self.crop_border = crop_border 1362 | self.niqe_model_path = niqe_model_path 1363 | self.block_size_height = block_size_height 1364 | self.block_size_width = block_size_width 1365 | 1366 | def forward(self, raw_tensor: torch.Tensor) -> float: 1367 | niqe_metrics = _niqe_torch(raw_tensor, 1368 | self.crop_border, 1369 | self.niqe_model_path, 1370 | self.block_size_height, 1371 | self.block_size_width) 1372 | 1373 | return niqe_metrics 1374 | --------------------------------------------------------------------------------