├── .gitattributes ├── pyproject.toml ├── vsdpir ├── __main__.py ├── network_unet.py ├── __init__.py └── basicblock.py ├── LICENSE ├── README.md └── .gitignore /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "vsdpir" 7 | version = "4.3.0" 8 | description = "DPIR function for VapourSynth" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license-files = { paths = ["LICENSE"] } 12 | authors = [ 13 | { name = "HolyWu", email = "holywu@gmail.com" }, 14 | ] 15 | keywords = [ 16 | "DPIR", 17 | "PyTorch", 18 | "TensorRT", 19 | "VapourSynth", 20 | ] 21 | classifiers = [ 22 | "Environment :: GPU :: NVIDIA CUDA", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | "Programming Language :: Python :: 3", 26 | "Topic :: Multimedia :: Video", 27 | ] 28 | dependencies = [ 29 | "numpy", 30 | "requests", 31 | "torch>=2.6.0", 32 | "tqdm", 33 | "VapourSynth>=66", 34 | ] 35 | 36 | [project.urls] 37 | Homepage = "https://github.com/HolyWu/vs-dpir" 38 | Issues = "https://github.com/HolyWu/vs-dpir/issues" 39 | -------------------------------------------------------------------------------- /vsdpir/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | from tqdm import tqdm 5 | 6 | 7 | def download_model(url: str) -> None: 8 | model_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") 9 | os.makedirs(model_dir, exist_ok=True) 10 | filename = url.split("/")[-1] 11 | r = requests.get(url, stream=True) 12 | with open(os.path.join(model_dir, filename), "wb") as f: 13 | with tqdm( 14 | unit="B", 15 | unit_scale=True, 16 | unit_divisor=1024, 17 | miniters=1, 18 | desc=filename, 19 | total=int(r.headers.get("content-length", 0)), 20 | ) as pbar: 21 | for chunk in r.iter_content(chunk_size=4096): 22 | f.write(chunk) 23 | pbar.update(len(chunk)) 24 | 25 | 26 | if __name__ == "__main__": 27 | url = "https://github.com/HolyWu/vs-dpir/releases/download/model/" 28 | models = ["drunet_color", "drunet_deblocking_color", "drunet_deblocking_gray", "drunet_gray"] 29 | for model in models: 30 | download_model(url + model + ".pth") 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 HolyWu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DPIR 2 | Plug-and-Play Image Restoration with Deep Denoiser Prior, based on https://github.com/cszn/DPIR. 3 | 4 | 5 | ## Dependencies 6 | - [PyTorch](https://pytorch.org/get-started/) 2.6.0 or later 7 | - [VapourSynth](http://www.vapoursynth.com/) R66 or later 8 | 9 | `trt` requires additional packages: 10 | - [TensorRT](https://developer.nvidia.com/tensorrt) 10.7.0.post1 or later 11 | - [Torch-TensorRT](https://pytorch.org/TensorRT/) 2.6.0 or later 12 | 13 | To install the latest stable version of PyTorch and Torch-TensorRT, run: 14 | ``` 15 | pip install -U packaging setuptools wheel 16 | pip install -U torch torchvision torch_tensorrt --index-url https://download.pytorch.org/whl/cu126 --extra-index-url https://pypi.nvidia.com 17 | ``` 18 | 19 | 20 | ## Installation 21 | ``` 22 | pip install -U vsdpir 23 | ``` 24 | 25 | If you want to download all models at once, run `python -m vsdpir`. If you prefer to only download the model you 26 | specified at first run, set `auto_download=True` in `dpir()`. 27 | 28 | 29 | ## Usage 30 | ```python 31 | from vsdpir import dpir 32 | 33 | ret = dpir(clip) 34 | ``` 35 | 36 | See `__init__.py` for the description of the parameters. 37 | -------------------------------------------------------------------------------- /vsdpir/network_unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from . import basicblock as B 4 | 5 | 6 | class UNetRes(nn.Module): 7 | def __init__( 8 | self, 9 | in_nc=1, 10 | out_nc=1, 11 | nc=[64, 128, 256, 512], 12 | nb=4, 13 | act_mode="R", 14 | downsample_mode="strideconv", 15 | upsample_mode="convtranspose", 16 | ): 17 | super(UNetRes, self).__init__() 18 | 19 | self.m_head = B.conv(in_nc, nc[0], bias=False, mode="C") 20 | 21 | # downsample 22 | if downsample_mode == "avgpool": 23 | downsample_block = B.downsample_avgpool 24 | elif downsample_mode == "maxpool": 25 | downsample_block = B.downsample_maxpool 26 | elif downsample_mode == "strideconv": 27 | downsample_block = B.downsample_strideconv 28 | else: 29 | raise NotImplementedError("downsample mode [{:s}] is not found".format(downsample_mode)) 30 | 31 | self.m_down1 = B.sequential( 32 | *[B.ResBlock(nc[0], nc[0], bias=False, mode="C" + act_mode + "C") for _ in range(nb)], 33 | downsample_block(nc[0], nc[1], bias=False, mode="2") 34 | ) 35 | self.m_down2 = B.sequential( 36 | *[B.ResBlock(nc[1], nc[1], bias=False, mode="C" + act_mode + "C") for _ in range(nb)], 37 | downsample_block(nc[1], nc[2], bias=False, mode="2") 38 | ) 39 | self.m_down3 = B.sequential( 40 | *[B.ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C") for _ in range(nb)], 41 | downsample_block(nc[2], nc[3], bias=False, mode="2") 42 | ) 43 | 44 | self.m_body = B.sequential( 45 | *[B.ResBlock(nc[3], nc[3], bias=False, mode="C" + act_mode + "C") for _ in range(nb)] 46 | ) 47 | 48 | # upsample 49 | if upsample_mode == "upconv": 50 | upsample_block = B.upsample_upconv 51 | elif upsample_mode == "pixelshuffle": 52 | upsample_block = B.upsample_pixelshuffle 53 | elif upsample_mode == "convtranspose": 54 | upsample_block = B.upsample_convtranspose 55 | else: 56 | raise NotImplementedError("upsample mode [{:s}] is not found".format(upsample_mode)) 57 | 58 | self.m_up3 = B.sequential( 59 | upsample_block(nc[3], nc[2], bias=False, mode="2"), 60 | *[B.ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C") for _ in range(nb)] 61 | ) 62 | self.m_up2 = B.sequential( 63 | upsample_block(nc[2], nc[1], bias=False, mode="2"), 64 | *[B.ResBlock(nc[1], nc[1], bias=False, mode="C" + act_mode + "C") for _ in range(nb)] 65 | ) 66 | self.m_up1 = B.sequential( 67 | upsample_block(nc[1], nc[0], bias=False, mode="2"), 68 | *[B.ResBlock(nc[0], nc[0], bias=False, mode="C" + act_mode + "C") for _ in range(nb)] 69 | ) 70 | 71 | self.m_tail = B.conv(nc[0], out_nc, bias=False, mode="C") 72 | 73 | def forward(self, x0): 74 | x1 = self.m_head(x0) 75 | x2 = self.m_down1(x1) 76 | x3 = self.m_down2(x2) 77 | x4 = self.m_down3(x3) 78 | x = self.m_body(x4) 79 | x = self.m_up3(x + x4) 80 | x = self.m_up2(x + x3) 81 | x = self.m_up1(x + x2) 82 | x = self.m_tail(x + x1) 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /vsdpir/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | import os 5 | import warnings 6 | from contextlib import contextmanager 7 | from dataclasses import dataclass 8 | from threading import Lock 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import vapoursynth as vs 14 | 15 | from .__main__ import download_model 16 | from .network_unet import UNetRes 17 | 18 | __version__ = "4.3.0" 19 | 20 | os.environ["CI_BUILD"] = "1" 21 | os.environ["CUDA_MODULE_LOADING"] = "LAZY" 22 | 23 | warnings.filterwarnings("ignore", "The given NumPy array is not writable") 24 | 25 | model_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") 26 | 27 | 28 | class Backend: 29 | @dataclass 30 | class Torch: 31 | module: torch.nn.Module 32 | 33 | @dataclass 34 | class TensorRT: 35 | module: list[torch.nn.Module] 36 | 37 | 38 | @contextmanager 39 | def redirect_stdout_to_stderr(): 40 | original_stdout = os.dup(1) 41 | try: 42 | os.dup2(2, 1) 43 | yield 44 | finally: 45 | os.dup2(original_stdout, 1) 46 | os.close(original_stdout) 47 | 48 | 49 | @redirect_stdout_to_stderr() 50 | @torch.inference_mode() 51 | def dpir( 52 | clip: vs.VideoNode, 53 | device_index: int = 0, 54 | num_streams: int = 1, 55 | batch_size: int = 1, 56 | task: str = "deblock", 57 | auto_download: bool = False, 58 | strength: float | vs.VideoNode | None = None, 59 | tile: list[int] = [0, 0], 60 | tile_pad: int = 8, 61 | trt: bool = False, 62 | trt_static_shape: bool = True, 63 | trt_min_shape: list[int] = [128, 128], 64 | trt_opt_shape: list[int] = [1920, 1080], 65 | trt_max_shape: list[int] = [1920, 1080], 66 | trt_debug: bool = False, 67 | trt_workspace_size: int = 0, 68 | trt_max_aux_streams: int | None = None, 69 | trt_optimization_level: int | None = None, 70 | trt_cache_dir: str = model_dir, 71 | ) -> vs.VideoNode: 72 | """Deep Plug-and-Play Image Restoration 73 | 74 | :param clip: Clip to process. Only RGBH/RGBS/GRAYH/GRAYS formats are supported. RGBH/GRAYH 75 | perform inference in FP16 mode while RGBS/GRAYS perform inference in FP32 mode. 76 | :param device_index: Device ordinal of the GPU. 77 | :param num_streams: Number of CUDA streams to enqueue the kernels. 78 | :param batch_size: Number of frames per batch. 79 | :param task: Task to perform. Must be 'deblock' or 'denoise'. 80 | :param auto_download: Automatically download the specified model if the file has not been downloaded. 81 | :param strength: Strength for deblocking/denoising. 82 | Defaults to 50.0 for 'deblock', 5.0 for 'denoise'. 83 | Also accepts a clip of GRAY format for varying strength. 84 | :param tile: Tile width and height. As too large images result in the out of GPU memory issue, so 85 | this tile option will first crop input images into tiles, and then process each of 86 | them. Finally, they will be merged into one image. 0 denotes for do not use tile. 87 | :param tile_pad: Pad size for each tile, to remove border artifacts. 88 | :param trt: Use TensorRT for high-performance inference. 89 | :param trt_static_shape: Build with static or dynamic shapes. 90 | :param trt_min_shape: Min size of dynamic shapes. Ignored if trt_static_shape=True. 91 | :param trt_opt_shape: Opt size of dynamic shapes. Ignored if trt_static_shape=True. 92 | :param trt_max_shape: Max size of dynamic shapes. Ignored if trt_static_shape=True. 93 | :param trt_debug: Print out verbose debugging information. 94 | :param trt_workspace_size: Size constraints of workspace memory pool. 95 | :param trt_max_aux_streams: Maximum number of auxiliary streams per inference stream that TRT is allowed to use 96 | to run kernels in parallel if the network contains ops that can run in parallel, 97 | with the cost of more memory usage. Set this to 0 for optimal memory usage. 98 | (default = using heuristics) 99 | :param trt_optimization_level: Builder optimization level. Higher level allows TensorRT to spend more building time 100 | for more optimization options. Valid values include integers from 0 to the maximum 101 | optimization level, which is currently 5. (default is 3) 102 | :param trt_cache_dir: Directory for TensorRT engine file. Engine will be cached when it's built for the 103 | first time. Note each engine is created for specific settings such as model 104 | path/name, precision, workspace etc, and specific GPUs and it's not portable. 105 | """ 106 | if not isinstance(clip, vs.VideoNode): 107 | raise vs.Error("dpir: this is not a clip") 108 | 109 | if clip.format.id not in [vs.RGBH, vs.RGBS, vs.GRAYH, vs.GRAYS]: 110 | raise vs.Error("dpir: only RGBH/RGBS/GRAYH/GRAYS formats are supported") 111 | 112 | if not torch.cuda.is_available(): 113 | raise vs.Error("dpir: CUDA is not available") 114 | 115 | if num_streams < 1: 116 | raise vs.Error("dpir: num_streams must be at least 1") 117 | 118 | if batch_size < 1: 119 | raise vs.Error("dpir: batch_size must be at least 1") 120 | 121 | task = task.lower() 122 | 123 | if task not in ["deblock", "denoise"]: 124 | raise vs.Error("dpir: task must be 'deblock' or 'denoise'") 125 | 126 | if isinstance(strength, vs.VideoNode): 127 | if strength.format.color_family != vs.GRAY: 128 | raise vs.Error("dpir: strength must be of GRAY format") 129 | 130 | if strength.width != clip.width or strength.height != clip.height or strength.num_frames != clip.num_frames: 131 | raise vs.Error("dpir: strength must have the same dimensions and number of frames as main clip") 132 | 133 | if not isinstance(tile, list) or len(tile) != 2: 134 | raise vs.Error("dpir: tile must be a list with 2 items") 135 | 136 | if not trt_static_shape: 137 | if not isinstance(trt_min_shape, list) or len(trt_min_shape) != 2: 138 | raise vs.Error("dpir: trt_min_shape must be a list with 2 items") 139 | 140 | if any(trt_min_shape[i] < 1 for i in range(2)): 141 | raise vs.Error("dpir: trt_min_shape must be at least 1") 142 | 143 | if not isinstance(trt_opt_shape, list) or len(trt_opt_shape) != 2: 144 | raise vs.Error("dpir: trt_opt_shape must be a list with 2 items") 145 | 146 | if any(trt_opt_shape[i] < 1 for i in range(2)): 147 | raise vs.Error("dpir: trt_opt_shape must be at least 1") 148 | 149 | if not isinstance(trt_max_shape, list) or len(trt_max_shape) != 2: 150 | raise vs.Error("dpir: trt_max_shape must be a list with 2 items") 151 | 152 | if any(trt_max_shape[i] < 1 for i in range(2)): 153 | raise vs.Error("dpir: trt_max_shape must be at least 1") 154 | 155 | if any(trt_min_shape[i] >= trt_max_shape[i] for i in range(2)): 156 | raise vs.Error("dpir: trt_min_shape must be less than trt_max_shape") 157 | 158 | torch.set_float32_matmul_precision("high") 159 | 160 | color_or_gray = "color" if clip.format.color_family == vs.RGB else "gray" 161 | in_nc = clip.format.num_planes + 1 162 | 163 | fp16 = clip.format.bits_per_sample == 16 164 | if fp16: 165 | dtype = torch.half 166 | noise_format = vs.GRAYH 167 | else: 168 | dtype = torch.float 169 | noise_format = vs.GRAYS 170 | 171 | device = torch.device("cuda", device_index) 172 | 173 | if task == "deblock": 174 | model_name = f"drunet_deblocking_{color_or_gray}.pth" 175 | 176 | if isinstance(strength, vs.VideoNode): 177 | noise = strength.std.Expr("x 100 /", format=noise_format) 178 | else: 179 | noise = clip.std.BlankClip( 180 | format=noise_format, color=(50.0 if strength is None else strength) / 100, keep=True 181 | ) 182 | else: 183 | model_name = f"drunet_{color_or_gray}.pth" 184 | 185 | if isinstance(strength, vs.VideoNode): 186 | noise = strength.std.Expr("x 255 /", format=noise_format) 187 | else: 188 | noise = clip.std.BlankClip( 189 | format=noise_format, color=(5.0 if strength is None else strength) / 255, keep=True 190 | ) 191 | 192 | if not os.path.isfile(os.path.join(model_dir, model_name)): 193 | if auto_download: 194 | download_model(f"https://github.com/HolyWu/vs-dpir/releases/download/model/{model_name}") 195 | else: 196 | raise vs.Error( 197 | "dpir: model file has not been downloaded. run `python -m vsdpir` to download all models, or set " 198 | "`auto_download=True` to only download the specified model" 199 | ) 200 | 201 | if all(t > 0 for t in tile): 202 | pad_w = math.ceil(min(tile[0] + 2 * tile_pad, clip.width) / 8) * 8 203 | pad_h = math.ceil(min(tile[1] + 2 * tile_pad, clip.height) / 8) * 8 204 | else: 205 | pad_w = math.ceil(clip.width / 8) * 8 206 | pad_h = math.ceil(clip.height / 8) * 8 207 | 208 | if trt: 209 | import tensorrt 210 | import torch_tensorrt 211 | 212 | if trt_static_shape: 213 | dimensions = f"{pad_w}x{pad_h}" 214 | else: 215 | for i in range(2): 216 | trt_min_shape[i] = math.ceil(trt_min_shape[i] / 8) * 8 217 | trt_opt_shape[i] = math.ceil(trt_opt_shape[i] / 8) * 8 218 | trt_max_shape[i] = math.ceil(trt_max_shape[i] / 8) * 8 219 | 220 | dimensions = ( 221 | f"min-{trt_min_shape[0]}x{trt_min_shape[1]}" 222 | f"_opt-{trt_opt_shape[0]}x{trt_opt_shape[1]}" 223 | f"_max-{trt_max_shape[0]}x{trt_max_shape[1]}" 224 | ) 225 | 226 | trt_engine_path = os.path.join( 227 | os.path.realpath(trt_cache_dir), 228 | ( 229 | f"{model_name}" 230 | + f"_batch-{batch_size}" 231 | + f"_{dimensions}" 232 | + f"_{'fp16' if fp16 else 'fp32'}" 233 | + f"_{torch.cuda.get_device_name(device)}" 234 | + f"_trt-{tensorrt.__version__}" 235 | + (f"_workspace-{trt_workspace_size}" if trt_workspace_size > 0 else "") 236 | + (f"_aux-{trt_max_aux_streams}" if trt_max_aux_streams is not None else "") 237 | + (f"_level-{trt_optimization_level}" if trt_optimization_level is not None else "") 238 | + ".ts" 239 | ), 240 | ) 241 | 242 | if not os.path.isfile(trt_engine_path): 243 | module = init_module(model_name, in_nc, device, dtype) 244 | inputs = (torch.zeros((batch_size, in_nc, pad_h, pad_w), dtype=dtype, device=device),) 245 | 246 | if trt_static_shape: 247 | dynamic_shapes = None 248 | else: 249 | trt_min_shape.reverse() 250 | trt_opt_shape.reverse() 251 | trt_max_shape.reverse() 252 | 253 | _height = torch.export.Dim("height", min=trt_min_shape[0] // 8, max=trt_max_shape[0] // 8) 254 | _width = torch.export.Dim("width", min=trt_min_shape[1] // 8, max=trt_max_shape[1] // 8) 255 | dim_height = _height * 8 256 | dim_width = _width * 8 257 | dynamic_shapes = {"x0": {2: dim_height, 3: dim_width}} 258 | 259 | exported_program = torch.export.export(module, inputs, dynamic_shapes=dynamic_shapes) 260 | 261 | module = torch_tensorrt.dynamo.compile( 262 | exported_program, 263 | inputs, 264 | device=device, 265 | enabled_precisions={dtype}, 266 | debug=trt_debug, 267 | num_avg_timing_iters=4, 268 | workspace_size=trt_workspace_size, 269 | min_block_size=1, 270 | max_aux_streams=trt_max_aux_streams, 271 | optimization_level=trt_optimization_level, 272 | ) 273 | 274 | torch_tensorrt.save(module, trt_engine_path, output_format="torchscript", inputs=inputs) 275 | 276 | module = [torch.jit.load(trt_engine_path).eval() for _ in range(num_streams)] 277 | backend = Backend.TensorRT(module) 278 | else: 279 | module = init_module(model_name, in_nc, device, dtype) 280 | backend = Backend.Torch(module) 281 | 282 | index = -1 283 | index_lock = Lock() 284 | 285 | inf_streams = [torch.cuda.Stream(device) for _ in range(num_streams)] 286 | f2t_streams = [torch.cuda.Stream(device) for _ in range(num_streams)] 287 | t2f_streams = [torch.cuda.Stream(device) for _ in range(num_streams)] 288 | 289 | inf_stream_locks = [Lock() for _ in range(num_streams)] 290 | f2t_stream_locks = [Lock() for _ in range(num_streams)] 291 | t2f_stream_locks = [Lock() for _ in range(num_streams)] 292 | 293 | @torch.inference_mode() 294 | def inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame: 295 | nonlocal index 296 | with index_lock: 297 | index = (index + 1) % num_streams 298 | local_index = index 299 | 300 | with f2t_stream_locks[local_index], torch.cuda.stream(f2t_streams[local_index]): 301 | img = torch.stack([frame_to_tensor(f[i], device) for i in range(batch_size)]).clamp(0.0, 1.0) 302 | noise_level_map = torch.stack([frame_to_tensor(f[i + batch_size], device) for i in range(batch_size)]) 303 | img = torch.cat([img, noise_level_map], dim=1) 304 | 305 | f2t_streams[local_index].synchronize() 306 | 307 | with inf_stream_locks[local_index], torch.cuda.stream(inf_streams[local_index]): 308 | if all(t > 0 for t in tile): 309 | output = tile_process(img, tile, tile_pad, pad_w, pad_h, backend, local_index) 310 | else: 311 | h, w = img.shape[2:] 312 | if need_pad := pad_w - w > 0 or pad_h - h > 0: 313 | img = F.pad(img, (0, pad_w - w, 0, pad_h - h), "replicate") 314 | 315 | if trt: 316 | output = module[local_index](img) 317 | else: 318 | output = module(img) 319 | 320 | if need_pad: 321 | output = output[:, :, :h, :w] 322 | 323 | inf_streams[local_index].synchronize() 324 | 325 | with t2f_stream_locks[local_index], torch.cuda.stream(t2f_streams[local_index]): 326 | frame = tensor_to_frame(output[0], f[0].copy(), t2f_streams[local_index]) 327 | for i in range(1, batch_size): 328 | frame.props[f"vsdpir_batch_frame{i}"] = tensor_to_frame( 329 | output[i], f[0].copy(), t2f_streams[local_index] 330 | ) 331 | return frame 332 | 333 | if (pad := (batch_size - clip.num_frames % batch_size) % batch_size) > 0: 334 | clip = clip.std.DuplicateFrames([clip.num_frames - 1] * pad) 335 | noise = noise.std.DuplicateFrames([noise.num_frames - 1] * pad) 336 | 337 | clips = [clip[i::batch_size] for i in range(batch_size)] + [noise[i::batch_size] for i in range(batch_size)] 338 | 339 | outputs = [clips[0].std.FrameEval(lambda n: clips[0].std.ModifyFrame(clips, inference), clip_src=clips)] 340 | for i in range(1, batch_size): 341 | outputs.append(outputs[0].std.PropToClip(f"vsdpir_batch_frame{i}")) 342 | 343 | output = vs.core.std.Interleave(outputs) 344 | if pad > 0: 345 | output = output[:-pad] 346 | return output 347 | 348 | 349 | def init_module(model_name: str, in_nc: int, device: torch.device, dtype: torch.dtype) -> torch.nn.Module: 350 | state_dict = torch.load(os.path.join(model_dir, model_name), map_location="cpu") 351 | 352 | with torch.device("meta"): 353 | module = UNetRes(in_nc=in_nc, out_nc=in_nc - 1) 354 | module.load_state_dict(state_dict, assign=True) 355 | return module.eval().to(device, dtype) 356 | 357 | 358 | def frame_to_tensor(frame: vs.VideoFrame, device: torch.device) -> torch.Tensor: 359 | return torch.stack( 360 | [ 361 | torch.from_numpy(np.asarray(frame[plane])).to(device, non_blocking=True) 362 | for plane in range(frame.format.num_planes) 363 | ] 364 | ) 365 | 366 | 367 | def tensor_to_frame(tensor: torch.Tensor, frame: vs.VideoFrame, stream: torch.cuda.Stream) -> vs.VideoFrame: 368 | tensor = tensor.detach() 369 | tensors = [tensor[plane].to("cpu", non_blocking=True) for plane in range(frame.format.num_planes)] 370 | 371 | stream.synchronize() 372 | 373 | for plane in range(frame.format.num_planes): 374 | np.copyto(np.asarray(frame[plane]), tensors[plane].numpy()) 375 | return frame 376 | 377 | 378 | def tile_process( 379 | img: torch.Tensor, 380 | tile: list[int], 381 | tile_pad: int, 382 | pad_w: int, 383 | pad_h: int, 384 | backend: Backend.Torch | Backend.TensorRT, 385 | index: int, 386 | ) -> torch.Tensor: 387 | batch, channel, height, width = img.shape 388 | output_shape = (batch, channel - 1, height, width) 389 | 390 | # start with black image 391 | output = img.new_zeros(output_shape) 392 | 393 | tiles_x = math.ceil(width / tile[0]) 394 | tiles_y = math.ceil(height / tile[1]) 395 | 396 | # loop over all tiles 397 | for y in range(tiles_y): 398 | for x in range(tiles_x): 399 | # extract tile from input image 400 | ofs_x = x * tile[0] 401 | ofs_y = y * tile[1] 402 | 403 | # input tile area on total image 404 | input_start_x = ofs_x 405 | input_end_x = min(ofs_x + tile[0], width) 406 | input_start_y = ofs_y 407 | input_end_y = min(ofs_y + tile[1], height) 408 | 409 | # input tile area on total image with padding 410 | input_start_x_pad = max(input_start_x - tile_pad, 0) 411 | input_end_x_pad = min(input_end_x + tile_pad, width) 412 | input_start_y_pad = max(input_start_y - tile_pad, 0) 413 | input_end_y_pad = min(input_end_y + tile_pad, height) 414 | 415 | # input tile dimensions 416 | input_tile_width = input_end_x - input_start_x 417 | input_tile_height = input_end_y - input_start_y 418 | 419 | input_tile = img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] 420 | 421 | h, w = input_tile.shape[2:] 422 | if need_pad := pad_w - w > 0 or pad_h - h > 0: 423 | input_tile = F.pad(input_tile, (0, pad_w - w, 0, pad_h - h), "replicate") 424 | 425 | # process tile 426 | if isinstance(backend, Backend.TensorRT): 427 | output_tile = backend.module[index](input_tile) 428 | else: 429 | output_tile = backend.module(input_tile) 430 | 431 | if need_pad: 432 | output_tile = output_tile[:, :, :h, :w] 433 | 434 | # output tile area on total image 435 | output_start_x = input_start_x 436 | output_end_x = input_end_x 437 | output_start_y = input_start_y 438 | output_end_y = input_end_y 439 | 440 | # output tile area without padding 441 | output_start_x_tile = input_start_x - input_start_x_pad 442 | output_end_x_tile = output_start_x_tile + input_tile_width 443 | output_start_y_tile = input_start_y - input_start_y_pad 444 | output_end_y_tile = output_start_y_tile + input_tile_height 445 | 446 | # put tile into output image 447 | output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[ 448 | :, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile 449 | ] 450 | 451 | return output 452 | -------------------------------------------------------------------------------- /vsdpir/basicblock.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | """ 8 | # -------------------------------------------- 9 | # Advanced nn.Sequential 10 | # https://github.com/xinntao/BasicSR 11 | # -------------------------------------------- 12 | """ 13 | 14 | 15 | def sequential(*args): 16 | """Advanced nn.Sequential. 17 | 18 | Args: 19 | nn.Sequential, nn.Module 20 | 21 | Returns: 22 | nn.Sequential 23 | """ 24 | if len(args) == 1: 25 | if isinstance(args[0], OrderedDict): 26 | raise NotImplementedError("sequential does not support OrderedDict input.") 27 | return args[0] # No sequential is needed. 28 | modules = [] 29 | for module in args: 30 | if isinstance(module, nn.Sequential): 31 | for submodule in module.children(): 32 | modules.append(submodule) 33 | elif isinstance(module, nn.Module): 34 | modules.append(module) 35 | return nn.Sequential(*modules) 36 | 37 | 38 | """ 39 | # -------------------------------------------- 40 | # Useful blocks 41 | # https://github.com/xinntao/BasicSR 42 | # -------------------------------- 43 | # conv + normaliation + relu (conv) 44 | # (PixelUnShuffle) 45 | # (ConditionalBatchNorm2d) 46 | # concat (ConcatBlock) 47 | # sum (ShortcutBlock) 48 | # resblock (ResBlock) 49 | # Channel Attention (CA) Layer (CALayer) 50 | # Residual Channel Attention Block (RCABlock) 51 | # Residual Channel Attention Group (RCAGroup) 52 | # Residual Dense Block (ResidualDenseBlock_5C) 53 | # Residual in Residual Dense Block (RRDB) 54 | # -------------------------------------------- 55 | """ 56 | 57 | 58 | # -------------------------------------------- 59 | # return nn.Sequantial of (Conv + BN + ReLU) 60 | # -------------------------------------------- 61 | def conv( 62 | in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode="CBR", negative_slope=0.2 63 | ): 64 | L = [] 65 | for t in mode: 66 | if t == "C": 67 | L.append( 68 | nn.Conv2d( 69 | in_channels=in_channels, 70 | out_channels=out_channels, 71 | kernel_size=kernel_size, 72 | stride=stride, 73 | padding=padding, 74 | bias=bias, 75 | ) 76 | ) 77 | elif t == "T": 78 | L.append( 79 | nn.ConvTranspose2d( 80 | in_channels=in_channels, 81 | out_channels=out_channels, 82 | kernel_size=kernel_size, 83 | stride=stride, 84 | padding=padding, 85 | bias=bias, 86 | ) 87 | ) 88 | elif t == "B": 89 | L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) 90 | elif t == "I": 91 | L.append(nn.InstanceNorm2d(out_channels, affine=True)) 92 | elif t == "R": 93 | L.append(nn.ReLU(inplace=True)) 94 | elif t == "r": 95 | L.append(nn.ReLU(inplace=False)) 96 | elif t == "L": 97 | L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) 98 | elif t == "l": 99 | L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) 100 | elif t == "2": 101 | L.append(nn.PixelShuffle(upscale_factor=2)) 102 | elif t == "3": 103 | L.append(nn.PixelShuffle(upscale_factor=3)) 104 | elif t == "4": 105 | L.append(nn.PixelShuffle(upscale_factor=4)) 106 | elif t == "U": 107 | L.append(nn.Upsample(scale_factor=2, mode="nearest")) 108 | elif t == "u": 109 | L.append(nn.Upsample(scale_factor=3, mode="nearest")) 110 | elif t == "v": 111 | L.append(nn.Upsample(scale_factor=4, mode="nearest")) 112 | elif t == "M": 113 | L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) 114 | elif t == "A": 115 | L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) 116 | else: 117 | raise NotImplementedError("Undefined type: ".format(t)) 118 | return sequential(*L) 119 | 120 | 121 | # -------------------------------------------- 122 | # inverse of pixel_shuffle 123 | # -------------------------------------------- 124 | def pixel_unshuffle(input, upscale_factor): 125 | r"""Rearranges elements in a Tensor of shape :math:`(C, rH, rW)` to a 126 | tensor of shape :math:`(*, r^2C, H, W)`. 127 | 128 | Authors: 129 | Zhaoyi Yan, https://github.com/Zhaoyi-Yan 130 | Kai Zhang, https://github.com/cszn/FFDNet 131 | 132 | Date: 133 | 01/Jan/2019 134 | """ 135 | batch_size, channels, in_height, in_width = input.size() 136 | 137 | out_height = in_height // upscale_factor 138 | out_width = in_width // upscale_factor 139 | 140 | input_view = input.contiguous().view(batch_size, channels, out_height, upscale_factor, out_width, upscale_factor) 141 | 142 | channels *= upscale_factor**2 143 | unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() 144 | return unshuffle_out.view(batch_size, channels, out_height, out_width) 145 | 146 | 147 | class PixelUnShuffle(nn.Module): 148 | r"""Rearranges elements in a Tensor of shape :math:`(C, rH, rW)` to a 149 | tensor of shape :math:`(*, r^2C, H, W)`. 150 | 151 | Authors: 152 | Zhaoyi Yan, https://github.com/Zhaoyi-Yan 153 | Kai Zhang, https://github.com/cszn/FFDNet 154 | 155 | Date: 156 | 01/Jan/2019 157 | """ 158 | 159 | def __init__(self, upscale_factor): 160 | super(PixelUnShuffle, self).__init__() 161 | self.upscale_factor = upscale_factor 162 | 163 | def forward(self, input): 164 | return pixel_unshuffle(input, self.upscale_factor) 165 | 166 | def extra_repr(self): 167 | return "upscale_factor={}".format(self.upscale_factor) 168 | 169 | 170 | # -------------------------------------------- 171 | # conditional batch norm 172 | # https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775 173 | # -------------------------------------------- 174 | class ConditionalBatchNorm2d(nn.Module): 175 | def __init__(self, num_features, num_classes): 176 | super().__init__() 177 | self.num_features = num_features 178 | self.bn = nn.BatchNorm2d(num_features, affine=False) 179 | self.embed = nn.Embedding(num_classes, num_features * 2) 180 | self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 181 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 182 | 183 | def forward(self, x, y): 184 | out = self.bn(x) 185 | gamma, beta = self.embed(y).chunk(2, 1) 186 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 187 | return out 188 | 189 | 190 | # -------------------------------------------- 191 | # Concat the output of a submodule to its input 192 | # -------------------------------------------- 193 | class ConcatBlock(nn.Module): 194 | def __init__(self, submodule): 195 | super(ConcatBlock, self).__init__() 196 | self.sub = submodule 197 | 198 | def forward(self, x): 199 | output = torch.cat((x, self.sub(x)), dim=1) 200 | return output 201 | 202 | def __repr__(self): 203 | return self.sub.__repr__() + "concat" 204 | 205 | 206 | # -------------------------------------------- 207 | # sum the output of a submodule to its input 208 | # -------------------------------------------- 209 | class ShortcutBlock(nn.Module): 210 | def __init__(self, submodule): 211 | super(ShortcutBlock, self).__init__() 212 | 213 | self.sub = submodule 214 | 215 | def forward(self, x): 216 | output = x + self.sub(x) 217 | return output 218 | 219 | def __repr__(self): 220 | tmpstr = "Identity + \n|" 221 | modstr = self.sub.__repr__().replace("\n", "\n|") 222 | tmpstr = tmpstr + modstr 223 | return tmpstr 224 | 225 | 226 | # -------------------------------------------- 227 | # Res Block: x + conv(relu(conv(x))) 228 | # -------------------------------------------- 229 | class ResBlock(nn.Module): 230 | def __init__( 231 | self, 232 | in_channels=64, 233 | out_channels=64, 234 | kernel_size=3, 235 | stride=1, 236 | padding=1, 237 | bias=True, 238 | mode="CRC", 239 | negative_slope=0.2, 240 | ): 241 | super(ResBlock, self).__init__() 242 | 243 | assert in_channels == out_channels, "Only support in_channels==out_channels." 244 | if mode[0] in ["R", "L"]: 245 | mode = mode[0].lower() + mode[1:] 246 | 247 | self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 248 | 249 | def forward(self, x): 250 | # res = self.res(x) 251 | return x + self.res(x) 252 | 253 | 254 | # -------------------------------------------- 255 | # simplified information multi-distillation block (IMDB) 256 | # x + conv1(concat(split(relu(conv(x)))x3)) 257 | # -------------------------------------------- 258 | class IMDBlock(nn.Module): 259 | """ 260 | @inproceedings{hui2019lightweight, 261 | title={Lightweight Image Super-Resolution with Information Multi-distillation Network}, 262 | author={Hui, Zheng and Gao, Xinbo and Yang, Yunchu and Wang, Xiumei}, 263 | booktitle={Proceedings of the 27th ACM International Conference on Multimedia (ACM MM)}, 264 | pages={2024--2032}, 265 | year={2019} 266 | } 267 | @inproceedings{zhang2019aim, 268 | title={AIM 2019 Challenge on Constrained Super-Resolution: Methods and Results}, 269 | author={Kai Zhang and Shuhang Gu and Radu Timofte and others}, 270 | booktitle={IEEE International Conference on Computer Vision Workshops}, 271 | year={2019} 272 | } 273 | """ 274 | 275 | def __init__( 276 | self, 277 | in_channels=64, 278 | out_channels=64, 279 | kernel_size=3, 280 | stride=1, 281 | padding=1, 282 | bias=True, 283 | mode="CL", 284 | d_rate=0.25, 285 | negative_slope=0.05, 286 | ): 287 | super(IMDBlock, self).__init__() 288 | self.d_nc = int(in_channels * d_rate) 289 | self.r_nc = int(in_channels - self.d_nc) 290 | 291 | assert mode[0] == "C", "convolutional layer first" 292 | 293 | self.conv1 = conv(in_channels, in_channels, kernel_size, stride, padding, bias, mode, negative_slope) 294 | self.conv2 = conv(self.r_nc, in_channels, kernel_size, stride, padding, bias, mode, negative_slope) 295 | self.conv3 = conv(self.r_nc, in_channels, kernel_size, stride, padding, bias, mode, negative_slope) 296 | self.conv4 = conv(self.r_nc, self.d_nc, kernel_size, stride, padding, bias, mode[0], negative_slope) 297 | self.conv1x1 = conv( 298 | self.d_nc * 4, 299 | out_channels, 300 | kernel_size=1, 301 | stride=1, 302 | padding=0, 303 | bias=bias, 304 | mode=mode[0], 305 | negative_slope=negative_slope, 306 | ) 307 | 308 | def forward(self, x): 309 | d1, r = torch.split(self.conv1(x), (self.d_nc, self.r_nc), dim=1) 310 | d2, r = torch.split(self.conv2(r), (self.d_nc, self.r_nc), dim=1) 311 | d3, r = torch.split(self.conv3(r), (self.d_nc, self.r_nc), dim=1) 312 | r = self.conv4(r) 313 | res = self.conv1x1(torch.cat((d1, d2, d3, r), dim=1)) 314 | return x + res 315 | 316 | 317 | # d1, r1 = torch.split(self.conv1(x), (self.d_nc, self.r_nc), dim=1) 318 | # d2, r2 = torch.split(self.conv2(r1), (self.d_nc, self.r_nc), dim=1) 319 | # d3, r3 = torch.split(self.conv3(r2), (self.d_nc, self.r_nc), dim=1) 320 | # d4 = self.conv4(r3) 321 | # -------------------------------------------- 322 | # Channel Attention (CA) Layer 323 | # -------------------------------------------- 324 | class CALayer(nn.Module): 325 | def __init__(self, channel=64, reduction=16): 326 | super(CALayer, self).__init__() 327 | 328 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 329 | self.conv_fc = nn.Sequential( 330 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 331 | nn.ReLU(inplace=True), 332 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 333 | nn.Sigmoid(), 334 | ) 335 | 336 | def forward(self, x): 337 | y = self.avg_pool(x) 338 | y = self.conv_fc(y) 339 | return x * y 340 | 341 | 342 | # -------------------------------------------- 343 | # Residual Channel Attention Block (RCAB) 344 | # -------------------------------------------- 345 | class RCABlock(nn.Module): 346 | def __init__( 347 | self, 348 | in_channels=64, 349 | out_channels=64, 350 | kernel_size=3, 351 | stride=1, 352 | padding=1, 353 | bias=True, 354 | mode="CRC", 355 | reduction=16, 356 | negative_slope=0.2, 357 | ): 358 | super(RCABlock, self).__init__() 359 | assert in_channels == out_channels, "Only support in_channels==out_channels." 360 | if mode[0] in ["R", "L"]: 361 | mode = mode[0].lower() + mode[1:] 362 | 363 | self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 364 | self.ca = CALayer(out_channels, reduction) 365 | 366 | def forward(self, x): 367 | res = self.res(x) 368 | res = self.ca(res) 369 | return res + x 370 | 371 | 372 | # -------------------------------------------- 373 | # Residual Channel Attention Group (RG) 374 | # -------------------------------------------- 375 | class RCAGroup(nn.Module): 376 | def __init__( 377 | self, 378 | in_channels=64, 379 | out_channels=64, 380 | kernel_size=3, 381 | stride=1, 382 | padding=1, 383 | bias=True, 384 | mode="CRC", 385 | reduction=16, 386 | nb=12, 387 | negative_slope=0.2, 388 | ): 389 | super(RCAGroup, self).__init__() 390 | assert in_channels == out_channels, "Only support in_channels==out_channels." 391 | if mode[0] in ["R", "L"]: 392 | mode = mode[0].lower() + mode[1:] 393 | 394 | RG = [ 395 | RCABlock(in_channels, out_channels, kernel_size, stride, padding, bias, mode, reduction, negative_slope) 396 | for _ in range(nb) 397 | ] 398 | RG.append(conv(out_channels, out_channels, mode="C")) 399 | self.rg = nn.Sequential(*RG) # self.rg = ShortcutBlock(nn.Sequential(*RG)) 400 | 401 | def forward(self, x): 402 | res = self.rg(x) 403 | return res + x 404 | 405 | 406 | # -------------------------------------------- 407 | # Residual Dense Block 408 | # style: 5 convs 409 | # -------------------------------------------- 410 | class ResidualDenseBlock_5C(nn.Module): 411 | def __init__(self, nc=64, gc=32, kernel_size=3, stride=1, padding=1, bias=True, mode="CR", negative_slope=0.2): 412 | super(ResidualDenseBlock_5C, self).__init__() 413 | # gc: growth channel 414 | self.conv1 = conv(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 415 | self.conv2 = conv(nc + gc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 416 | self.conv3 = conv(nc + 2 * gc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 417 | self.conv4 = conv(nc + 3 * gc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 418 | self.conv5 = conv(nc + 4 * gc, nc, kernel_size, stride, padding, bias, mode[:-1], negative_slope) 419 | 420 | def forward(self, x): 421 | x1 = self.conv1(x) 422 | x2 = self.conv2(torch.cat((x, x1), 1)) 423 | x3 = self.conv3(torch.cat((x, x1, x2), 1)) 424 | x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) 425 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 426 | return x5.mul_(0.2) + x 427 | 428 | 429 | # -------------------------------------------- 430 | # Residual in Residual Dense Block 431 | # 3x5c 432 | # -------------------------------------------- 433 | class RRDB(nn.Module): 434 | def __init__(self, nc=64, gc=32, kernel_size=3, stride=1, padding=1, bias=True, mode="CR", negative_slope=0.2): 435 | super(RRDB, self).__init__() 436 | 437 | self.RDB1 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 438 | self.RDB2 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 439 | self.RDB3 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) 440 | 441 | def forward(self, x): 442 | out = self.RDB1(x) 443 | out = self.RDB2(out) 444 | out = self.RDB3(out) 445 | return out.mul_(0.2) + x 446 | 447 | 448 | """ 449 | # -------------------------------------------- 450 | # Upsampler 451 | # Kai Zhang, https://github.com/cszn/KAIR 452 | # -------------------------------------------- 453 | # upsample_pixelshuffle 454 | # upsample_upconv 455 | # upsample_convtranspose 456 | # -------------------------------------------- 457 | """ 458 | 459 | 460 | # -------------------------------------------- 461 | # conv + subp (+ relu) 462 | # -------------------------------------------- 463 | def upsample_pixelshuffle( 464 | in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode="2R", negative_slope=0.2 465 | ): 466 | assert len(mode) < 4 and mode[0] in ["2", "3", "4"], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." 467 | up1 = conv( 468 | in_channels, 469 | out_channels * (int(mode[0]) ** 2), 470 | kernel_size, 471 | stride, 472 | padding, 473 | bias, 474 | mode="C" + mode, 475 | negative_slope=negative_slope, 476 | ) 477 | return up1 478 | 479 | 480 | # -------------------------------------------- 481 | # nearest_upsample + conv (+ R) 482 | # -------------------------------------------- 483 | def upsample_upconv( 484 | in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode="2R", negative_slope=0.2 485 | ): 486 | assert len(mode) < 4 and mode[0] in ["2", "3", "4"], "mode examples: 2, 2R, 2BR, 3, ..., 4BR" 487 | if mode[0] == "2": 488 | uc = "UC" 489 | elif mode[0] == "3": 490 | uc = "uC" 491 | elif mode[0] == "4": 492 | uc = "vC" 493 | mode = mode.replace(mode[0], uc) 494 | up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope) 495 | return up1 496 | 497 | 498 | # -------------------------------------------- 499 | # convTranspose (+ relu) 500 | # -------------------------------------------- 501 | def upsample_convtranspose( 502 | in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode="2R", negative_slope=0.2 503 | ): 504 | assert len(mode) < 4 and mode[0] in ["2", "3", "4"], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." 505 | kernel_size = int(mode[0]) 506 | stride = int(mode[0]) 507 | mode = mode.replace(mode[0], "T") 508 | up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 509 | return up1 510 | 511 | 512 | """ 513 | # -------------------------------------------- 514 | # Downsampler 515 | # Kai Zhang, https://github.com/cszn/KAIR 516 | # -------------------------------------------- 517 | # downsample_strideconv 518 | # downsample_maxpool 519 | # downsample_avgpool 520 | # -------------------------------------------- 521 | """ 522 | 523 | 524 | # -------------------------------------------- 525 | # strideconv (+ relu) 526 | # -------------------------------------------- 527 | def downsample_strideconv( 528 | in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode="2R", negative_slope=0.2 529 | ): 530 | assert len(mode) < 4 and mode[0] in ["2", "3", "4"], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." 531 | kernel_size = int(mode[0]) 532 | stride = int(mode[0]) 533 | mode = mode.replace(mode[0], "C") 534 | down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 535 | return down1 536 | 537 | 538 | # -------------------------------------------- 539 | # maxpooling + conv (+ relu) 540 | # -------------------------------------------- 541 | def downsample_maxpool( 542 | in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode="2R", negative_slope=0.2 543 | ): 544 | assert len(mode) < 4 and mode[0] in ["2", "3"], "mode examples: 2, 2R, 2BR, 3, ..., 3BR." 545 | kernel_size_pool = int(mode[0]) 546 | stride_pool = int(mode[0]) 547 | mode = mode.replace(mode[0], "MC") 548 | pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) 549 | pool_tail = conv( 550 | in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope 551 | ) 552 | return sequential(pool, pool_tail) 553 | 554 | 555 | # -------------------------------------------- 556 | # averagepooling + conv (+ relu) 557 | # -------------------------------------------- 558 | def downsample_avgpool( 559 | in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode="2R", negative_slope=0.2 560 | ): 561 | assert len(mode) < 4 and mode[0] in ["2", "3"], "mode examples: 2, 2R, 2BR, 3, ..., 3BR." 562 | kernel_size_pool = int(mode[0]) 563 | stride_pool = int(mode[0]) 564 | mode = mode.replace(mode[0], "AC") 565 | pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) 566 | pool_tail = conv( 567 | in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope 568 | ) 569 | return sequential(pool, pool_tail) 570 | 571 | 572 | """ 573 | # -------------------------------------------- 574 | # NonLocalBlock2D: 575 | # embedded_gaussian 576 | # +W(softmax(thetaXphi)Xg) 577 | # -------------------------------------------- 578 | """ 579 | 580 | 581 | # -------------------------------------------- 582 | # non-local block with embedded_gaussian 583 | # https://github.com/AlexHex7/Non-local_pytorch 584 | # -------------------------------------------- 585 | class NonLocalBlock2D(nn.Module): 586 | def __init__( 587 | self, 588 | nc=64, 589 | kernel_size=1, 590 | stride=1, 591 | padding=0, 592 | bias=True, 593 | act_mode="B", 594 | downsample=False, 595 | downsample_mode="maxpool", 596 | negative_slope=0.2, 597 | ): 598 | 599 | super(NonLocalBlock2D, self).__init__() 600 | 601 | inter_nc = nc // 2 602 | self.inter_nc = inter_nc 603 | self.W = conv(inter_nc, nc, kernel_size, stride, padding, bias, mode="C" + act_mode) 604 | self.theta = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode="C") 605 | 606 | if downsample: 607 | if downsample_mode == "avgpool": 608 | downsample_block = downsample_avgpool 609 | elif downsample_mode == "maxpool": 610 | downsample_block = downsample_maxpool 611 | elif downsample_mode == "strideconv": 612 | downsample_block = downsample_strideconv 613 | else: 614 | raise NotImplementedError("downsample mode [{:s}] is not found".format(downsample_mode)) 615 | self.phi = downsample_block(nc, inter_nc, kernel_size, stride, padding, bias, mode="2") 616 | self.g = downsample_block(nc, inter_nc, kernel_size, stride, padding, bias, mode="2") 617 | else: 618 | self.phi = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode="C") 619 | self.g = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode="C") 620 | 621 | def forward(self, x): 622 | """ 623 | :param x: (b, c, t, h, w) 624 | :return: 625 | """ 626 | 627 | batch_size = x.size(0) 628 | 629 | g_x = self.g(x).view(batch_size, self.inter_nc, -1) 630 | g_x = g_x.permute(0, 2, 1) 631 | 632 | theta_x = self.theta(x).view(batch_size, self.inter_nc, -1) 633 | theta_x = theta_x.permute(0, 2, 1) 634 | phi_x = self.phi(x).view(batch_size, self.inter_nc, -1) 635 | f = torch.matmul(theta_x, phi_x) 636 | f_div_C = F.softmax(f, dim=-1) 637 | 638 | y = torch.matmul(f_div_C, g_x) 639 | y = y.permute(0, 2, 1).contiguous() 640 | y = y.view(batch_size, self.inter_nc, *x.size()[2:]) 641 | W_y = self.W(y) 642 | z = W_y + x 643 | 644 | return z 645 | --------------------------------------------------------------------------------