├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── basicsr ├── VERSION ├── __init__.py ├── archs │ ├── __init__.py │ ├── arcface_arch.py │ ├── arch_util.py │ ├── codeformer_arch.py │ ├── rrdbnet_arch.py │ ├── vgg_arch.py │ └── vqgan_arch.py ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── prefetch_dataloader.py │ └── transforms.py ├── losses │ ├── __init__.py │ ├── loss_util.py │ └── losses.py ├── metrics │ ├── __init__.py │ ├── metric_util.py │ └── psnr_ssim.py ├── models │ └── __init__.py ├── ops │ ├── __init__.py │ ├── dcn │ │ ├── __init__.py │ │ ├── deform_conv.py │ │ └── src │ │ │ ├── deform_conv_cuda.cpp │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ └── deform_conv_ext.cpp │ ├── fused_act │ │ ├── __init__.py │ │ ├── fused_act.py │ │ └── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ └── upfirdn2d │ │ ├── __init__.py │ │ ├── src │ │ ├── upfirdn2d.cpp │ │ └── upfirdn2d_kernel.cu │ │ └── upfirdn2d.py ├── setup.py ├── train.py ├── utils │ ├── __init__.py │ ├── dist_util.py │ ├── download_util.py │ ├── file_client.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ ├── realesrgan_utils.py │ └── registry.py └── version.py ├── codeformer_wrapper.py ├── demo.jpg ├── facelib ├── detection │ ├── __init__.py │ ├── align_trans.py │ ├── matlab_cp2tform.py │ ├── retinaface │ │ ├── retinaface.py │ │ ├── retinaface_net.py │ │ └── retinaface_utils.py │ └── yolov5face │ │ ├── __init__.py │ │ ├── face_detector.py │ │ ├── models │ │ ├── __init__.py │ │ ├── common.py │ │ ├── experimental.py │ │ ├── yolo.py │ │ ├── yolov5l.yaml │ │ └── yolov5n.yaml │ │ └── utils │ │ ├── __init__.py │ │ ├── autoanchor.py │ │ ├── datasets.py │ │ ├── extract_ckpt.py │ │ ├── general.py │ │ └── torch_utils.py ├── parsing │ ├── __init__.py │ ├── bisenet.py │ ├── parsenet.py │ └── resnet.py └── utils │ ├── __init__.py │ ├── face_restoration_helper.py │ ├── face_utils.py │ └── misc.py ├── icon.png ├── output └── .gitkeep ├── recognition ├── arcface_onnx.py ├── face_align.py ├── main.py └── scrfd.py ├── refacer.py ├── refacer_bulk.py ├── requirements-COREML.txt ├── requirements-CPU.txt ├── requirements-GPU.txt └── weights ├── CodeFormer └── .gitkeep ├── README.md ├── facelib └── .gitkeep └── inswapper └── .gitkeep /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | out/* 163 | !out/.gitkeep 164 | media 165 | tests 166 | *.onnx 167 | 168 | aaa.md 169 | 170 | *_test.py 171 | img.jpg 172 | test_data 173 | testsrc.mp4 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | IMPORTANT NOTICE 2 | 3 | This project is licensed under a custom MIT License, **except** for the optional 4 | `codeformer` component, which is licensed under the Creative Commons 5 | Attribution-NonCommercial-ShareAlike 4.0 International License (CC BY-NC-SA 4.0). 6 | If you require commercial use, you **must remove** `codeformer`. 7 | See the bottom of this file for full details and removal instructions. 8 | 9 | --- 10 | 11 | Custom MIT License 12 | 13 | Copyright (c) 2023 xaviviro 14 | Copyright (c) 2025 Felipe Daragon 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | - The above copyright notice and this permission notice shall be included in 24 | all copies or substantial portions of the Software. 25 | 26 | - You may only use this Software with content (such as images and videos) 27 | for which you have the necessary rights and permissions. Unauthorized use of 28 | third-party content is strictly prohibited. 29 | 30 | - This Software is intended for educational and research purposes only. Use 31 | of this Software for malicious purposes, including but not limited to identity 32 | theft, invasion of privacy, or defamation, is strictly prohibited. 33 | 34 | - By using this Software, you agree to comply with all applicable laws and 35 | to respect the rights and privacy of others. You agree to use the Software 36 | responsibly and ethically. 37 | 38 | - The Software may contain protective mechanisms intended to prevent its use 39 | with illegal or unauthorized media. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 45 | WHETHER IN AN ACTION OF CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, 46 | OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 47 | 48 | --- 49 | 50 | ## Additional License Notice: Optional `codeformer` Component 51 | 52 | This project distribution optionally includes an old version of a component 53 | named `codeformer` (https://github.com/felipedaragon/codeformer), 54 | developed by Shangchen Zhou. 55 | The `codeformer` component is **NOT** licensed under the MIT License. 56 | Instead, it is licensed under: 57 | 58 | **Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)** 59 | License details: https://creativecommons.org/licenses/by-nc-sa/4.0/ 60 | 61 | Key points about this license: 62 | - **Non-commercial use only**: You may not use `codeformer` for commercial purposes. 63 | - **Attribution required**: You must credit the original creators. 64 | - **ShareAlike**: If you modify and share `codeformer`, you must do so under the same license. 65 | 66 | ### How to Use This Project as MIT Only 67 | 68 | If you wish to use this project solely under the MIT License (for example, 69 | for commercial purposes), you **must remove** the `codeformer` component. 70 | Please follow the instructions provided below: 71 | 72 | - Remove the subdirectories basicsr and facelib 73 | - Remove within weights subdirectory Codeformer and facelib. 74 | - Remove codeformer_wrapper.py 75 | - Edit refacer.py and remove the import: codeformer_wrapper 76 | - Adjust the code so that it doesn't calls the enhance functions from the commented wrapper 77 | - That's all! 78 | 79 | Failure to remove `codeformer` when required may violate the terms of its license. 80 | 81 | About User Outputs: 82 | The outputs generated by this software (such as refaced images or videos) are not 83 | subject to the CC BY-NC-SA license and may be used freely, including for commercial 84 | purposes, regardless of whether the optional codeformer component is used. 85 | 86 | Explanation: 87 | 88 | Codeformer is a model that processes images for face enhancement. 89 | It does not embed its own visible content into the output. 90 | 91 | This is very different from a case where licensed assets (such as textures, 92 | overlays, characters, backgrounds, or artwork) appear visibly in the output. 93 | 94 | Codeformer simply modifies the input image without adding original copyrighted 95 | material. 96 | 97 | Therefore, output images are not derivative works of Codeformer 98 | and are not bound by the NonCommercial restrictions of its license. 99 | 100 | Only the Codeformer model code and weights themselves are under CC BY-NC-SA 4.0 101 | — not the results produced through their use. 102 | -------------------------------------------------------------------------------- /basicsr/VERSION: -------------------------------------------------------------------------------- 1 | 1.3.2 2 | -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .train import * 10 | from .utils import * 11 | from .version import __gitsha__, __version__ 12 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules 16 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | -------------------------------------------------------------------------------- /basicsr/archs/arcface_arch.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from basicsr.utils.registry import ARCH_REGISTRY 3 | 4 | 5 | def conv3x3(inplanes, outplanes, stride=1): 6 | """A simple wrapper for 3x3 convolution with padding. 7 | 8 | Args: 9 | inplanes (int): Channel number of inputs. 10 | outplanes (int): Channel number of outputs. 11 | stride (int): Stride in convolution. Default: 1. 12 | """ 13 | return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | """Basic residual block used in the ResNetArcFace architecture. 18 | 19 | Args: 20 | inplanes (int): Channel number of inputs. 21 | planes (int): Channel number of outputs. 22 | stride (int): Stride in convolution. Default: 1. 23 | downsample (nn.Module): The downsample module. Default: None. 24 | """ 25 | expansion = 1 # output channel expansion ratio 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class IRBlock(nn.Module): 57 | """Improved residual block (IR Block) used in the ResNetArcFace architecture. 58 | 59 | Args: 60 | inplanes (int): Channel number of inputs. 61 | planes (int): Channel number of outputs. 62 | stride (int): Stride in convolution. Default: 1. 63 | downsample (nn.Module): The downsample module. Default: None. 64 | use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. 65 | """ 66 | expansion = 1 # output channel expansion ratio 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): 69 | super(IRBlock, self).__init__() 70 | self.bn0 = nn.BatchNorm2d(inplanes) 71 | self.conv1 = conv3x3(inplanes, inplanes) 72 | self.bn1 = nn.BatchNorm2d(inplanes) 73 | self.prelu = nn.PReLU() 74 | self.conv2 = conv3x3(inplanes, planes, stride) 75 | self.bn2 = nn.BatchNorm2d(planes) 76 | self.downsample = downsample 77 | self.stride = stride 78 | self.use_se = use_se 79 | if self.use_se: 80 | self.se = SEBlock(planes) 81 | 82 | def forward(self, x): 83 | residual = x 84 | out = self.bn0(x) 85 | out = self.conv1(out) 86 | out = self.bn1(out) 87 | out = self.prelu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | if self.use_se: 92 | out = self.se(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.prelu(out) 99 | 100 | return out 101 | 102 | 103 | class Bottleneck(nn.Module): 104 | """Bottleneck block used in the ResNetArcFace architecture. 105 | 106 | Args: 107 | inplanes (int): Channel number of inputs. 108 | planes (int): Channel number of outputs. 109 | stride (int): Stride in convolution. Default: 1. 110 | downsample (nn.Module): The downsample module. Default: None. 111 | """ 112 | expansion = 4 # output channel expansion ratio 113 | 114 | def __init__(self, inplanes, planes, stride=1, downsample=None): 115 | super(Bottleneck, self).__init__() 116 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 117 | self.bn1 = nn.BatchNorm2d(planes) 118 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 119 | self.bn2 = nn.BatchNorm2d(planes) 120 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 121 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.downsample = downsample 124 | self.stride = stride 125 | 126 | def forward(self, x): 127 | residual = x 128 | 129 | out = self.conv1(x) 130 | out = self.bn1(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv2(out) 134 | out = self.bn2(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv3(out) 138 | out = self.bn3(out) 139 | 140 | if self.downsample is not None: 141 | residual = self.downsample(x) 142 | 143 | out += residual 144 | out = self.relu(out) 145 | 146 | return out 147 | 148 | 149 | class SEBlock(nn.Module): 150 | """The squeeze-and-excitation block (SEBlock) used in the IRBlock. 151 | 152 | Args: 153 | channel (int): Channel number of inputs. 154 | reduction (int): Channel reduction ration. Default: 16. 155 | """ 156 | 157 | def __init__(self, channel, reduction=16): 158 | super(SEBlock, self).__init__() 159 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information 160 | self.fc = nn.Sequential( 161 | nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), 162 | nn.Sigmoid()) 163 | 164 | def forward(self, x): 165 | b, c, _, _ = x.size() 166 | y = self.avg_pool(x).view(b, c) 167 | y = self.fc(y).view(b, c, 1, 1) 168 | return x * y 169 | 170 | 171 | @ARCH_REGISTRY.register() 172 | class ResNetArcFace(nn.Module): 173 | """ArcFace with ResNet architectures. 174 | 175 | Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. 176 | 177 | Args: 178 | block (str): Block used in the ArcFace architecture. 179 | layers (tuple(int)): Block numbers in each layer. 180 | use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. 181 | """ 182 | 183 | def __init__(self, block, layers, use_se=True): 184 | if block == 'IRBlock': 185 | block = IRBlock 186 | self.inplanes = 64 187 | self.use_se = use_se 188 | super(ResNetArcFace, self).__init__() 189 | 190 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) 191 | self.bn1 = nn.BatchNorm2d(64) 192 | self.prelu = nn.PReLU() 193 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 194 | self.layer1 = self._make_layer(block, 64, layers[0]) 195 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 196 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 197 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 198 | self.bn4 = nn.BatchNorm2d(512) 199 | self.dropout = nn.Dropout() 200 | self.fc5 = nn.Linear(512 * 8 * 8, 512) 201 | self.bn5 = nn.BatchNorm1d(512) 202 | 203 | # initialization 204 | for m in self.modules(): 205 | if isinstance(m, nn.Conv2d): 206 | nn.init.xavier_normal_(m.weight) 207 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 208 | nn.init.constant_(m.weight, 1) 209 | nn.init.constant_(m.bias, 0) 210 | elif isinstance(m, nn.Linear): 211 | nn.init.xavier_normal_(m.weight) 212 | nn.init.constant_(m.bias, 0) 213 | 214 | def _make_layer(self, block, planes, num_blocks, stride=1): 215 | downsample = None 216 | if stride != 1 or self.inplanes != planes * block.expansion: 217 | downsample = nn.Sequential( 218 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 219 | nn.BatchNorm2d(planes * block.expansion), 220 | ) 221 | layers = [] 222 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) 223 | self.inplanes = planes 224 | for _ in range(1, num_blocks): 225 | layers.append(block(self.inplanes, planes, use_se=self.use_se)) 226 | 227 | return nn.Sequential(*layers) 228 | 229 | def forward(self, x): 230 | x = self.conv1(x) 231 | x = self.bn1(x) 232 | x = self.prelu(x) 233 | x = self.maxpool(x) 234 | 235 | x = self.layer1(x) 236 | x = self.layer2(x) 237 | x = self.layer3(x) 238 | x = self.layer4(x) 239 | x = self.bn4(x) 240 | x = self.dropout(x) 241 | x = x.view(x.size(0), -1) 242 | x = self.fc5(x) 243 | x = self.bn5(x) 244 | 245 | return x -------------------------------------------------------------------------------- /basicsr/archs/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | 12 | Used in RRDB block in ESRGAN. 13 | 14 | Args: 15 | num_feat (int): Channel number of intermediate features. 16 | num_grow_ch (int): Channels for each growth. 17 | """ 18 | 19 | def __init__(self, num_feat=64, num_grow_ch=32): 20 | super(ResidualDenseBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 26 | 27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | 29 | # initialization 30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 31 | 32 | def forward(self, x): 33 | x1 = self.lrelu(self.conv1(x)) 34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 38 | # Emperically, we use 0.2 to scale the residual for better performance 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | """Residual in Residual Dense Block. 44 | 45 | Used in RRDB-Net in ESRGAN. 46 | 47 | Args: 48 | num_feat (int): Channel number of intermediate features. 49 | num_grow_ch (int): Channels for each growth. 50 | """ 51 | 52 | def __init__(self, num_feat, num_grow_ch=32): 53 | super(RRDB, self).__init__() 54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 57 | 58 | def forward(self, x): 59 | out = self.rdb1(x) 60 | out = self.rdb2(out) 61 | out = self.rdb3(out) 62 | # Emperically, we use 0.2 to scale the residual for better performance 63 | return out * 0.2 + x 64 | 65 | 66 | @ARCH_REGISTRY.register() 67 | class RRDBNet(nn.Module): 68 | """Networks consisting of Residual in Residual Dense Block, which is used 69 | in ESRGAN. 70 | 71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 72 | 73 | We extend ESRGAN for scale x2 and scale x1. 74 | Note: This is one option for scale 1, scale 2 in RRDBNet. 75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 77 | 78 | Args: 79 | num_in_ch (int): Channel number of inputs. 80 | num_out_ch (int): Channel number of outputs. 81 | num_feat (int): Channel number of intermediate features. 82 | Default: 64 83 | num_block (int): Block number in the trunk network. Defaults: 23 84 | num_grow_ch (int): Channels for each growth. Default: 32. 85 | """ 86 | 87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 88 | super(RRDBNet, self).__init__() 89 | self.scale = scale 90 | if scale == 2: 91 | num_in_ch = num_in_ch * 4 92 | elif scale == 1: 93 | num_in_ch = num_in_ch * 16 94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | # upsample 98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 102 | 103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 104 | 105 | def forward(self, x): 106 | if self.scale == 2: 107 | feat = pixel_unshuffle(x, scale=2) 108 | elif self.scale == 1: 109 | feat = pixel_unshuffle(x, scale=4) 110 | else: 111 | feat = x 112 | feat = self.conv_first(feat) 113 | body_feat = self.conv_body(self.body(feat)) 114 | feat = feat + body_feat 115 | # upsample 116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 118 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 119 | return out -------------------------------------------------------------------------------- /basicsr/archs/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch import nn as nn 5 | from torchvision.models import vgg as vgg 6 | 7 | from basicsr.utils.registry import ARCH_REGISTRY 8 | 9 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' 10 | NAMES = { 11 | 'vgg11': [ 12 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 13 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 14 | 'pool5' 15 | ], 16 | 'vgg13': [ 17 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 18 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 19 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' 20 | ], 21 | 'vgg16': [ 22 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 23 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 24 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 25 | 'pool5' 26 | ], 27 | 'vgg19': [ 28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 29 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 30 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 31 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' 32 | ] 33 | } 34 | 35 | 36 | def insert_bn(names): 37 | """Insert bn layer after each conv. 38 | 39 | Args: 40 | names (list): The list of layer names. 41 | 42 | Returns: 43 | list: The list of layer names with bn layers. 44 | """ 45 | names_bn = [] 46 | for name in names: 47 | names_bn.append(name) 48 | if 'conv' in name: 49 | position = name.replace('conv', '') 50 | names_bn.append('bn' + position) 51 | return names_bn 52 | 53 | 54 | @ARCH_REGISTRY.register() 55 | class VGGFeatureExtractor(nn.Module): 56 | """VGG network for feature extraction. 57 | 58 | In this implementation, we allow users to choose whether use normalization 59 | in the input feature and the type of vgg network. Note that the pretrained 60 | path must fit the vgg type. 61 | 62 | Args: 63 | layer_name_list (list[str]): Forward function returns the corresponding 64 | features according to the layer_name_list. 65 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 66 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 67 | use_input_norm (bool): If True, normalize the input image. Importantly, 68 | the input feature must in the range [0, 1]. Default: True. 69 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 70 | Default: False. 71 | requires_grad (bool): If true, the parameters of VGG network will be 72 | optimized. Default: False. 73 | remove_pooling (bool): If true, the max pooling operations in VGG net 74 | will be removed. Default: False. 75 | pooling_stride (int): The stride of max pooling operation. Default: 2. 76 | """ 77 | 78 | def __init__(self, 79 | layer_name_list, 80 | vgg_type='vgg19', 81 | use_input_norm=True, 82 | range_norm=False, 83 | requires_grad=False, 84 | remove_pooling=False, 85 | pooling_stride=2): 86 | super(VGGFeatureExtractor, self).__init__() 87 | 88 | self.layer_name_list = layer_name_list 89 | self.use_input_norm = use_input_norm 90 | self.range_norm = range_norm 91 | 92 | self.names = NAMES[vgg_type.replace('_bn', '')] 93 | if 'bn' in vgg_type: 94 | self.names = insert_bn(self.names) 95 | 96 | # only borrow layers that will be used to avoid unused params 97 | max_idx = 0 98 | for v in layer_name_list: 99 | idx = self.names.index(v) 100 | if idx > max_idx: 101 | max_idx = idx 102 | 103 | if os.path.exists(VGG_PRETRAIN_PATH): 104 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 105 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) 106 | vgg_net.load_state_dict(state_dict) 107 | else: 108 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 109 | 110 | features = vgg_net.features[:max_idx + 1] 111 | 112 | modified_net = OrderedDict() 113 | for k, v in zip(self.names, features): 114 | if 'pool' in k: 115 | # if remove_pooling is true, pooling operation will be removed 116 | if remove_pooling: 117 | continue 118 | else: 119 | # in some cases, we may want to change the default stride 120 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 121 | else: 122 | modified_net[k] = v 123 | 124 | self.vgg_net = nn.Sequential(modified_net) 125 | 126 | if not requires_grad: 127 | self.vgg_net.eval() 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | else: 131 | self.vgg_net.train() 132 | for param in self.parameters(): 133 | param.requires_grad = True 134 | 135 | if self.use_input_norm: 136 | # the mean is for image with range [0, 1] 137 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 138 | # the std is for image with range [0, 1] 139 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 140 | 141 | def forward(self, x): 142 | """Forward function. 143 | 144 | Args: 145 | x (Tensor): Input tensor with shape (n, c, h, w). 146 | 147 | Returns: 148 | Tensor: Forward results. 149 | """ 150 | if self.range_norm: 151 | x = (x + 1) / 2 152 | if self.use_input_norm: 153 | x = (x - self.mean) / self.std 154 | output = {} 155 | 156 | for key, layer in self.vgg_net._modules.items(): 157 | x = layer(x) 158 | if key in self.layer_name_list: 159 | output[key] = x.clone() 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must constain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | 84 | prefetch_mode = dataset_opt.get('prefetch_mode') 85 | if prefetch_mode == 'cpu': # CPUPrefetcher 86 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 87 | logger = get_root_logger() 88 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}') 89 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 90 | else: 91 | # prefetch_mode=None: Normal dataloader 92 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 93 | return torch.utils.data.DataLoader(**dataloader_args) 94 | 95 | 96 | def worker_init_fn(worker_id, num_workers, rank, seed): 97 | # Set the worker seed to num_workers * rank + worker_id + seed 98 | worker_seed = num_workers * rank + worker_id + seed 99 | np.random.seed(worker_seed) 100 | random.seed(worker_seed) 101 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA (or MPS/CPU) prefetcher. 86 | 87 | It may consume more GPU memory. 88 | 89 | Args: 90 | loader: Dataloader. 91 | opt (dict): Options. 92 | """ 93 | 94 | def __init__(self, loader, opt): 95 | self.ori_loader = loader 96 | self.loader = iter(loader) 97 | self.opt = opt 98 | 99 | # Cross-platform device detection 100 | if opt['num_gpu'] != 0 and torch.cuda.is_available(): 101 | self.device = torch.device('cuda') 102 | self.stream = torch.cuda.Stream() 103 | elif torch.backends.mps.is_available(): 104 | self.device = torch.device('mps') 105 | self.stream = None 106 | else: 107 | self.device = torch.device('cpu') 108 | self.stream = None 109 | 110 | self.preload() 111 | 112 | def preload(self): 113 | try: 114 | self.batch = next(self.loader) # self.batch is a dict 115 | except StopIteration: 116 | self.batch = None 117 | return None 118 | 119 | if self.stream is not None: 120 | with torch.cuda.stream(self.stream): 121 | for k, v in self.batch.items(): 122 | if torch.is_tensor(v): 123 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 124 | else: 125 | for k, v in self.batch.items(): 126 | if torch.is_tensor(v): 127 | self.batch[k] = self.batch[k].to(device=self.device) 128 | 129 | def next(self): 130 | if self.stream is not None: 131 | torch.cuda.current_stream().wait_stream(self.stream) 132 | batch = self.batch 133 | self.preload() 134 | return batch 135 | 136 | def reset(self): 137 | self.loader = iter(self.ori_loader) 138 | self.preload() -------------------------------------------------------------------------------- /basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | 4 | 5 | def mod_crop(img, scale): 6 | """Mod crop images, used during testing. 7 | 8 | Args: 9 | img (ndarray): Input image. 10 | scale (int): Scale factor. 11 | 12 | Returns: 13 | ndarray: Result image. 14 | """ 15 | img = img.copy() 16 | if img.ndim in (2, 3): 17 | h, w = img.shape[0], img.shape[1] 18 | h_remainder, w_remainder = h % scale, w % scale 19 | img = img[:h - h_remainder, :w - w_remainder, ...] 20 | else: 21 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 22 | return img 23 | 24 | 25 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): 26 | """Paired random crop. 27 | 28 | It crops lists of lq and gt images with corresponding locations. 29 | 30 | Args: 31 | img_gts (list[ndarray] | ndarray): GT images. Note that all images 32 | should have the same shape. If the input is an ndarray, it will 33 | be transformed to a list containing itself. 34 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 35 | should have the same shape. If the input is an ndarray, it will 36 | be transformed to a list containing itself. 37 | gt_patch_size (int): GT patch size. 38 | scale (int): Scale factor. 39 | gt_path (str): Path to ground-truth. 40 | 41 | Returns: 42 | list[ndarray] | ndarray: GT images and LQ images. If returned results 43 | only have one element, just return ndarray. 44 | """ 45 | 46 | if not isinstance(img_gts, list): 47 | img_gts = [img_gts] 48 | if not isinstance(img_lqs, list): 49 | img_lqs = [img_lqs] 50 | 51 | h_lq, w_lq, _ = img_lqs[0].shape 52 | h_gt, w_gt, _ = img_gts[0].shape 53 | lq_patch_size = gt_patch_size // scale 54 | 55 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 56 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 57 | f'multiplication of LQ ({h_lq}, {w_lq}).') 58 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 59 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 60 | f'({lq_patch_size}, {lq_patch_size}). ' 61 | f'Please remove {gt_path}.') 62 | 63 | # randomly choose top and left coordinates for lq patch 64 | top = random.randint(0, h_lq - lq_patch_size) 65 | left = random.randint(0, w_lq - lq_patch_size) 66 | 67 | # crop lq patch 68 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 69 | 70 | # crop corresponding gt patch 71 | top_gt, left_gt = int(top * scale), int(left * scale) 72 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 73 | if len(img_gts) == 1: 74 | img_gts = img_gts[0] 75 | if len(img_lqs) == 1: 76 | img_lqs = img_lqs[0] 77 | return img_gts, img_lqs 78 | 79 | 80 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 81 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 82 | 83 | We use vertical flip and transpose for rotation implementation. 84 | All the images in the list use the same augmentation. 85 | 86 | Args: 87 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 88 | is an ndarray, it will be transformed to a list. 89 | hflip (bool): Horizontal flip. Default: True. 90 | rotation (bool): Ratotation. Default: True. 91 | flows (list[ndarray]: Flows to be augmented. If the input is an 92 | ndarray, it will be transformed to a list. 93 | Dimension is (h, w, 2). Default: None. 94 | return_status (bool): Return the status of flip and rotation. 95 | Default: False. 96 | 97 | Returns: 98 | list[ndarray] | ndarray: Augmented images and flows. If returned 99 | results only have one element, just return ndarray. 100 | 101 | """ 102 | hflip = hflip and random.random() < 0.5 103 | vflip = rotation and random.random() < 0.5 104 | rot90 = rotation and random.random() < 0.5 105 | 106 | def _augment(img): 107 | if hflip: # horizontal 108 | cv2.flip(img, 1, img) 109 | if vflip: # vertical 110 | cv2.flip(img, 0, img) 111 | if rot90: 112 | img = img.transpose(1, 0, 2) 113 | return img 114 | 115 | def _augment_flow(flow): 116 | if hflip: # horizontal 117 | cv2.flip(flow, 1, flow) 118 | flow[:, :, 0] *= -1 119 | if vflip: # vertical 120 | cv2.flip(flow, 0, flow) 121 | flow[:, :, 1] *= -1 122 | if rot90: 123 | flow = flow.transpose(1, 0, 2) 124 | flow = flow[:, :, [1, 0]] 125 | return flow 126 | 127 | if not isinstance(imgs, list): 128 | imgs = [imgs] 129 | imgs = [_augment(img) for img in imgs] 130 | if len(imgs) == 1: 131 | imgs = imgs[0] 132 | 133 | if flows is not None: 134 | if not isinstance(flows, list): 135 | flows = [flows] 136 | flows = [_augment_flow(flow) for flow in flows] 137 | if len(flows) == 1: 138 | flows = flows[0] 139 | return imgs, flows 140 | else: 141 | if return_status: 142 | return imgs, (hflip, vflip, rot90) 143 | else: 144 | return imgs 145 | 146 | 147 | def img_rotate(img, angle, center=None, scale=1.0): 148 | """Rotate image. 149 | 150 | Args: 151 | img (ndarray): Image to be rotated. 152 | angle (float): Rotation angle in degrees. Positive values mean 153 | counter-clockwise rotation. 154 | center (tuple[int]): Rotation center. If the center is None, 155 | initialize it as the center of the image. Default: None. 156 | scale (float): Isotropic scale factor. Default: 1.0. 157 | """ 158 | (h, w) = img.shape[:2] 159 | 160 | if center is None: 161 | center = (w // 2, h // 2) 162 | 163 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 164 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 165 | return rotated_img 166 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils import get_root_logger 4 | from basicsr.utils.registry import LOSS_REGISTRY 5 | from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize, 6 | gradient_penalty_loss, r1_penalty) 7 | 8 | __all__ = [ 9 | 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss', 10 | 'r1_penalty', 'g_path_regularize' 11 | ] 12 | 13 | 14 | def build_loss(opt): 15 | """Build loss from options. 16 | 17 | Args: 18 | opt (dict): Configuration. It must constain: 19 | type (str): Model type. 20 | """ 21 | opt = deepcopy(opt) 22 | loss_type = opt.pop('type') 23 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 24 | logger = get_root_logger() 25 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 26 | return loss 27 | -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .psnr_ssim import calculate_psnr, calculate_ssim 5 | 6 | __all__ = ['calculate_psnr', 'calculate_ssim'] 7 | 8 | 9 | def calculate_metric(data, opt): 10 | """Calculate metric from data and options. 11 | 12 | Args: 13 | opt (dict): Configuration. It must constain: 14 | type (str): Model type. 15 | """ 16 | opt = deepcopy(opt) 17 | metric_type = opt.pop('type') 18 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 19 | return metric 20 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /basicsr/metrics/psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from basicsr.metrics.metric_util import reorder_image, to_y_channel 5 | from basicsr.utils.registry import METRIC_REGISTRY 6 | 7 | 8 | @METRIC_REGISTRY.register() 9 | def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 10 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 11 | 12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 13 | 14 | Args: 15 | img1 (ndarray): Images with range [0, 255]. 16 | img2 (ndarray): Images with range [0, 255]. 17 | crop_border (int): Cropped pixels in each edge of an image. These 18 | pixels are not involved in the PSNR calculation. 19 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 20 | Default: 'HWC'. 21 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 22 | 23 | Returns: 24 | float: psnr result. 25 | """ 26 | 27 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 28 | if input_order not in ['HWC', 'CHW']: 29 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 30 | img1 = reorder_image(img1, input_order=input_order) 31 | img2 = reorder_image(img2, input_order=input_order) 32 | img1 = img1.astype(np.float64) 33 | img2 = img2.astype(np.float64) 34 | 35 | if crop_border != 0: 36 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 38 | 39 | if test_y_channel: 40 | img1 = to_y_channel(img1) 41 | img2 = to_y_channel(img2) 42 | 43 | mse = np.mean((img1 - img2)**2) 44 | if mse == 0: 45 | return float('inf') 46 | return 20. * np.log10(255. / np.sqrt(mse)) 47 | 48 | 49 | def _ssim(img1, img2): 50 | """Calculate SSIM (structural similarity) for one channel images. 51 | 52 | It is called by func:`calculate_ssim`. 53 | 54 | Args: 55 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 56 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 57 | 58 | Returns: 59 | float: ssim result. 60 | """ 61 | 62 | C1 = (0.01 * 255)**2 63 | C2 = (0.03 * 255)**2 64 | 65 | img1 = img1.astype(np.float64) 66 | img2 = img2.astype(np.float64) 67 | kernel = cv2.getGaussianKernel(11, 1.5) 68 | window = np.outer(kernel, kernel.transpose()) 69 | 70 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 71 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 72 | mu1_sq = mu1**2 73 | mu2_sq = mu2**2 74 | mu1_mu2 = mu1 * mu2 75 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 76 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 77 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 78 | 79 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 80 | return ssim_map.mean() 81 | 82 | 83 | @METRIC_REGISTRY.register() 84 | def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 85 | """Calculate SSIM (structural similarity). 86 | 87 | Ref: 88 | Image quality assessment: From error visibility to structural similarity 89 | 90 | The results are the same as that of the official released MATLAB code in 91 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 92 | 93 | For three-channel images, SSIM is calculated for each channel and then 94 | averaged. 95 | 96 | Args: 97 | img1 (ndarray): Images with range [0, 255]. 98 | img2 (ndarray): Images with range [0, 255]. 99 | crop_border (int): Cropped pixels in each edge of an image. These 100 | pixels are not involved in the SSIM calculation. 101 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 102 | Default: 'HWC'. 103 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 104 | 105 | Returns: 106 | float: ssim result. 107 | """ 108 | 109 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 110 | if input_order not in ['HWC', 'CHW']: 111 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 112 | img1 = reorder_image(img1, input_order=input_order) 113 | img2 = reorder_image(img2, input_order=input_order) 114 | img1 = img1.astype(np.float64) 115 | img2 = img2.astype(np.float64) 116 | 117 | if crop_border != 0: 118 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 119 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 120 | 121 | if test_y_channel: 122 | img1 = to_y_channel(img1) 123 | img2 = to_y_channel(img2) 124 | 125 | ssims = [] 126 | for i in range(img1.shape[2]): 127 | ssims.append(_ssim(img1[..., i], img2[..., i])) 128 | return np.array(ssims).mean() 129 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with 12 | # '_model.py' 13 | model_folder = osp.dirname(osp.abspath(__file__)) 14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 15 | # import all the model modules 16 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 17 | 18 | 19 | def build_model(opt): 20 | """Build model from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must constain: 24 | model_type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 28 | logger = get_root_logger() 29 | logger.info(f'Model [{model.__class__.__name__}] is created.') 30 | return model 31 | -------------------------------------------------------------------------------- /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MechasAI/NeoRefacer/ad4dccfe49a1b1c803a880762dd682738a160dcc/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/ops/dcn/src/deform_conv_ext.cpp: -------------------------------------------------------------------------------- 1 | // modify from 2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #define WITH_CUDA // always use cuda 11 | #ifdef WITH_CUDA 12 | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, 13 | at::Tensor offset, at::Tensor output, 14 | at::Tensor columns, at::Tensor ones, int kW, 15 | int kH, int dW, int dH, int padW, int padH, 16 | int dilationW, int dilationH, int group, 17 | int deformable_group, int im2col_step); 18 | 19 | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, 20 | at::Tensor gradOutput, at::Tensor gradInput, 21 | at::Tensor gradOffset, at::Tensor weight, 22 | at::Tensor columns, int kW, int kH, int dW, 23 | int dH, int padW, int padH, int dilationW, 24 | int dilationH, int group, 25 | int deformable_group, int im2col_step); 26 | 27 | int deform_conv_backward_parameters_cuda( 28 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 29 | at::Tensor gradWeight, // at::Tensor gradBias, 30 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 31 | int padW, int padH, int dilationW, int dilationH, int group, 32 | int deformable_group, float scale, int im2col_step); 33 | 34 | void modulated_deform_conv_cuda_forward( 35 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 36 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 37 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 38 | const int pad_h, const int pad_w, const int dilation_h, 39 | const int dilation_w, const int group, const int deformable_group, 40 | const bool with_bias); 41 | 42 | void modulated_deform_conv_cuda_backward( 43 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 44 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 45 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 46 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 47 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 48 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 49 | const bool with_bias); 50 | #endif 51 | 52 | int deform_conv_forward(at::Tensor input, at::Tensor weight, 53 | at::Tensor offset, at::Tensor output, 54 | at::Tensor columns, at::Tensor ones, int kW, 55 | int kH, int dW, int dH, int padW, int padH, 56 | int dilationW, int dilationH, int group, 57 | int deformable_group, int im2col_step) { 58 | if (input.device().is_cuda()) { 59 | #ifdef WITH_CUDA 60 | return deform_conv_forward_cuda(input, weight, offset, output, columns, 61 | ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, 62 | deformable_group, im2col_step); 63 | #else 64 | AT_ERROR("deform conv is not compiled with GPU support"); 65 | #endif 66 | } 67 | AT_ERROR("deform conv is not implemented on CPU"); 68 | } 69 | 70 | int deform_conv_backward_input(at::Tensor input, at::Tensor offset, 71 | at::Tensor gradOutput, at::Tensor gradInput, 72 | at::Tensor gradOffset, at::Tensor weight, 73 | at::Tensor columns, int kW, int kH, int dW, 74 | int dH, int padW, int padH, int dilationW, 75 | int dilationH, int group, 76 | int deformable_group, int im2col_step) { 77 | if (input.device().is_cuda()) { 78 | #ifdef WITH_CUDA 79 | return deform_conv_backward_input_cuda(input, offset, gradOutput, 80 | gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, 81 | dilationW, dilationH, group, deformable_group, im2col_step); 82 | #else 83 | AT_ERROR("deform conv is not compiled with GPU support"); 84 | #endif 85 | } 86 | AT_ERROR("deform conv is not implemented on CPU"); 87 | } 88 | 89 | int deform_conv_backward_parameters( 90 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 91 | at::Tensor gradWeight, // at::Tensor gradBias, 92 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 93 | int padW, int padH, int dilationW, int dilationH, int group, 94 | int deformable_group, float scale, int im2col_step) { 95 | if (input.device().is_cuda()) { 96 | #ifdef WITH_CUDA 97 | return deform_conv_backward_parameters_cuda(input, offset, gradOutput, 98 | gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, 99 | dilationH, group, deformable_group, scale, im2col_step); 100 | #else 101 | AT_ERROR("deform conv is not compiled with GPU support"); 102 | #endif 103 | } 104 | AT_ERROR("deform conv is not implemented on CPU"); 105 | } 106 | 107 | void modulated_deform_conv_forward( 108 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 109 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 110 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 111 | const int pad_h, const int pad_w, const int dilation_h, 112 | const int dilation_w, const int group, const int deformable_group, 113 | const bool with_bias) { 114 | if (input.device().is_cuda()) { 115 | #ifdef WITH_CUDA 116 | return modulated_deform_conv_cuda_forward(input, weight, bias, ones, 117 | offset, mask, output, columns, kernel_h, kernel_w, stride_h, 118 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, 119 | deformable_group, with_bias); 120 | #else 121 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 122 | #endif 123 | } 124 | AT_ERROR("modulated deform conv is not implemented on CPU"); 125 | } 126 | 127 | void modulated_deform_conv_backward( 128 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 129 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 130 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 131 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 132 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 133 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 134 | const bool with_bias) { 135 | if (input.device().is_cuda()) { 136 | #ifdef WITH_CUDA 137 | return modulated_deform_conv_cuda_backward(input, weight, bias, ones, 138 | offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, 139 | grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, 140 | pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, 141 | with_bias); 142 | #else 143 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 144 | #endif 145 | } 146 | AT_ERROR("modulated deform conv is not implemented on CPU"); 147 | } 148 | 149 | 150 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 151 | m.def("deform_conv_forward", &deform_conv_forward, 152 | "deform forward"); 153 | m.def("deform_conv_backward_input", &deform_conv_backward_input, 154 | "deform_conv_backward_input"); 155 | m.def("deform_conv_backward_parameters", 156 | &deform_conv_backward_parameters, 157 | "deform_conv_backward_parameters"); 158 | m.def("modulated_deform_conv_forward", 159 | &modulated_deform_conv_forward, 160 | "modulated deform conv forward"); 161 | m.def("modulated_deform_conv_backward", 162 | &modulated_deform_conv_backward, 163 | "modulated deform conv backward"); 164 | } 165 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | 7 | try: 8 | from . import fused_act_ext 9 | except ImportError: 10 | import os 11 | BASICSR_JIT = os.getenv('BASICSR_JIT') 12 | if BASICSR_JIT == 'True': 13 | from torch.utils.cpp_extension import load 14 | module_path = os.path.dirname(__file__) 15 | fused_act_ext = load( 16 | 'fused', 17 | sources=[ 18 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 19 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 20 | ], 21 | ) 22 | 23 | 24 | class FusedLeakyReLUFunctionBackward(Function): 25 | 26 | @staticmethod 27 | def forward(ctx, grad_output, out, negative_slope, scale): 28 | ctx.save_for_backward(out) 29 | ctx.negative_slope = negative_slope 30 | ctx.scale = scale 31 | 32 | empty = grad_output.new_empty(0) 33 | 34 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 35 | 36 | dim = [0] 37 | 38 | if grad_input.ndim > 2: 39 | dim += list(range(2, grad_input.ndim)) 40 | 41 | grad_bias = grad_input.sum(dim).detach() 42 | 43 | return grad_input, grad_bias 44 | 45 | @staticmethod 46 | def backward(ctx, gradgrad_input, gradgrad_bias): 47 | out, = ctx.saved_tensors 48 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 49 | ctx.scale) 50 | 51 | return gradgrad_out, None, None, None 52 | 53 | 54 | class FusedLeakyReLUFunction(Function): 55 | 56 | @staticmethod 57 | def forward(ctx, input, bias, negative_slope, scale): 58 | empty = input.new_empty(0) 59 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 60 | ctx.save_for_backward(out) 61 | ctx.negative_slope = negative_slope 62 | ctx.scale = scale 63 | 64 | return out 65 | 66 | @staticmethod 67 | def backward(ctx, grad_output): 68 | out, = ctx.saved_tensors 69 | 70 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 71 | 72 | return grad_input, grad_bias, None, None 73 | 74 | 75 | class FusedLeakyReLU(nn.Module): 76 | 77 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 78 | super().__init__() 79 | 80 | self.bias = nn.Parameter(torch.zeros(channel)) 81 | self.negative_slope = negative_slope 82 | self.scale = scale 83 | 84 | def forward(self, input): 85 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 86 | 87 | 88 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 89 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 90 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.nn import functional as F 6 | 7 | try: 8 | from . import upfirdn2d_ext 9 | except ImportError: 10 | import os 11 | BASICSR_JIT = os.getenv('BASICSR_JIT') 12 | if BASICSR_JIT == 'True': 13 | from torch.utils.cpp_extension import load 14 | module_path = os.path.dirname(__file__) 15 | upfirdn2d_ext = load( 16 | 'upfirdn2d', 17 | sources=[ 18 | os.path.join(module_path, 'src', 'upfirdn2d.cpp'), 19 | os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), 20 | ], 21 | ) 22 | 23 | 24 | class UpFirDn2dBackward(Function): 25 | 26 | @staticmethod 27 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 28 | 29 | up_x, up_y = up 30 | down_x, down_y = down 31 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 32 | 33 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 34 | 35 | grad_input = upfirdn2d_ext.upfirdn2d( 36 | grad_output, 37 | grad_kernel, 38 | down_x, 39 | down_y, 40 | up_x, 41 | up_y, 42 | g_pad_x0, 43 | g_pad_x1, 44 | g_pad_y0, 45 | g_pad_y1, 46 | ) 47 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 48 | 49 | ctx.save_for_backward(kernel) 50 | 51 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 52 | 53 | ctx.up_x = up_x 54 | ctx.up_y = up_y 55 | ctx.down_x = down_x 56 | ctx.down_y = down_y 57 | ctx.pad_x0 = pad_x0 58 | ctx.pad_x1 = pad_x1 59 | ctx.pad_y0 = pad_y0 60 | ctx.pad_y1 = pad_y1 61 | ctx.in_size = in_size 62 | ctx.out_size = out_size 63 | 64 | return grad_input 65 | 66 | @staticmethod 67 | def backward(ctx, gradgrad_input): 68 | kernel, = ctx.saved_tensors 69 | 70 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 71 | 72 | gradgrad_out = upfirdn2d_ext.upfirdn2d( 73 | gradgrad_input, 74 | kernel, 75 | ctx.up_x, 76 | ctx.up_y, 77 | ctx.down_x, 78 | ctx.down_y, 79 | ctx.pad_x0, 80 | ctx.pad_x1, 81 | ctx.pad_y0, 82 | ctx.pad_y1, 83 | ) 84 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], 85 | # ctx.out_size[1], ctx.in_size[3]) 86 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 87 | 88 | return gradgrad_out, None, None, None, None, None, None, None, None 89 | 90 | 91 | class UpFirDn2d(Function): 92 | 93 | @staticmethod 94 | def forward(ctx, input, kernel, up, down, pad): 95 | up_x, up_y = up 96 | down_x, down_y = down 97 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 98 | 99 | kernel_h, kernel_w = kernel.shape 100 | batch, channel, in_h, in_w = input.shape 101 | ctx.in_size = input.shape 102 | 103 | input = input.reshape(-1, in_h, in_w, 1) 104 | 105 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 106 | 107 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 108 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 109 | ctx.out_size = (out_h, out_w) 110 | 111 | ctx.up = (up_x, up_y) 112 | ctx.down = (down_x, down_y) 113 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 114 | 115 | g_pad_x0 = kernel_w - pad_x0 - 1 116 | g_pad_y0 = kernel_h - pad_y0 - 1 117 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 118 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 119 | 120 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 121 | 122 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 123 | # out = out.view(major, out_h, out_w, minor) 124 | out = out.view(-1, channel, out_h, out_w) 125 | 126 | return out 127 | 128 | @staticmethod 129 | def backward(ctx, grad_output): 130 | kernel, grad_kernel = ctx.saved_tensors 131 | 132 | grad_input = UpFirDn2dBackward.apply( 133 | grad_output, 134 | kernel, 135 | grad_kernel, 136 | ctx.up, 137 | ctx.down, 138 | ctx.pad, 139 | ctx.g_pad, 140 | ctx.in_size, 141 | ctx.out_size, 142 | ) 143 | 144 | return grad_input, None, None, None, None 145 | 146 | 147 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 148 | if input.device.type == 'cpu': 149 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 150 | else: 151 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 152 | 153 | return out 154 | 155 | 156 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 157 | _, channel, in_h, in_w = input.shape 158 | input = input.reshape(-1, in_h, in_w, 1) 159 | 160 | _, in_h, in_w, minor = input.shape 161 | kernel_h, kernel_w = kernel.shape 162 | 163 | out = input.view(-1, in_h, 1, in_w, 1, minor) 164 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 165 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 166 | 167 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 168 | out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 172 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 173 | out = F.conv2d(out, w) 174 | out = out.reshape( 175 | -1, 176 | minor, 177 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 178 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 179 | ) 180 | out = out.permute(0, 2, 3, 1) 181 | out = out[:, ::down_y, ::down_x, :] 182 | 183 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 184 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 185 | 186 | return out.view(-1, channel, out_h, out_w) 187 | -------------------------------------------------------------------------------- /basicsr/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import sys 8 | import time 9 | import torch 10 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 11 | 12 | version_file = './basicsr/version.py' 13 | 14 | 15 | def readme(): 16 | with open('README.md', encoding='utf-8') as f: 17 | content = f.read() 18 | return content 19 | 20 | 21 | def get_git_hash(): 22 | 23 | def _minimal_ext_cmd(cmd): 24 | # construct minimal environment 25 | env = {} 26 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 27 | v = os.environ.get(k) 28 | if v is not None: 29 | env[k] = v 30 | # LANGUAGE is used on win32 31 | env['LANGUAGE'] = 'C' 32 | env['LANG'] = 'C' 33 | env['LC_ALL'] = 'C' 34 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 35 | return out 36 | 37 | try: 38 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 39 | sha = out.strip().decode('ascii') 40 | except OSError: 41 | sha = 'unknown' 42 | 43 | return sha 44 | 45 | 46 | def get_hash(): 47 | if os.path.exists('.git'): 48 | sha = get_git_hash()[:7] 49 | elif os.path.exists(version_file): 50 | try: 51 | from version import __version__ 52 | sha = __version__.split('+')[-1] 53 | except ImportError: 54 | raise ImportError('Unable to get git version') 55 | else: 56 | sha = 'unknown' 57 | 58 | return sha 59 | 60 | 61 | def write_version_py(): 62 | content = """# GENERATED VERSION FILE 63 | # TIME: {} 64 | __version__ = '{}' 65 | __gitsha__ = '{}' 66 | version_info = ({}) 67 | """ 68 | sha = get_hash() 69 | with open('./basicsr/VERSION', 'r') as f: 70 | SHORT_VERSION = f.read().strip() 71 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 72 | 73 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) 74 | with open(version_file, 'w') as f: 75 | f.write(version_file_str) 76 | 77 | 78 | def get_version(): 79 | with open(version_file, 'r') as f: 80 | exec(compile(f.read(), version_file, 'exec')) 81 | return locals()['__version__'] 82 | 83 | 84 | def make_cuda_ext(name, module, sources, sources_cuda=None): 85 | if sources_cuda is None: 86 | sources_cuda = [] 87 | define_macros = [] 88 | extra_compile_args = {'cxx': []} 89 | 90 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 91 | define_macros += [('WITH_CUDA', None)] 92 | extension = CUDAExtension 93 | extra_compile_args['nvcc'] = [ 94 | '-D__CUDA_NO_HALF_OPERATORS__', 95 | '-D__CUDA_NO_HALF_CONVERSIONS__', 96 | '-D__CUDA_NO_HALF2_OPERATORS__', 97 | ] 98 | sources += sources_cuda 99 | else: 100 | print(f'Compiling {name} without CUDA') 101 | extension = CppExtension 102 | 103 | return extension( 104 | name=f'{module}.{name}', 105 | sources=[os.path.join(*module.split('.'), p) for p in sources], 106 | define_macros=define_macros, 107 | extra_compile_args=extra_compile_args) 108 | 109 | 110 | def get_requirements(filename='requirements.txt'): 111 | with open(os.path.join('.', filename), 'r') as f: 112 | requires = [line.replace('\n', '') for line in f.readlines()] 113 | return requires 114 | 115 | 116 | if __name__ == '__main__': 117 | if '--cuda_ext' in sys.argv: 118 | ext_modules = [ 119 | make_cuda_ext( 120 | name='deform_conv_ext', 121 | module='ops.dcn', 122 | sources=['src/deform_conv_ext.cpp'], 123 | sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), 124 | make_cuda_ext( 125 | name='fused_act_ext', 126 | module='ops.fused_act', 127 | sources=['src/fused_bias_act.cpp'], 128 | sources_cuda=['src/fused_bias_act_kernel.cu']), 129 | make_cuda_ext( 130 | name='upfirdn2d_ext', 131 | module='ops.upfirdn2d', 132 | sources=['src/upfirdn2d.cpp'], 133 | sources_cuda=['src/upfirdn2d_kernel.cu']), 134 | ] 135 | sys.argv.remove('--cuda_ext') 136 | else: 137 | ext_modules = [] 138 | 139 | write_version_py() 140 | setup( 141 | name='basicsr', 142 | version=get_version(), 143 | description='Open Source Image and Video Super-Resolution Toolbox', 144 | long_description=readme(), 145 | long_description_content_type='text/markdown', 146 | author='Xintao Wang', 147 | author_email='xintao.wang@outlook.com', 148 | keywords='computer vision, restoration, super resolution', 149 | url='https://github.com/xinntao/BasicSR', 150 | include_package_data=True, 151 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), 152 | classifiers=[ 153 | 'Development Status :: 4 - Beta', 154 | 'License :: OSI Approved :: Apache Software License', 155 | 'Operating System :: OS Independent', 156 | 'Programming Language :: Python :: 3', 157 | 'Programming Language :: Python :: 3.7', 158 | 'Programming Language :: Python :: 3.8', 159 | ], 160 | license='Apache License 2.0', 161 | setup_requires=['cython', 'numpy'], 162 | install_requires=get_requirements(), 163 | ext_modules=ext_modules, 164 | cmdclass={'build_ext': BuildExtension}, 165 | zip_safe=False) 166 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 3 | from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 4 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 5 | 6 | __all__ = [ 7 | # file_client.py 8 | 'FileClient', 9 | # img_util.py 10 | 'img2tensor', 11 | 'tensor2img', 12 | 'imfrombytes', 13 | 'imwrite', 14 | 'crop_border', 15 | # logger.py 16 | 'MessageLogger', 17 | 'init_tb_logger', 18 | 'init_wandb_logger', 19 | 'get_root_logger', 20 | 'get_env_info', 21 | # misc.py 22 | 'set_random_seed', 23 | 'get_time_str', 24 | 'mkdir_and_rename', 25 | 'make_exp_dirs', 26 | 'scandir', 27 | 'check_resume', 28 | 'sizeof_fmt' 29 | ] 30 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | Ref: 14 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 15 | Args: 16 | file_id (str): File id. 17 | save_path (str): Save path. 18 | """ 19 | 20 | session = requests.Session() 21 | URL = 'https://docs.google.com/uc?export=download' 22 | params = {'id': file_id} 23 | 24 | response = session.get(URL, params=params, stream=True) 25 | token = get_confirm_token(response) 26 | if token: 27 | params['confirm'] = token 28 | response = session.get(URL, params=params, stream=True) 29 | 30 | # get file size 31 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | print(response_file_size) 33 | if 'Content-Range' in response_file_size.headers: 34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 49 | if file_size is not None: 50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 51 | 52 | readable_file_size = sizeof_fmt(file_size) 53 | else: 54 | pbar = None 55 | 56 | with open(destination, 'wb') as f: 57 | downloaded_size = 0 58 | for chunk in response.iter_content(chunk_size): 59 | downloaded_size += chunk_size 60 | if pbar is not None: 61 | pbar.update(1) 62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | if pbar is not None: 66 | pbar.close() 67 | 68 | 69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 70 | """Load file form http url, will download models if necessary. 71 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 72 | Args: 73 | url (str): URL to be downloaded. 74 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 75 | Default: None. 76 | progress (bool): Whether to show the download progress. Default: True. 77 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 78 | Returns: 79 | str: The path to the downloaded file. 80 | """ 81 | if model_dir is None: # use the pytorch hub_dir 82 | hub_dir = get_dir() 83 | model_dir = os.path.join(hub_dir, 'checkpoints') 84 | 85 | os.makedirs(model_dir, exist_ok=True) 86 | 87 | parts = urlparse(url) 88 | filename = os.path.basename(parts.path) 89 | if file_name is not None: 90 | filename = file_name 91 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 92 | if not os.path.exists(cached_file): 93 | print(f'Downloading: "{url}" to {cached_file}\n') 94 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 95 | return cached_file -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing differnet lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | return cv2.imwrite(file_path, img, params) 152 | 153 | 154 | def crop_border(imgs, crop_border): 155 | """Crop borders of images. 156 | 157 | Args: 158 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 159 | crop_border (int): Crop border for each end of height and weight. 160 | 161 | Returns: 162 | list[ndarray]: Cropped images. 163 | """ 164 | if crop_border == 0: 165 | return imgs 166 | else: 167 | if isinstance(imgs, list): 168 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 169 | else: 170 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 171 | -------------------------------------------------------------------------------- /basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | example.lmdb 22 | ├── data.mdb 23 | ├── lock.mdb 24 | ├── meta_info.txt 25 | 26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 27 | https://lmdb.readthedocs.io/en/release/ for more details. 28 | 29 | The meta_info.txt is a specified txt file to record the meta information 30 | of our datasets. It will be automatically created when preparing 31 | datasets by our provided dataset tools. 32 | Each line in the txt file records 1)image name (with extension), 33 | 2)image shape, and 3)compression level, separated by a white space. 34 | 35 | For example, the meta information could be: 36 | `000_00000000.png (720,1280,3) 1`, which means: 37 | 1) image name (with extension): 000_00000000.png; 38 | 2) image shape: (720,1280,3); 39 | 3) compression level: 1 40 | 41 | We use the image name without extension as the lmdb key. 42 | 43 | If `multiprocessing_read` is True, it will read all the images to memory 44 | using multiprocessing. Thus, your server needs to have enough memory. 45 | 46 | Args: 47 | data_path (str): Data path for reading images. 48 | lmdb_path (str): Lmdb save path. 49 | img_path_list (str): Image path list. 50 | keys (str): Used for lmdb keys. 51 | batch (int): After processing batch images, lmdb commits. 52 | Default: 5000. 53 | compress_level (int): Compress level when encoding images. Default: 1. 54 | multiprocessing_read (bool): Whether use multiprocessing to read all 55 | the images to memory. Default: False. 56 | n_thread (int): For multiprocessing. 57 | map_size (int | None): Map size for lmdb env. If None, use the 58 | estimated size from images. Default: None 59 | """ 60 | 61 | assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' 62 | f'but got {len(img_path_list)} and {len(keys)}') 63 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 64 | print(f'Totoal images: {len(img_path_list)}') 65 | if not lmdb_path.endswith('.lmdb'): 66 | raise ValueError("lmdb_path must end with '.lmdb'.") 67 | if osp.exists(lmdb_path): 68 | print(f'Folder {lmdb_path} already exists. Exit.') 69 | sys.exit(1) 70 | 71 | if multiprocessing_read: 72 | # read all the images to memory (multiprocessing) 73 | dataset = {} # use dict to keep the order for multiprocessing 74 | shapes = {} 75 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 76 | pbar = tqdm(total=len(img_path_list), unit='image') 77 | 78 | def callback(arg): 79 | """get the image data and update pbar.""" 80 | key, dataset[key], shapes[key] = arg 81 | pbar.update(1) 82 | pbar.set_description(f'Read {key}') 83 | 84 | pool = Pool(n_thread) 85 | for path, key in zip(img_path_list, keys): 86 | pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) 87 | pool.close() 88 | pool.join() 89 | pbar.close() 90 | print(f'Finish reading {len(img_path_list)} images.') 91 | 92 | # create lmdb environment 93 | if map_size is None: 94 | # obtain data size for one image 95 | img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 96 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 97 | data_size_per_img = img_byte.nbytes 98 | print('Data size per image is: ', data_size_per_img) 99 | data_size = data_size_per_img * len(img_path_list) 100 | map_size = data_size * 10 101 | 102 | env = lmdb.open(lmdb_path, map_size=map_size) 103 | 104 | # write data to lmdb 105 | pbar = tqdm(total=len(img_path_list), unit='chunk') 106 | txn = env.begin(write=True) 107 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 108 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 109 | pbar.update(1) 110 | pbar.set_description(f'Write {key}') 111 | key_byte = key.encode('ascii') 112 | if multiprocessing_read: 113 | img_byte = dataset[key] 114 | h, w, c = shapes[key] 115 | else: 116 | _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) 117 | h, w, c = img_shape 118 | 119 | txn.put(key_byte, img_byte) 120 | # write meta information 121 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 122 | if idx % batch == 0: 123 | txn.commit() 124 | txn = env.begin(write=True) 125 | pbar.close() 126 | txn.commit() 127 | env.close() 128 | txt_file.close() 129 | print('\nFinish writing lmdb.') 130 | 131 | 132 | def read_img_worker(path, key, compress_level): 133 | """Read image worker. 134 | 135 | Args: 136 | path (str): Image path. 137 | key (str): Image key. 138 | compress_level (int): Compress level when encoding images. 139 | 140 | Returns: 141 | str: Image key. 142 | byte: Image byte. 143 | tuple[int]: Image shape. 144 | """ 145 | 146 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 147 | if img.ndim == 2: 148 | h, w = img.shape 149 | c = 1 150 | else: 151 | h, w, c = img.shape 152 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 153 | return (key, img_byte, (h, w, c)) 154 | 155 | 156 | class LmdbMaker(): 157 | """LMDB Maker. 158 | 159 | Args: 160 | lmdb_path (str): Lmdb save path. 161 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 162 | batch (int): After processing batch images, lmdb commits. 163 | Default: 5000. 164 | compress_level (int): Compress level when encoding images. Default: 1. 165 | """ 166 | 167 | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): 168 | if not lmdb_path.endswith('.lmdb'): 169 | raise ValueError("lmdb_path must end with '.lmdb'.") 170 | if osp.exists(lmdb_path): 171 | print(f'Folder {lmdb_path} already exists. Exit.') 172 | sys.exit(1) 173 | 174 | self.lmdb_path = lmdb_path 175 | self.batch = batch 176 | self.compress_level = compress_level 177 | self.env = lmdb.open(lmdb_path, map_size=map_size) 178 | self.txn = self.env.begin(write=True) 179 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 180 | self.counter = 0 181 | 182 | def put(self, img_byte, key, img_shape): 183 | self.counter += 1 184 | key_byte = key.encode('ascii') 185 | self.txn.put(key_byte, img_byte) 186 | # write meta information 187 | h, w, c = img_shape 188 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 189 | if self.counter % self.batch == 0: 190 | self.txn.commit() 191 | self.txn = self.env.begin(write=True) 192 | 193 | def close(self): 194 | self.txn.commit() 195 | self.env.close() 196 | self.txt_file.close() 197 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class MessageLogger(): 11 | """Message logger for printing. 12 | Args: 13 | opt (dict): Config. It contains the following keys: 14 | name (str): Exp name. 15 | logger (dict): Contains 'print_freq' (str) for logger interval. 16 | train (dict): Contains 'total_iter' (int) for total iters. 17 | use_tb_logger (bool): Use tensorboard logger. 18 | start_iter (int): Start iter. Default: 1. 19 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 20 | """ 21 | 22 | def __init__(self, opt, start_iter=1, tb_logger=None): 23 | self.exp_name = opt['name'] 24 | self.interval = opt['logger']['print_freq'] 25 | self.start_iter = start_iter 26 | self.max_iters = opt['train']['total_iter'] 27 | self.use_tb_logger = opt['logger']['use_tb_logger'] 28 | self.tb_logger = tb_logger 29 | self.start_time = time.time() 30 | self.logger = get_root_logger() 31 | 32 | @master_only 33 | def __call__(self, log_vars): 34 | """Format logging message. 35 | Args: 36 | log_vars (dict): It contains the following keys: 37 | epoch (int): Epoch number. 38 | iter (int): Current iter. 39 | lrs (list): List for learning rates. 40 | time (float): Iter time. 41 | data_time (float): Data time for each iter. 42 | """ 43 | # epoch, iter, learning rates 44 | epoch = log_vars.pop('epoch') 45 | current_iter = log_vars.pop('iter') 46 | lrs = log_vars.pop('lrs') 47 | 48 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') 49 | for v in lrs: 50 | message += f'{v:.3e},' 51 | message += ')] ' 52 | 53 | # time and estimated time 54 | if 'time' in log_vars.keys(): 55 | iter_time = log_vars.pop('time') 56 | data_time = log_vars.pop('data_time') 57 | 58 | total_time = time.time() - self.start_time 59 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 60 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 61 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 62 | message += f'[eta: {eta_str}, ' 63 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 64 | 65 | # other items, especially losses 66 | for k, v in log_vars.items(): 67 | message += f'{k}: {v:.4e} ' 68 | # tensorboard logger 69 | if self.use_tb_logger: 70 | if k.startswith('l_'): 71 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 72 | else: 73 | self.tb_logger.add_scalar(k, v, current_iter) 74 | self.logger.info(message) 75 | 76 | 77 | @master_only 78 | def init_tb_logger(log_dir): 79 | from torch.utils.tensorboard import SummaryWriter 80 | tb_logger = SummaryWriter(log_dir=log_dir) 81 | return tb_logger 82 | 83 | 84 | @master_only 85 | def init_wandb_logger(opt): 86 | """We now only use wandb to sync tensorboard log.""" 87 | import wandb 88 | logger = logging.getLogger('basicsr') 89 | 90 | project = opt['logger']['wandb']['project'] 91 | resume_id = opt['logger']['wandb'].get('resume_id') 92 | if resume_id: 93 | wandb_id = resume_id 94 | resume = 'allow' 95 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 96 | else: 97 | wandb_id = wandb.util.generate_id() 98 | resume = 'never' 99 | 100 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 101 | 102 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 103 | 104 | 105 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 106 | """Get the root logger. 107 | The logger will be initialized if it has not been initialized. By default a 108 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 109 | also be added. 110 | Args: 111 | logger_name (str): root logger name. Default: 'basicsr'. 112 | log_file (str | None): The log filename. If specified, a FileHandler 113 | will be added to the root logger. 114 | log_level (int): The root logger level. Note that only the process of 115 | rank 0 is affected, while other processes will set the level to 116 | "Error" and be silent most of the time. 117 | Returns: 118 | logging.Logger: The root logger. 119 | """ 120 | logger = logging.getLogger(logger_name) 121 | # if the logger has been initialized, just return it 122 | if logger_name in initialized_logger: 123 | return logger 124 | 125 | format_str = '%(asctime)s %(levelname)s: %(message)s' 126 | stream_handler = logging.StreamHandler() 127 | stream_handler.setFormatter(logging.Formatter(format_str)) 128 | logger.addHandler(stream_handler) 129 | logger.propagate = False 130 | rank, _ = get_dist_info() 131 | if rank != 0: 132 | logger.setLevel('ERROR') 133 | elif log_file is not None: 134 | logger.setLevel(log_level) 135 | # add file handler 136 | # file_handler = logging.FileHandler(log_file, 'w') 137 | file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log 138 | file_handler.setFormatter(logging.Formatter(format_str)) 139 | file_handler.setLevel(log_level) 140 | logger.addHandler(file_handler) 141 | initialized_logger[logger_name] = True 142 | return logger 143 | 144 | 145 | def get_env_info(): 146 | """Get environment information. 147 | Currently, only log the software version. 148 | """ 149 | import torch 150 | import torchvision 151 | 152 | from basicsr.version import __version__ 153 | msg = r""" 154 | ____ _ _____ ____ 155 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 156 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 157 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 158 | /_____/ \__,_//____//_/ \___//____//_/ |_| 159 | ______ __ __ __ __ 160 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 161 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 162 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 163 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 164 | """ 165 | msg += ('\nVersion Information: ' 166 | f'\n\tBasicSR: {__version__}' 167 | f'\n\tPyTorch: {torch.__version__}' 168 | f'\n\tTorchVision: {torchvision.__version__}') 169 | return msg -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | from .logger import get_root_logger 10 | 11 | 12 | def set_random_seed(seed): 13 | """Set random seeds.""" 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | 21 | def get_time_str(): 22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 23 | 24 | 25 | def mkdir_and_rename(path): 26 | """mkdirs. If path exists, rename it with timestamp and create a new one. 27 | 28 | Args: 29 | path (str): Folder path. 30 | """ 31 | if osp.exists(path): 32 | new_name = path + '_archived_' + get_time_str() 33 | print(f'Path already exists. Rename it to {new_name}', flush=True) 34 | os.rename(path, new_name) 35 | os.makedirs(path, exist_ok=True) 36 | 37 | 38 | @master_only 39 | def make_exp_dirs(opt): 40 | """Make dirs for experiments.""" 41 | path_opt = opt['path'].copy() 42 | if opt['is_train']: 43 | mkdir_and_rename(path_opt.pop('experiments_root')) 44 | else: 45 | mkdir_and_rename(path_opt.pop('results_root')) 46 | for key, path in path_opt.items(): 47 | if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): 48 | os.makedirs(path, exist_ok=True) 49 | 50 | 51 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 52 | """Scan a directory to find the interested files. 53 | 54 | Args: 55 | dir_path (str): Path of the directory. 56 | suffix (str | tuple(str), optional): File suffix that we are 57 | interested in. Default: None. 58 | recursive (bool, optional): If set to True, recursively scan the 59 | directory. Default: False. 60 | full_path (bool, optional): If set to True, include the dir_path. 61 | Default: False. 62 | 63 | Returns: 64 | A generator for all the interested files with relative pathes. 65 | """ 66 | 67 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 68 | raise TypeError('"suffix" must be a string or tuple of strings') 69 | 70 | root = dir_path 71 | 72 | def _scandir(dir_path, suffix, recursive): 73 | for entry in os.scandir(dir_path): 74 | if not entry.name.startswith('.') and entry.is_file(): 75 | if full_path: 76 | return_path = entry.path 77 | else: 78 | return_path = osp.relpath(entry.path, root) 79 | 80 | if suffix is None: 81 | yield return_path 82 | elif return_path.endswith(suffix): 83 | yield return_path 84 | else: 85 | if recursive: 86 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 87 | else: 88 | continue 89 | 90 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 91 | 92 | 93 | def check_resume(opt, resume_iter): 94 | """Check resume states and pretrain_network paths. 95 | 96 | Args: 97 | opt (dict): Options. 98 | resume_iter (int): Resume iteration. 99 | """ 100 | logger = get_root_logger() 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | logger.warning('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (basename 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | logger.info(f"Set {name} to {opt['path'][name]}") 118 | 119 | 120 | def sizeof_fmt(size, suffix='B'): 121 | """Get human readable file size. 122 | 123 | Args: 124 | size (int): File size. 125 | suffix (str): Suffix. Default: 'B'. 126 | 127 | Return: 128 | str: Formated file siz. 129 | """ 130 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 131 | if abs(size) < 1024.0: 132 | return f'{size:3.1f} {unit}{suffix}' 133 | size /= 1024.0 134 | return f'{size:3.1f} Y{suffix}' 135 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import time 3 | from collections import OrderedDict 4 | from os import path as osp 5 | from basicsr.utils.misc import get_time_str 6 | 7 | def ordered_yaml(): 8 | """Support OrderedDict for yaml. 9 | 10 | Returns: 11 | yaml Loader and Dumper. 12 | """ 13 | try: 14 | from yaml import CDumper as Dumper 15 | from yaml import CLoader as Loader 16 | except ImportError: 17 | from yaml import Dumper, Loader 18 | 19 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 20 | 21 | def dict_representer(dumper, data): 22 | return dumper.represent_dict(data.items()) 23 | 24 | def dict_constructor(loader, node): 25 | return OrderedDict(loader.construct_pairs(node)) 26 | 27 | Dumper.add_representer(OrderedDict, dict_representer) 28 | Loader.add_constructor(_mapping_tag, dict_constructor) 29 | return Loader, Dumper 30 | 31 | 32 | def parse(opt_path, root_path, is_train=True): 33 | """Parse option file. 34 | 35 | Args: 36 | opt_path (str): Option file path. 37 | is_train (str): Indicate whether in training or not. Default: True. 38 | 39 | Returns: 40 | (dict): Options. 41 | """ 42 | with open(opt_path, mode='r') as f: 43 | Loader, _ = ordered_yaml() 44 | opt = yaml.load(f, Loader=Loader) 45 | 46 | opt['is_train'] = is_train 47 | 48 | # opt['name'] = f"{get_time_str()}_{opt['name']}" 49 | if opt['path'].get('resume_state', None): # Shangchen added 50 | resume_state_path = opt['path'].get('resume_state') 51 | opt['name'] = resume_state_path.split("/")[-3] 52 | else: 53 | opt['name'] = f"{get_time_str()}_{opt['name']}" 54 | 55 | 56 | # datasets 57 | for phase, dataset in opt['datasets'].items(): 58 | # for several datasets, e.g., test_1, test_2 59 | phase = phase.split('_')[0] 60 | dataset['phase'] = phase 61 | if 'scale' in opt: 62 | dataset['scale'] = opt['scale'] 63 | if dataset.get('dataroot_gt') is not None: 64 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 65 | if dataset.get('dataroot_lq') is not None: 66 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 67 | 68 | # paths 69 | for key, val in opt['path'].items(): 70 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 71 | opt['path'][key] = osp.expanduser(val) 72 | 73 | if is_train: 74 | experiments_root = osp.join(root_path, 'experiments', opt['name']) 75 | opt['path']['experiments_root'] = experiments_root 76 | opt['path']['models'] = osp.join(experiments_root, 'models') 77 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 78 | opt['path']['log'] = experiments_root 79 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 80 | 81 | else: # test 82 | results_root = osp.join(root_path, 'results', opt['name']) 83 | opt['path']['results_root'] = results_root 84 | opt['path']['log'] = results_root 85 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 86 | 87 | return opt 88 | 89 | 90 | def dict2str(opt, indent_level=1): 91 | """dict to string for printing options. 92 | 93 | Args: 94 | opt (dict): Option dict. 95 | indent_level (int): Indent level. Default: 1. 96 | 97 | Return: 98 | (str): Option string for printing. 99 | """ 100 | msg = '\n' 101 | for k, v in opt.items(): 102 | if isinstance(v, dict): 103 | msg += ' ' * (indent_level * 2) + k + ':[' 104 | msg += dict2str(v, indent_level + 1) 105 | msg += ' ' * (indent_level * 2) + ']\n' 106 | else: 107 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 108 | return msg 109 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj): 39 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 40 | f"in '{self._name}' registry!") 41 | self._obj_map[name] = obj 42 | 43 | def register(self, obj=None): 44 | """ 45 | Register the given object under the the name `obj.__name__`. 46 | Can be used as either a decorator or not. 47 | See docstring of this class for usage. 48 | """ 49 | if obj is None: 50 | # used as a decorator 51 | def deco(func_or_class): 52 | name = func_or_class.__name__ 53 | self._do_register(name, func_or_class) 54 | return func_or_class 55 | 56 | return deco 57 | 58 | # used as a function call 59 | name = obj.__name__ 60 | self._do_register(name, obj) 61 | 62 | def get(self, name): 63 | ret = self._obj_map.get(name) 64 | if ret is None: 65 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 66 | return ret 67 | 68 | def __contains__(self, name): 69 | return name in self._obj_map 70 | 71 | def __iter__(self): 72 | return iter(self._obj_map.items()) 73 | 74 | def keys(self): 75 | return self._obj_map.keys() 76 | 77 | 78 | DATASET_REGISTRY = Registry('dataset') 79 | ARCH_REGISTRY = Registry('arch') 80 | MODEL_REGISTRY = Registry('model') 81 | LOSS_REGISTRY = Registry('loss') 82 | METRIC_REGISTRY = Registry('metric') 83 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.3.2' 2 | __gitsha__ = '' 3 | version_info = (1, 3, 2) 4 | -------------------------------------------------------------------------------- /codeformer_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import cv2 4 | import numpy as np 5 | from pathlib import Path 6 | from torchvision.transforms.functional import normalize 7 | from basicsr.utils import img2tensor, tensor2img 8 | from basicsr.utils.download_util import load_file_from_url 9 | from facelib.utils.face_restoration_helper import FaceRestoreHelper 10 | from basicsr.utils.registry import ARCH_REGISTRY 11 | 12 | # Cross-platform device selection: CUDA > MPS > CPU 13 | if torch.cuda.is_available(): 14 | device = torch.device("cuda") 15 | elif torch.backends.mps.is_available(): 16 | device = torch.device("mps") 17 | else: 18 | device = torch.device("cpu") 19 | 20 | # Download and load model 21 | pretrain_model_url = { 22 | 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', 23 | } 24 | 25 | net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, 26 | connect_list=['32', '64', '128', '256']).to(device) 27 | 28 | ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], 29 | model_dir='weights/CodeFormer', progress=True, file_name=None) 30 | checkpoint = torch.load(ckpt_path, map_location=device)['params_ema'] 31 | net.load_state_dict(checkpoint) 32 | net.eval() 33 | 34 | face_helper = FaceRestoreHelper( 35 | upscale_factor=1, 36 | face_size=512, 37 | crop_ratio=(1, 1), 38 | det_model='retinaface_resnet50', 39 | save_ext='jpg', 40 | use_parse=True, 41 | device=device 42 | ) 43 | 44 | def _enhance_img(img: np.ndarray, w: float = 0.5) -> np.ndarray: 45 | """ 46 | Internal helper to enhance a numpy image with CodeFormer. 47 | """ 48 | face_helper.clean_all() 49 | face_helper.read_image(img) 50 | num_faces = face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) 51 | if num_faces == 0: 52 | return img # Return original if no faces detected 53 | 54 | face_helper.align_warp_face() 55 | 56 | for cropped_face in face_helper.cropped_faces: 57 | cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True).to(device) 58 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 59 | cropped_face_t = cropped_face_t.unsqueeze(0) # (1, 3, H, W), already on correct device 60 | 61 | with torch.no_grad(): 62 | output = net(cropped_face_t, w=w, adain=True)[0] 63 | restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) 64 | 65 | restored_face = restored_face.astype('uint8') 66 | face_helper.add_restored_face(restored_face) 67 | 68 | face_helper.get_inverse_affine(None) 69 | restored_img = face_helper.paste_faces_to_input_image() 70 | return restored_img 71 | 72 | def enhance_image(input_image_path: str, w: float = 0.5) -> str: 73 | """ 74 | Enhances an input image using CodeFormer and saves it with a '.enhanced.jpg' suffix. 75 | """ 76 | input_path = Path(input_image_path) 77 | output_path = input_path.with_name(f"{input_path.stem}.enhanced.jpg") 78 | 79 | img = cv2.imread(str(input_path), cv2.IMREAD_COLOR) 80 | if img is None: 81 | raise ValueError(f"Cannot read image: {input_image_path}") 82 | 83 | restored_img = _enhance_img(img, w=w) 84 | 85 | os.makedirs(output_path.parent, exist_ok=True) 86 | cv2.imwrite(str(output_path), restored_img) 87 | print(f"Enhanced image saved to: {output_path}") 88 | return str(output_path) 89 | 90 | def enhance_image_memory(img: np.ndarray, w: float = 0.5) -> np.ndarray: 91 | """ 92 | Enhances an input image entirely in memory and returns the enhanced image. 93 | """ 94 | return _enhance_img(img, w=w) -------------------------------------------------------------------------------- /demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MechasAI/NeoRefacer/ad4dccfe49a1b1c803a880762dd682738a160dcc/demo.jpg -------------------------------------------------------------------------------- /facelib/detection/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from copy import deepcopy 5 | 6 | from facelib.utils import load_file_from_url 7 | from facelib.utils import download_pretrained_models 8 | from facelib.detection.yolov5face.models.common import Conv 9 | 10 | from .retinaface.retinaface import RetinaFace 11 | from .yolov5face.face_detector import YoloDetector 12 | 13 | 14 | def init_detection_model(model_name, half=False, device='cuda'): 15 | if 'retinaface' in model_name: 16 | model = init_retinaface_model(model_name, half, device) 17 | elif 'YOLOv5' in model_name: 18 | model = init_yolov5face_model(model_name, device) 19 | else: 20 | raise NotImplementedError(f'{model_name} is not implemented.') 21 | 22 | return model 23 | 24 | 25 | def init_retinaface_model(model_name, half=False, device='cuda'): 26 | if model_name == 'retinaface_resnet50': 27 | model = RetinaFace(network_name='resnet50', half=half) 28 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth' 29 | elif model_name == 'retinaface_mobile0.25': 30 | model = RetinaFace(network_name='mobile0.25', half=half) 31 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' 32 | else: 33 | raise NotImplementedError(f'{model_name} is not implemented.') 34 | 35 | model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) 36 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 37 | # remove unnecessary 'module.' 38 | for k, v in deepcopy(load_net).items(): 39 | if k.startswith('module.'): 40 | load_net[k[7:]] = v 41 | load_net.pop(k) 42 | model.load_state_dict(load_net, strict=True) 43 | model.eval() 44 | model = model.to(device) 45 | 46 | return model 47 | 48 | 49 | def init_yolov5face_model(model_name, device='cuda'): 50 | if model_name == 'YOLOv5l': 51 | model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) 52 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth' 53 | elif model_name == 'YOLOv5n': 54 | model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) 55 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth' 56 | else: 57 | raise NotImplementedError(f'{model_name} is not implemented.') 58 | 59 | model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) 60 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 61 | model.detector.load_state_dict(load_net, strict=True) 62 | model.detector.eval() 63 | model.detector = model.detector.to(device).float() 64 | 65 | for m in model.detector.modules(): 66 | if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 67 | m.inplace = True # pytorch 1.7.0 compatibility 68 | elif isinstance(m, Conv): 69 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 70 | 71 | return model 72 | 73 | 74 | # Download from Google Drive 75 | # def init_yolov5face_model(model_name, device='cuda'): 76 | # if model_name == 'YOLOv5l': 77 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) 78 | # f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} 79 | # elif model_name == 'YOLOv5n': 80 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) 81 | # f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} 82 | # else: 83 | # raise NotImplementedError(f'{model_name} is not implemented.') 84 | 85 | # model_path = os.path.join('weights/facelib', list(f_id.keys())[0]) 86 | # if not os.path.exists(model_path): 87 | # download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib') 88 | 89 | # load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 90 | # model.detector.load_state_dict(load_net, strict=True) 91 | # model.detector.eval() 92 | # model.detector = model.detector.to(device).float() 93 | 94 | # for m in model.detector.modules(): 95 | # if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 96 | # m.inplace = True # pytorch 1.7.0 compatibility 97 | # elif isinstance(m, Conv): 98 | # m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 99 | 100 | # return model -------------------------------------------------------------------------------- /facelib/detection/align_trans.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from .matlab_cp2tform import get_similarity_transform_for_cv2 5 | 6 | # reference facial points, a list of coordinates (x,y) 7 | REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], 8 | [33.54930115, 92.3655014], [62.72990036, 92.20410156]] 9 | 10 | DEFAULT_CROP_SIZE = (96, 112) 11 | 12 | 13 | class FaceWarpException(Exception): 14 | 15 | def __str__(self): 16 | return 'In File {}:{}'.format(__file__, super.__str__(self)) 17 | 18 | 19 | def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): 20 | """ 21 | Function: 22 | ---------- 23 | get reference 5 key points according to crop settings: 24 | 0. Set default crop_size: 25 | if default_square: 26 | crop_size = (112, 112) 27 | else: 28 | crop_size = (96, 112) 29 | 1. Pad the crop_size by inner_padding_factor in each side; 30 | 2. Resize crop_size into (output_size - outer_padding*2), 31 | pad into output_size with outer_padding; 32 | 3. Output reference_5point; 33 | Parameters: 34 | ---------- 35 | @output_size: (w, h) or None 36 | size of aligned face image 37 | @inner_padding_factor: (w_factor, h_factor) 38 | padding factor for inner (w, h) 39 | @outer_padding: (w_pad, h_pad) 40 | each row is a pair of coordinates (x, y) 41 | @default_square: True or False 42 | if True: 43 | default crop_size = (112, 112) 44 | else: 45 | default crop_size = (96, 112); 46 | !!! make sure, if output_size is not None: 47 | (output_size - outer_padding) 48 | = some_scale * (default crop_size * (1.0 + 49 | inner_padding_factor)) 50 | Returns: 51 | ---------- 52 | @reference_5point: 5x2 np.array 53 | each row is a pair of transformed coordinates (x, y) 54 | """ 55 | 56 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) 57 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE) 58 | 59 | # 0) make the inner region a square 60 | if default_square: 61 | size_diff = max(tmp_crop_size) - tmp_crop_size 62 | tmp_5pts += size_diff / 2 63 | tmp_crop_size += size_diff 64 | 65 | if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]): 66 | 67 | return tmp_5pts 68 | 69 | if (inner_padding_factor == 0 and outer_padding == (0, 0)): 70 | if output_size is None: 71 | return tmp_5pts 72 | else: 73 | raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) 74 | 75 | # check output size 76 | if not (0 <= inner_padding_factor <= 1.0): 77 | raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') 78 | 79 | if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None): 80 | output_size = tmp_crop_size * \ 81 | (1 + inner_padding_factor * 2).astype(np.int32) 82 | output_size += np.array(outer_padding) 83 | if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): 84 | raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])') 85 | 86 | # 1) pad the inner region according inner_padding_factor 87 | if inner_padding_factor > 0: 88 | size_diff = tmp_crop_size * inner_padding_factor * 2 89 | tmp_5pts += size_diff / 2 90 | tmp_crop_size += np.round(size_diff).astype(np.int32) 91 | 92 | # 2) resize the padded inner region 93 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 94 | 95 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: 96 | raise FaceWarpException('Must have (output_size - outer_padding)' 97 | '= some_scale * (crop_size * (1.0 + inner_padding_factor)') 98 | 99 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] 100 | tmp_5pts = tmp_5pts * scale_factor 101 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) 102 | # tmp_5pts = tmp_5pts + size_diff / 2 103 | tmp_crop_size = size_bf_outer_pad 104 | 105 | # 3) add outer_padding to make output_size 106 | reference_5point = tmp_5pts + np.array(outer_padding) 107 | tmp_crop_size = output_size 108 | 109 | return reference_5point 110 | 111 | 112 | def get_affine_transform_matrix(src_pts, dst_pts): 113 | """ 114 | Function: 115 | ---------- 116 | get affine transform matrix 'tfm' from src_pts to dst_pts 117 | Parameters: 118 | ---------- 119 | @src_pts: Kx2 np.array 120 | source points matrix, each row is a pair of coordinates (x, y) 121 | @dst_pts: Kx2 np.array 122 | destination points matrix, each row is a pair of coordinates (x, y) 123 | Returns: 124 | ---------- 125 | @tfm: 2x3 np.array 126 | transform matrix from src_pts to dst_pts 127 | """ 128 | 129 | tfm = np.float32([[1, 0, 0], [0, 1, 0]]) 130 | n_pts = src_pts.shape[0] 131 | ones = np.ones((n_pts, 1), src_pts.dtype) 132 | src_pts_ = np.hstack([src_pts, ones]) 133 | dst_pts_ = np.hstack([dst_pts, ones]) 134 | 135 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) 136 | 137 | if rank == 3: 138 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) 139 | elif rank == 2: 140 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) 141 | 142 | return tfm 143 | 144 | 145 | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'): 146 | """ 147 | Function: 148 | ---------- 149 | apply affine transform 'trans' to uv 150 | Parameters: 151 | ---------- 152 | @src_img: 3x3 np.array 153 | input image 154 | @facial_pts: could be 155 | 1)a list of K coordinates (x,y) 156 | or 157 | 2) Kx2 or 2xK np.array 158 | each row or col is a pair of coordinates (x, y) 159 | @reference_pts: could be 160 | 1) a list of K coordinates (x,y) 161 | or 162 | 2) Kx2 or 2xK np.array 163 | each row or col is a pair of coordinates (x, y) 164 | or 165 | 3) None 166 | if None, use default reference facial points 167 | @crop_size: (w, h) 168 | output face image size 169 | @align_type: transform type, could be one of 170 | 1) 'similarity': use similarity transform 171 | 2) 'cv2_affine': use the first 3 points to do affine transform, 172 | by calling cv2.getAffineTransform() 173 | 3) 'affine': use all points to do affine transform 174 | Returns: 175 | ---------- 176 | @face_img: output face image with size (w, h) = @crop_size 177 | """ 178 | 179 | if reference_pts is None: 180 | if crop_size[0] == 96 and crop_size[1] == 112: 181 | reference_pts = REFERENCE_FACIAL_POINTS 182 | else: 183 | default_square = False 184 | inner_padding_factor = 0 185 | outer_padding = (0, 0) 186 | output_size = crop_size 187 | 188 | reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, 189 | default_square) 190 | 191 | ref_pts = np.float32(reference_pts) 192 | ref_pts_shp = ref_pts.shape 193 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: 194 | raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2') 195 | 196 | if ref_pts_shp[0] == 2: 197 | ref_pts = ref_pts.T 198 | 199 | src_pts = np.float32(facial_pts) 200 | src_pts_shp = src_pts.shape 201 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: 202 | raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2') 203 | 204 | if src_pts_shp[0] == 2: 205 | src_pts = src_pts.T 206 | 207 | if src_pts.shape != ref_pts.shape: 208 | raise FaceWarpException('facial_pts and reference_pts must have the same shape') 209 | 210 | if align_type == 'cv2_affine': 211 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) 212 | elif align_type == 'affine': 213 | tfm = get_affine_transform_matrix(src_pts, ref_pts) 214 | else: 215 | tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) 216 | 217 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) 218 | 219 | return face_img 220 | -------------------------------------------------------------------------------- /facelib/detection/retinaface/retinaface_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv_bn(inp, oup, stride=1, leaky=0): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), 9 | nn.LeakyReLU(negative_slope=leaky, inplace=True)) 10 | 11 | 12 | def conv_bn_no_relu(inp, oup, stride): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | nn.BatchNorm2d(oup), 16 | ) 17 | 18 | 19 | def conv_bn1X1(inp, oup, stride, leaky=0): 20 | return nn.Sequential( 21 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), 22 | nn.LeakyReLU(negative_slope=leaky, inplace=True)) 23 | 24 | 25 | def conv_dw(inp, oup, stride, leaky=0.1): 26 | return nn.Sequential( 27 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 28 | nn.BatchNorm2d(inp), 29 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 30 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 31 | nn.BatchNorm2d(oup), 32 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 33 | ) 34 | 35 | 36 | class SSH(nn.Module): 37 | 38 | def __init__(self, in_channel, out_channel): 39 | super(SSH, self).__init__() 40 | assert out_channel % 4 == 0 41 | leaky = 0 42 | if (out_channel <= 64): 43 | leaky = 0.1 44 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) 45 | 46 | self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) 47 | self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 48 | 49 | self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) 50 | self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 51 | 52 | def forward(self, input): 53 | conv3X3 = self.conv3X3(input) 54 | 55 | conv5X5_1 = self.conv5X5_1(input) 56 | conv5X5 = self.conv5X5_2(conv5X5_1) 57 | 58 | conv7X7_2 = self.conv7X7_2(conv5X5_1) 59 | conv7X7 = self.conv7x7_3(conv7X7_2) 60 | 61 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) 62 | out = F.relu(out) 63 | return out 64 | 65 | 66 | class FPN(nn.Module): 67 | 68 | def __init__(self, in_channels_list, out_channels): 69 | super(FPN, self).__init__() 70 | leaky = 0 71 | if (out_channels <= 64): 72 | leaky = 0.1 73 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) 74 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) 75 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) 76 | 77 | self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) 78 | self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) 79 | 80 | def forward(self, input): 81 | # names = list(input.keys()) 82 | # input = list(input.values()) 83 | 84 | output1 = self.output1(input[0]) 85 | output2 = self.output2(input[1]) 86 | output3 = self.output3(input[2]) 87 | 88 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') 89 | output2 = output2 + up3 90 | output2 = self.merge2(output2) 91 | 92 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') 93 | output1 = output1 + up2 94 | output1 = self.merge1(output1) 95 | 96 | out = [output1, output2, output3] 97 | return out 98 | 99 | 100 | class MobileNetV1(nn.Module): 101 | 102 | def __init__(self): 103 | super(MobileNetV1, self).__init__() 104 | self.stage1 = nn.Sequential( 105 | conv_bn(3, 8, 2, leaky=0.1), # 3 106 | conv_dw(8, 16, 1), # 7 107 | conv_dw(16, 32, 2), # 11 108 | conv_dw(32, 32, 1), # 19 109 | conv_dw(32, 64, 2), # 27 110 | conv_dw(64, 64, 1), # 43 111 | ) 112 | self.stage2 = nn.Sequential( 113 | conv_dw(64, 128, 2), # 43 + 16 = 59 114 | conv_dw(128, 128, 1), # 59 + 32 = 91 115 | conv_dw(128, 128, 1), # 91 + 32 = 123 116 | conv_dw(128, 128, 1), # 123 + 32 = 155 117 | conv_dw(128, 128, 1), # 155 + 32 = 187 118 | conv_dw(128, 128, 1), # 187 + 32 = 219 119 | ) 120 | self.stage3 = nn.Sequential( 121 | conv_dw(128, 256, 2), # 219 +3 2 = 241 122 | conv_dw(256, 256, 1), # 241 + 64 = 301 123 | ) 124 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 125 | self.fc = nn.Linear(256, 1000) 126 | 127 | def forward(self, x): 128 | x = self.stage1(x) 129 | x = self.stage2(x) 130 | x = self.stage3(x) 131 | x = self.avg(x) 132 | # x = self.model(x) 133 | x = x.view(-1, 256) 134 | x = self.fc(x) 135 | return x 136 | 137 | 138 | class ClassHead(nn.Module): 139 | 140 | def __init__(self, inchannels=512, num_anchors=3): 141 | super(ClassHead, self).__init__() 142 | self.num_anchors = num_anchors 143 | self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) 144 | 145 | def forward(self, x): 146 | out = self.conv1x1(x) 147 | out = out.permute(0, 2, 3, 1).contiguous() 148 | 149 | return out.view(out.shape[0], -1, 2) 150 | 151 | 152 | class BboxHead(nn.Module): 153 | 154 | def __init__(self, inchannels=512, num_anchors=3): 155 | super(BboxHead, self).__init__() 156 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) 157 | 158 | def forward(self, x): 159 | out = self.conv1x1(x) 160 | out = out.permute(0, 2, 3, 1).contiguous() 161 | 162 | return out.view(out.shape[0], -1, 4) 163 | 164 | 165 | class LandmarkHead(nn.Module): 166 | 167 | def __init__(self, inchannels=512, num_anchors=3): 168 | super(LandmarkHead, self).__init__() 169 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) 170 | 171 | def forward(self, x): 172 | out = self.conv1x1(x) 173 | out = out.permute(0, 2, 3, 1).contiguous() 174 | 175 | return out.view(out.shape[0], -1, 10) 176 | 177 | 178 | def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): 179 | classhead = nn.ModuleList() 180 | for i in range(fpn_num): 181 | classhead.append(ClassHead(inchannels, anchor_num)) 182 | return classhead 183 | 184 | 185 | def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): 186 | bboxhead = nn.ModuleList() 187 | for i in range(fpn_num): 188 | bboxhead.append(BboxHead(inchannels, anchor_num)) 189 | return bboxhead 190 | 191 | 192 | def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): 193 | landmarkhead = nn.ModuleList() 194 | for i in range(fpn_num): 195 | landmarkhead.append(LandmarkHead(inchannels, anchor_num)) 196 | return landmarkhead 197 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MechasAI/NeoRefacer/ad4dccfe49a1b1c803a880762dd682738a160dcc/facelib/detection/yolov5face/__init__.py -------------------------------------------------------------------------------- /facelib/detection/yolov5face/face_detector.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from pathlib import Path 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | 10 | from facelib.detection.yolov5face.models.common import Conv 11 | from facelib.detection.yolov5face.models.yolo import Model 12 | from facelib.detection.yolov5face.utils.datasets import letterbox 13 | from facelib.detection.yolov5face.utils.general import ( 14 | check_img_size, 15 | non_max_suppression_face, 16 | scale_coords, 17 | scale_coords_landmarks, 18 | ) 19 | 20 | IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.'))) >= (1, 9, 0) 21 | 22 | 23 | def isListempty(inList): 24 | if isinstance(inList, list): # Is a list 25 | return all(map(isListempty, inList)) 26 | return False # Not a list 27 | 28 | class YoloDetector: 29 | def __init__( 30 | self, 31 | config_name, 32 | min_face=10, 33 | target_size=None, 34 | device='cuda', 35 | ): 36 | """ 37 | config_name: name of .yaml config with network configuration from models/ folder. 38 | min_face : minimal face size in pixels. 39 | target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. 40 | None for original resolution. 41 | """ 42 | self._class_path = Path(__file__).parent.absolute() 43 | self.target_size = target_size 44 | self.min_face = min_face 45 | self.detector = Model(cfg=config_name) 46 | self.device = device 47 | 48 | 49 | def _preprocess(self, imgs): 50 | """ 51 | Preprocessing image before passing through the network. Resize and conversion to torch tensor. 52 | """ 53 | pp_imgs = [] 54 | for img in imgs: 55 | h0, w0 = img.shape[:2] # orig hw 56 | if self.target_size: 57 | r = self.target_size / min(h0, w0) # resize image to img_size 58 | if r < 1: 59 | img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) 60 | 61 | imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size 62 | img = letterbox(img, new_shape=imgsz)[0] 63 | pp_imgs.append(img) 64 | pp_imgs = np.array(pp_imgs) 65 | pp_imgs = pp_imgs.transpose(0, 3, 1, 2) 66 | pp_imgs = torch.from_numpy(pp_imgs).to(self.device) 67 | pp_imgs = pp_imgs.float() # uint8 to fp16/32 68 | return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 69 | 70 | def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): 71 | """ 72 | Postprocessing of raw pytorch model output. 73 | Returns: 74 | bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. 75 | points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). 76 | """ 77 | bboxes = [[] for _ in range(len(origimgs))] 78 | landmarks = [[] for _ in range(len(origimgs))] 79 | 80 | pred = non_max_suppression_face(pred, conf_thres, iou_thres) 81 | 82 | for image_id, origimg in enumerate(origimgs): 83 | img_shape = origimg.shape 84 | image_height, image_width = img_shape[:2] 85 | gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh 86 | gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks 87 | det = pred[image_id].cpu() 88 | scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() 89 | scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() 90 | 91 | for j in range(det.size()[0]): 92 | box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() 93 | box = list( 94 | map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height]) 95 | ) 96 | if box[3] - box[1] < self.min_face: 97 | continue 98 | lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() 99 | lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) 100 | lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] 101 | bboxes[image_id].append(box) 102 | landmarks[image_id].append(lm) 103 | return bboxes, landmarks 104 | 105 | def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): 106 | """ 107 | Get bbox coordinates and keypoints of faces on original image. 108 | Params: 109 | imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) 110 | conf_thres: confidence threshold for each prediction 111 | iou_thres: threshold for NMS (filter of intersecting bboxes) 112 | Returns: 113 | bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. 114 | points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). 115 | """ 116 | # Pass input images through face detector 117 | images = imgs if isinstance(imgs, list) else [imgs] 118 | images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] 119 | origimgs = copy.deepcopy(images) 120 | 121 | images = self._preprocess(images) 122 | 123 | if IS_HIGH_VERSION: 124 | with torch.inference_mode(): # for pytorch>=1.9 125 | pred = self.detector(images)[0] 126 | else: 127 | with torch.no_grad(): # for pytorch<1.9 128 | pred = self.detector(images)[0] 129 | 130 | bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) 131 | 132 | # return bboxes, points 133 | if not isListempty(points): 134 | bboxes = np.array(bboxes).reshape(-1,4) 135 | points = np.array(points).reshape(-1,10) 136 | padding = bboxes[:,0].reshape(-1,1) 137 | return np.concatenate((bboxes, padding, points), axis=1) 138 | else: 139 | return None 140 | 141 | def __call__(self, *args): 142 | return self.predict(*args) -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MechasAI/NeoRefacer/ad4dccfe49a1b1c803a880762dd682738a160dcc/facelib/detection/yolov5face/models/__init__.py -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/experimental.py: -------------------------------------------------------------------------------- 1 | # # This file contains experimental modules 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | from facelib.detection.yolov5face.models.common import Conv 8 | 9 | 10 | class CrossConv(nn.Module): 11 | # Cross Convolution Downsample 12 | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): 13 | # ch_in, ch_out, kernel, stride, groups, expansion, shortcut 14 | super().__init__() 15 | c_ = int(c2 * e) # hidden channels 16 | self.cv1 = Conv(c1, c_, (1, k), (1, s)) 17 | self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) 18 | self.add = shortcut and c1 == c2 19 | 20 | def forward(self, x): 21 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) 22 | 23 | 24 | class MixConv2d(nn.Module): 25 | # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 26 | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): 27 | super().__init__() 28 | groups = len(k) 29 | if equal_ch: # equal c_ per group 30 | i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices 31 | c_ = [(i == g).sum() for g in range(groups)] # intermediate channels 32 | else: # equal weight.numel() per group 33 | b = [c2] + [0] * groups 34 | a = np.eye(groups + 1, groups, k=-1) 35 | a -= np.roll(a, 1, axis=1) 36 | a *= np.array(k) ** 2 37 | a[0] = 1 38 | c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b 39 | 40 | self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) 41 | self.bn = nn.BatchNorm2d(c2) 42 | self.act = nn.LeakyReLU(0.1, inplace=True) 43 | 44 | def forward(self, x): 45 | return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) 46 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/yolov5l.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 1 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [4,5, 8,10, 13,16] # P3/8 9 | - [23,29, 43,55, 73,105] # P4/16 10 | - [146,217, 231,300, 335,433] # P5/32 11 | 12 | # YOLOv5 backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 16 | [-1, 3, C3, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 18 | [-1, 9, C3, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 20 | [-1, 9, C3, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 22 | [-1, 1, SPP, [1024, [3,5,7]]], 23 | [-1, 3, C3, [1024, False]], # 8 24 | ] 25 | 26 | # YOLOv5 head 27 | head: 28 | [[-1, 1, Conv, [512, 1, 1]], 29 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 30 | [[-1, 5], 1, Concat, [1]], # cat backbone P4 31 | [-1, 3, C3, [512, False]], # 12 32 | 33 | [-1, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 3], 1, Concat, [1]], # cat backbone P3 36 | [-1, 3, C3, [256, False]], # 16 (P3/8-small) 37 | 38 | [-1, 1, Conv, [256, 3, 2]], 39 | [[-1, 13], 1, Concat, [1]], # cat head P4 40 | [-1, 3, C3, [512, False]], # 19 (P4/16-medium) 41 | 42 | [-1, 1, Conv, [512, 3, 2]], 43 | [[-1, 9], 1, Concat, [1]], # cat head P5 44 | [-1, 3, C3, [1024, False]], # 22 (P5/32-large) 45 | 46 | [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) 47 | ] -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/yolov5n.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 1 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [4,5, 8,10, 13,16] # P3/8 9 | - [23,29, 43,55, 73,105] # P4/16 10 | - [146,217, 231,300, 335,433] # P5/32 11 | 12 | # YOLOv5 backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 16 | [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 17 | [-1, 3, ShuffleV2Block, [128, 1]], # 2 18 | [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 19 | [-1, 7, ShuffleV2Block, [256, 1]], # 4 20 | [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 21 | [-1, 3, ShuffleV2Block, [512, 1]], # 6 22 | ] 23 | 24 | # YOLOv5 head 25 | head: 26 | [[-1, 1, Conv, [128, 1, 1]], 27 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 28 | [[-1, 4], 1, Concat, [1]], # cat backbone P4 29 | [-1, 1, C3, [128, False]], # 10 30 | 31 | [-1, 1, Conv, [128, 1, 1]], 32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 33 | [[-1, 2], 1, Concat, [1]], # cat backbone P3 34 | [-1, 1, C3, [128, False]], # 14 (P3/8-small) 35 | 36 | [-1, 1, Conv, [128, 3, 2]], 37 | [[-1, 11], 1, Concat, [1]], # cat head P4 38 | [-1, 1, C3, [128, False]], # 17 (P4/16-medium) 39 | 40 | [-1, 1, Conv, [128, 3, 2]], 41 | [[-1, 7], 1, Concat, [1]], # cat head P5 42 | [-1, 1, C3, [128, False]], # 20 (P5/32-large) 43 | 44 | [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) 45 | ] 46 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MechasAI/NeoRefacer/ad4dccfe49a1b1c803a880762dd682738a160dcc/facelib/detection/yolov5face/utils/__init__.py -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/autoanchor.py: -------------------------------------------------------------------------------- 1 | # Auto-anchor utils 2 | 3 | 4 | def check_anchor_order(m): 5 | # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary 6 | a = m.anchor_grid.prod(-1).view(-1) # anchor area 7 | da = a[-1] - a[0] # delta a 8 | ds = m.stride[-1] - m.stride[0] # delta s 9 | if da.sign() != ds.sign(): # same order 10 | print("Reversing anchor order") 11 | m.anchors[:] = m.anchors.flip(0) 12 | m.anchor_grid[:] = m.anchor_grid.flip(0) 13 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): 6 | # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 7 | shape = img.shape[:2] # current shape [height, width] 8 | if isinstance(new_shape, int): 9 | new_shape = (new_shape, new_shape) 10 | 11 | # Scale ratio (new / old) 12 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 13 | if not scaleup: # only scale down, do not scale up (for better test mAP) 14 | r = min(r, 1.0) 15 | 16 | # Compute padding 17 | ratio = r, r # width, height ratios 18 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 19 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding 20 | if auto: # minimum rectangle 21 | dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding 22 | elif scale_fill: # stretch 23 | dw, dh = 0.0, 0.0 24 | new_unpad = (new_shape[1], new_shape[0]) 25 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios 26 | 27 | dw /= 2 # divide padding into 2 sides 28 | dh /= 2 29 | 30 | if shape[::-1] != new_unpad: # resize 31 | img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) 32 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 33 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 34 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border 35 | return img, ratio, (dw, dh) 36 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/extract_ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | 5 | # Setup dynamic device selection 6 | if torch.cuda.is_available(): 7 | device = torch.device("cuda") 8 | elif torch.backends.mps.is_available(): 9 | device = torch.device("mps") 10 | else: 11 | device = torch.device("cpu") 12 | 13 | sys.path.insert(0, './facelib/detection/yolov5face') 14 | 15 | # Load the model to the selected device 16 | ckpt = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location=device) 17 | model = ckpt['model'].to(device) 18 | 19 | # Save only the weights 20 | os.makedirs('weights/facelib', exist_ok=True) 21 | torch.save(model.state_dict(), 'weights/facelib/yolov5n-face.pth') -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def fuse_conv_and_bn(conv, bn): 6 | # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ 7 | fusedconv = ( 8 | nn.Conv2d( 9 | conv.in_channels, 10 | conv.out_channels, 11 | kernel_size=conv.kernel_size, 12 | stride=conv.stride, 13 | padding=conv.padding, 14 | groups=conv.groups, 15 | bias=True, 16 | ) 17 | .requires_grad_(False) 18 | .to(conv.weight.device) 19 | ) 20 | 21 | # prepare filters 22 | w_conv = conv.weight.clone().view(conv.out_channels, -1) 23 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) 24 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) 25 | 26 | # prepare spatial bias 27 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias 28 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) 29 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) 30 | 31 | return fusedconv 32 | 33 | 34 | def copy_attr(a, b, include=(), exclude=()): 35 | # Copy attributes from b to a, options to only include [...] and to exclude [...] 36 | for k, v in b.__dict__.items(): 37 | if (include and k not in include) or k.startswith("_") or k in exclude: 38 | continue 39 | 40 | setattr(a, k, v) 41 | -------------------------------------------------------------------------------- /facelib/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from facelib.utils import load_file_from_url 4 | from .bisenet import BiSeNet 5 | from .parsenet import ParseNet 6 | 7 | 8 | def init_parsing_model(model_name='bisenet', half=False, device='cuda'): 9 | if model_name == 'bisenet': 10 | model = BiSeNet(num_class=19) 11 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' 12 | elif model_name == 'parsenet': 13 | model = ParseNet(in_size=512, out_size=512, parsing_ch=19) 14 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' 15 | else: 16 | raise NotImplementedError(f'{model_name} is not implemented.') 17 | 18 | model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) 19 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 20 | model.load_state_dict(load_net, strict=True) 21 | model.eval() 22 | model = model.to(device) 23 | return model 24 | -------------------------------------------------------------------------------- /facelib/parsing/bisenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .resnet import ResNet18 6 | 7 | 8 | class ConvBNReLU(nn.Module): 9 | 10 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): 11 | super(ConvBNReLU, self).__init__() 12 | self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) 13 | self.bn = nn.BatchNorm2d(out_chan) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = F.relu(self.bn(x)) 18 | return x 19 | 20 | 21 | class BiSeNetOutput(nn.Module): 22 | 23 | def __init__(self, in_chan, mid_chan, num_class): 24 | super(BiSeNetOutput, self).__init__() 25 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 26 | self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) 27 | 28 | def forward(self, x): 29 | feat = self.conv(x) 30 | out = self.conv_out(feat) 31 | return out, feat 32 | 33 | 34 | class AttentionRefinementModule(nn.Module): 35 | 36 | def __init__(self, in_chan, out_chan): 37 | super(AttentionRefinementModule, self).__init__() 38 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 39 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) 40 | self.bn_atten = nn.BatchNorm2d(out_chan) 41 | self.sigmoid_atten = nn.Sigmoid() 42 | 43 | def forward(self, x): 44 | feat = self.conv(x) 45 | atten = F.avg_pool2d(feat, feat.size()[2:]) 46 | atten = self.conv_atten(atten) 47 | atten = self.bn_atten(atten) 48 | atten = self.sigmoid_atten(atten) 49 | out = torch.mul(feat, atten) 50 | return out 51 | 52 | 53 | class ContextPath(nn.Module): 54 | 55 | def __init__(self): 56 | super(ContextPath, self).__init__() 57 | self.resnet = ResNet18() 58 | self.arm16 = AttentionRefinementModule(256, 128) 59 | self.arm32 = AttentionRefinementModule(512, 128) 60 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 61 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 62 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 63 | 64 | def forward(self, x): 65 | feat8, feat16, feat32 = self.resnet(x) 66 | h8, w8 = feat8.size()[2:] 67 | h16, w16 = feat16.size()[2:] 68 | h32, w32 = feat32.size()[2:] 69 | 70 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 71 | avg = self.conv_avg(avg) 72 | avg_up = F.interpolate(avg, (h32, w32), mode='nearest') 73 | 74 | feat32_arm = self.arm32(feat32) 75 | feat32_sum = feat32_arm + avg_up 76 | feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') 77 | feat32_up = self.conv_head32(feat32_up) 78 | 79 | feat16_arm = self.arm16(feat16) 80 | feat16_sum = feat16_arm + feat32_up 81 | feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') 82 | feat16_up = self.conv_head16(feat16_up) 83 | 84 | return feat8, feat16_up, feat32_up # x8, x8, x16 85 | 86 | 87 | class FeatureFusionModule(nn.Module): 88 | 89 | def __init__(self, in_chan, out_chan): 90 | super(FeatureFusionModule, self).__init__() 91 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 92 | self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) 93 | self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.sigmoid = nn.Sigmoid() 96 | 97 | def forward(self, fsp, fcp): 98 | fcat = torch.cat([fsp, fcp], dim=1) 99 | feat = self.convblk(fcat) 100 | atten = F.avg_pool2d(feat, feat.size()[2:]) 101 | atten = self.conv1(atten) 102 | atten = self.relu(atten) 103 | atten = self.conv2(atten) 104 | atten = self.sigmoid(atten) 105 | feat_atten = torch.mul(feat, atten) 106 | feat_out = feat_atten + feat 107 | return feat_out 108 | 109 | 110 | class BiSeNet(nn.Module): 111 | 112 | def __init__(self, num_class): 113 | super(BiSeNet, self).__init__() 114 | self.cp = ContextPath() 115 | self.ffm = FeatureFusionModule(256, 256) 116 | self.conv_out = BiSeNetOutput(256, 256, num_class) 117 | self.conv_out16 = BiSeNetOutput(128, 64, num_class) 118 | self.conv_out32 = BiSeNetOutput(128, 64, num_class) 119 | 120 | def forward(self, x, return_feat=False): 121 | h, w = x.size()[2:] 122 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature 123 | feat_sp = feat_res8 # replace spatial path feature with res3b1 feature 124 | feat_fuse = self.ffm(feat_sp, feat_cp8) 125 | 126 | out, feat = self.conv_out(feat_fuse) 127 | out16, feat16 = self.conv_out16(feat_cp8) 128 | out32, feat32 = self.conv_out32(feat_cp16) 129 | 130 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) 131 | out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) 132 | out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) 133 | 134 | if return_feat: 135 | feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) 136 | feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) 137 | feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) 138 | return out, out16, out32, feat, feat16, feat32 139 | else: 140 | return out, out16, out32 141 | -------------------------------------------------------------------------------- /facelib/parsing/parsenet.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/chaofengc/PSFRGAN 2 | """ 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class NormLayer(nn.Module): 9 | """Normalization Layers. 10 | 11 | Args: 12 | channels: input channels, for batch norm and instance norm. 13 | input_size: input shape without batch size, for layer norm. 14 | """ 15 | 16 | def __init__(self, channels, normalize_shape=None, norm_type='bn'): 17 | super(NormLayer, self).__init__() 18 | norm_type = norm_type.lower() 19 | self.norm_type = norm_type 20 | if norm_type == 'bn': 21 | self.norm = nn.BatchNorm2d(channels, affine=True) 22 | elif norm_type == 'in': 23 | self.norm = nn.InstanceNorm2d(channels, affine=False) 24 | elif norm_type == 'gn': 25 | self.norm = nn.GroupNorm(32, channels, affine=True) 26 | elif norm_type == 'pixel': 27 | self.norm = lambda x: F.normalize(x, p=2, dim=1) 28 | elif norm_type == 'layer': 29 | self.norm = nn.LayerNorm(normalize_shape) 30 | elif norm_type == 'none': 31 | self.norm = lambda x: x * 1.0 32 | else: 33 | assert 1 == 0, f'Norm type {norm_type} not support.' 34 | 35 | def forward(self, x, ref=None): 36 | if self.norm_type == 'spade': 37 | return self.norm(x, ref) 38 | else: 39 | return self.norm(x) 40 | 41 | 42 | class ReluLayer(nn.Module): 43 | """Relu Layer. 44 | 45 | Args: 46 | relu type: type of relu layer, candidates are 47 | - ReLU 48 | - LeakyReLU: default relu slope 0.2 49 | - PRelu 50 | - SELU 51 | - none: direct pass 52 | """ 53 | 54 | def __init__(self, channels, relu_type='relu'): 55 | super(ReluLayer, self).__init__() 56 | relu_type = relu_type.lower() 57 | if relu_type == 'relu': 58 | self.func = nn.ReLU(True) 59 | elif relu_type == 'leakyrelu': 60 | self.func = nn.LeakyReLU(0.2, inplace=True) 61 | elif relu_type == 'prelu': 62 | self.func = nn.PReLU(channels) 63 | elif relu_type == 'selu': 64 | self.func = nn.SELU(True) 65 | elif relu_type == 'none': 66 | self.func = lambda x: x * 1.0 67 | else: 68 | assert 1 == 0, f'Relu type {relu_type} not support.' 69 | 70 | def forward(self, x): 71 | return self.func(x) 72 | 73 | 74 | class ConvLayer(nn.Module): 75 | 76 | def __init__(self, 77 | in_channels, 78 | out_channels, 79 | kernel_size=3, 80 | scale='none', 81 | norm_type='none', 82 | relu_type='none', 83 | use_pad=True, 84 | bias=True): 85 | super(ConvLayer, self).__init__() 86 | self.use_pad = use_pad 87 | self.norm_type = norm_type 88 | if norm_type in ['bn']: 89 | bias = False 90 | 91 | stride = 2 if scale == 'down' else 1 92 | 93 | self.scale_func = lambda x: x 94 | if scale == 'up': 95 | self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') 96 | 97 | self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) 98 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) 99 | 100 | self.relu = ReluLayer(out_channels, relu_type) 101 | self.norm = NormLayer(out_channels, norm_type=norm_type) 102 | 103 | def forward(self, x): 104 | out = self.scale_func(x) 105 | if self.use_pad: 106 | out = self.reflection_pad(out) 107 | out = self.conv2d(out) 108 | out = self.norm(out) 109 | out = self.relu(out) 110 | return out 111 | 112 | 113 | class ResidualBlock(nn.Module): 114 | """ 115 | Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html 116 | """ 117 | 118 | def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): 119 | super(ResidualBlock, self).__init__() 120 | 121 | if scale == 'none' and c_in == c_out: 122 | self.shortcut_func = lambda x: x 123 | else: 124 | self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) 125 | 126 | scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} 127 | scale_conf = scale_config_dict[scale] 128 | 129 | self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) 130 | self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') 131 | 132 | def forward(self, x): 133 | identity = self.shortcut_func(x) 134 | 135 | res = self.conv1(x) 136 | res = self.conv2(res) 137 | return identity + res 138 | 139 | 140 | class ParseNet(nn.Module): 141 | 142 | def __init__(self, 143 | in_size=128, 144 | out_size=128, 145 | min_feat_size=32, 146 | base_ch=64, 147 | parsing_ch=19, 148 | res_depth=10, 149 | relu_type='LeakyReLU', 150 | norm_type='bn', 151 | ch_range=[32, 256]): 152 | super().__init__() 153 | self.res_depth = res_depth 154 | act_args = {'norm_type': norm_type, 'relu_type': relu_type} 155 | min_ch, max_ch = ch_range 156 | 157 | ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 158 | min_feat_size = min(in_size, min_feat_size) 159 | 160 | down_steps = int(np.log2(in_size // min_feat_size)) 161 | up_steps = int(np.log2(out_size // min_feat_size)) 162 | 163 | # =============== define encoder-body-decoder ==================== 164 | self.encoder = [] 165 | self.encoder.append(ConvLayer(3, base_ch, 3, 1)) 166 | head_ch = base_ch 167 | for i in range(down_steps): 168 | cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) 169 | self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) 170 | head_ch = head_ch * 2 171 | 172 | self.body = [] 173 | for i in range(res_depth): 174 | self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) 175 | 176 | self.decoder = [] 177 | for i in range(up_steps): 178 | cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) 179 | self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) 180 | head_ch = head_ch // 2 181 | 182 | self.encoder = nn.Sequential(*self.encoder) 183 | self.body = nn.Sequential(*self.body) 184 | self.decoder = nn.Sequential(*self.decoder) 185 | self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) 186 | self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) 187 | 188 | def forward(self, x): 189 | feat = self.encoder(x) 190 | x = feat + self.body(feat) 191 | x = self.decoder(x) 192 | out_img = self.out_img_conv(x) 193 | out_mask = self.out_mask_conv(x) 194 | return out_mask, out_img 195 | -------------------------------------------------------------------------------- /facelib/parsing/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | 12 | def __init__(self, in_chan, out_chan, stride=1): 13 | super(BasicBlock, self).__init__() 14 | self.conv1 = conv3x3(in_chan, out_chan, stride) 15 | self.bn1 = nn.BatchNorm2d(out_chan) 16 | self.conv2 = conv3x3(out_chan, out_chan) 17 | self.bn2 = nn.BatchNorm2d(out_chan) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = None 20 | if in_chan != out_chan or stride != 1: 21 | self.downsample = nn.Sequential( 22 | nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(out_chan), 24 | ) 25 | 26 | def forward(self, x): 27 | residual = self.conv1(x) 28 | residual = F.relu(self.bn1(residual)) 29 | residual = self.conv2(residual) 30 | residual = self.bn2(residual) 31 | 32 | shortcut = x 33 | if self.downsample is not None: 34 | shortcut = self.downsample(x) 35 | 36 | out = shortcut + residual 37 | out = self.relu(out) 38 | return out 39 | 40 | 41 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 42 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 43 | for i in range(bnum - 1): 44 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 45 | return nn.Sequential(*layers) 46 | 47 | 48 | class ResNet18(nn.Module): 49 | 50 | def __init__(self): 51 | super(ResNet18, self).__init__() 52 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 53 | self.bn1 = nn.BatchNorm2d(64) 54 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 55 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 56 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 57 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 58 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = F.relu(self.bn1(x)) 63 | x = self.maxpool(x) 64 | 65 | x = self.layer1(x) 66 | feat8 = self.layer2(x) # 1/8 67 | feat16 = self.layer3(feat8) # 1/16 68 | feat32 = self.layer4(feat16) # 1/32 69 | return feat8, feat16, feat32 70 | -------------------------------------------------------------------------------- /facelib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back 2 | from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir 3 | 4 | __all__ = [ 5 | 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 6 | 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' 7 | ] 8 | -------------------------------------------------------------------------------- /facelib/utils/misc.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import os.path as osp 4 | import torch 5 | from torch.hub import download_url_to_file, get_dir 6 | from urllib.parse import urlparse 7 | # from basicsr.utils.download_util import download_file_from_google_drive 8 | import gdown 9 | 10 | 11 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | 13 | 14 | def download_pretrained_models(file_ids, save_path_root): 15 | os.makedirs(save_path_root, exist_ok=True) 16 | 17 | for file_name, file_id in file_ids.items(): 18 | file_url = 'https://drive.google.com/uc?id='+file_id 19 | save_path = osp.abspath(osp.join(save_path_root, file_name)) 20 | if osp.exists(save_path): 21 | user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') 22 | if user_response.lower() == 'y': 23 | print(f'Covering {file_name} to {save_path}') 24 | gdown.download(file_url, save_path, quiet=False) 25 | # download_file_from_google_drive(file_id, save_path) 26 | elif user_response.lower() == 'n': 27 | print(f'Skipping {file_name}') 28 | else: 29 | raise ValueError('Wrong input. Only accepts Y/N.') 30 | else: 31 | print(f'Downloading {file_name} to {save_path}') 32 | gdown.download(file_url, save_path, quiet=False) 33 | # download_file_from_google_drive(file_id, save_path) 34 | 35 | 36 | def imwrite(img, file_path, params=None, auto_mkdir=True): 37 | """Write image to file. 38 | 39 | Args: 40 | img (ndarray): Image array to be written. 41 | file_path (str): Image file path. 42 | params (None or list): Same as opencv's :func:`imwrite` interface. 43 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 44 | whether to create it automatically. 45 | 46 | Returns: 47 | bool: Successful or not. 48 | """ 49 | if auto_mkdir: 50 | dir_name = os.path.abspath(os.path.dirname(file_path)) 51 | os.makedirs(dir_name, exist_ok=True) 52 | return cv2.imwrite(file_path, img, params) 53 | 54 | 55 | def img2tensor(imgs, bgr2rgb=True, float32=True): 56 | """Numpy array to tensor. 57 | 58 | Args: 59 | imgs (list[ndarray] | ndarray): Input images. 60 | bgr2rgb (bool): Whether to change bgr to rgb. 61 | float32 (bool): Whether to change to float32. 62 | 63 | Returns: 64 | list[tensor] | tensor: Tensor images. If returned results only have 65 | one element, just return tensor. 66 | """ 67 | 68 | def _totensor(img, bgr2rgb, float32): 69 | if img.shape[2] == 3 and bgr2rgb: 70 | if img.dtype == 'float64': 71 | img = img.astype('float32') 72 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 73 | img = torch.from_numpy(img.transpose(2, 0, 1)) 74 | if float32: 75 | img = img.float() 76 | return img 77 | 78 | if isinstance(imgs, list): 79 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 80 | else: 81 | return _totensor(imgs, bgr2rgb, float32) 82 | 83 | 84 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 85 | """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 86 | """ 87 | if model_dir is None: 88 | hub_dir = get_dir() 89 | model_dir = os.path.join(hub_dir, 'checkpoints') 90 | 91 | os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) 92 | 93 | parts = urlparse(url) 94 | filename = os.path.basename(parts.path) 95 | if file_name is not None: 96 | filename = file_name 97 | cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) 98 | if not os.path.exists(cached_file): 99 | print(f'Downloading: "{url}" to {cached_file}\n') 100 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 101 | return cached_file 102 | 103 | 104 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 105 | """Scan a directory to find the interested files. 106 | Args: 107 | dir_path (str): Path of the directory. 108 | suffix (str | tuple(str), optional): File suffix that we are 109 | interested in. Default: None. 110 | recursive (bool, optional): If set to True, recursively scan the 111 | directory. Default: False. 112 | full_path (bool, optional): If set to True, include the dir_path. 113 | Default: False. 114 | Returns: 115 | A generator for all the interested files with relative paths. 116 | """ 117 | 118 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 119 | raise TypeError('"suffix" must be a string or tuple of strings') 120 | 121 | root = dir_path 122 | 123 | def _scandir(dir_path, suffix, recursive): 124 | for entry in os.scandir(dir_path): 125 | if not entry.name.startswith('.') and entry.is_file(): 126 | if full_path: 127 | return_path = entry.path 128 | else: 129 | return_path = osp.relpath(entry.path, root) 130 | 131 | if suffix is None: 132 | yield return_path 133 | elif return_path.endswith(suffix): 134 | yield return_path 135 | else: 136 | if recursive: 137 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 138 | else: 139 | continue 140 | 141 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 142 | -------------------------------------------------------------------------------- /icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MechasAI/NeoRefacer/ad4dccfe49a1b1c803a880762dd682738a160dcc/icon.png -------------------------------------------------------------------------------- /output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MechasAI/NeoRefacer/ad4dccfe49a1b1c803a880762dd682738a160dcc/output/.gitkeep -------------------------------------------------------------------------------- /recognition/arcface_onnx.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : insightface.ai 3 | # @Author : Jia Guo 4 | # @Time : 2021-05-04 5 | # @Function : 6 | 7 | import numpy as np 8 | import cv2 9 | import onnx 10 | import onnxruntime 11 | import face_align 12 | 13 | __all__ = [ 14 | 'ArcFaceONNX', 15 | ] 16 | 17 | 18 | class ArcFaceONNX: 19 | def __init__(self, model_file=None, session=None): 20 | assert model_file is not None 21 | self.model_file = model_file 22 | self.session = session 23 | self.taskname = 'recognition' 24 | find_sub = False 25 | find_mul = False 26 | model = onnx.load(self.model_file) 27 | graph = model.graph 28 | for nid, node in enumerate(graph.node[:8]): 29 | #print(nid, node.name) 30 | if node.name.startswith('Sub') or node.name.startswith('_minus'): 31 | find_sub = True 32 | if node.name.startswith('Mul') or node.name.startswith('_mul'): 33 | find_mul = True 34 | if find_sub and find_mul: 35 | #mxnet arcface model 36 | input_mean = 0.0 37 | input_std = 1.0 38 | else: 39 | input_mean = 127.5 40 | input_std = 127.5 41 | self.input_mean = input_mean 42 | self.input_std = input_std 43 | #print('input mean and std:', self.input_mean, self.input_std) 44 | if self.session is None: 45 | self.session = onnxruntime.InferenceSession(self.model_file, providers=['CoreMLExecutionProvider','CUDAExecutionProvider']) 46 | input_cfg = self.session.get_inputs()[0] 47 | input_shape = input_cfg.shape 48 | input_name = input_cfg.name 49 | self.input_size = tuple(input_shape[2:4][::-1]) 50 | self.input_shape = input_shape 51 | outputs = self.session.get_outputs() 52 | output_names = [] 53 | for out in outputs: 54 | output_names.append(out.name) 55 | self.input_name = input_name 56 | self.output_names = output_names 57 | assert len(self.output_names)==1 58 | self.output_shape = outputs[0].shape 59 | 60 | def prepare(self, ctx_id, **kwargs): 61 | if ctx_id<0: 62 | self.session.set_providers(['CPUExecutionProvider']) 63 | 64 | def get(self, img, kps): 65 | aimg = face_align.norm_crop(img, landmark=kps, image_size=self.input_size[0]) 66 | embedding = self.get_feat(aimg).flatten() 67 | return embedding 68 | 69 | def compute_sim(self, feat1, feat2): 70 | from numpy.linalg import norm 71 | feat1 = feat1.ravel() 72 | feat2 = feat2.ravel() 73 | sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) 74 | return sim 75 | 76 | def get_feat(self, imgs): 77 | if not isinstance(imgs, list): 78 | imgs = [imgs] 79 | input_size = self.input_size 80 | 81 | blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, 82 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 83 | net_out = self.session.run(self.output_names, {self.input_name: blob})[0] 84 | return net_out 85 | 86 | def forward(self, batch_data): 87 | blob = (batch_data - self.input_mean) / self.input_std 88 | net_out = self.session.run(self.output_names, {self.input_name: blob})[0] 89 | return net_out 90 | 91 | 92 | -------------------------------------------------------------------------------- /recognition/face_align.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from skimage import transform as trans 4 | 5 | src1 = np.array([[51.642, 50.115], [57.617, 49.990], [35.740, 69.007], 6 | [51.157, 89.050], [57.025, 89.702]], 7 | dtype=np.float32) 8 | #<--left 9 | src2 = np.array([[45.031, 50.118], [65.568, 50.872], [39.677, 68.111], 10 | [45.177, 86.190], [64.246, 86.758]], 11 | dtype=np.float32) 12 | 13 | #---frontal 14 | src3 = np.array([[39.730, 51.138], [72.270, 51.138], [56.000, 68.493], 15 | [42.463, 87.010], [69.537, 87.010]], 16 | dtype=np.float32) 17 | 18 | #-->right 19 | src4 = np.array([[46.845, 50.872], [67.382, 50.118], [72.737, 68.111], 20 | [48.167, 86.758], [67.236, 86.190]], 21 | dtype=np.float32) 22 | 23 | #-->right profile 24 | src5 = np.array([[54.796, 49.990], [60.771, 50.115], [76.673, 69.007], 25 | [55.388, 89.702], [61.257, 89.050]], 26 | dtype=np.float32) 27 | 28 | src = np.array([src1, src2, src3, src4, src5]) 29 | src_map = {112: src, 224: src * 2} 30 | 31 | arcface_src = np.array( 32 | [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], 33 | [41.5493, 92.3655], [70.7299, 92.2041]], 34 | dtype=np.float32) 35 | 36 | arcface_src = np.expand_dims(arcface_src, axis=0) 37 | 38 | # In[66]: 39 | 40 | 41 | # lmk is prediction; src is template 42 | def estimate_norm(lmk, image_size=112, mode='arcface'): 43 | assert lmk.shape == (5, 2) 44 | tform = trans.SimilarityTransform() 45 | lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1) 46 | min_M = [] 47 | min_index = [] 48 | min_error = float('inf') 49 | if mode == 'arcface': 50 | if image_size == 112: 51 | src = arcface_src 52 | else: 53 | src = float(image_size) / 112 * arcface_src 54 | else: 55 | src = src_map[image_size] 56 | for i in np.arange(src.shape[0]): 57 | tform.estimate(lmk, src[i]) 58 | M = tform.params[0:2, :] 59 | results = np.dot(M, lmk_tran.T) 60 | results = results.T 61 | error = np.sum(np.sqrt(np.sum((results - src[i])**2, axis=1))) 62 | # print(error) 63 | if error < min_error: 64 | min_error = error 65 | min_M = M 66 | min_index = i 67 | return min_M, min_index 68 | 69 | 70 | def norm_crop(img, landmark, image_size=112, mode='arcface'): 71 | M, pose_index = estimate_norm(landmark, image_size, mode) 72 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) 73 | return warped 74 | 75 | def square_crop(im, S): 76 | if im.shape[0] > im.shape[1]: 77 | height = S 78 | width = int(float(im.shape[1]) / im.shape[0] * S) 79 | scale = float(S) / im.shape[0] 80 | else: 81 | width = S 82 | height = int(float(im.shape[0]) / im.shape[1] * S) 83 | scale = float(S) / im.shape[1] 84 | resized_im = cv2.resize(im, (width, height)) 85 | det_im = np.zeros((S, S, 3), dtype=np.uint8) 86 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im 87 | return det_im, scale 88 | 89 | 90 | def transform(data, center, output_size, scale, rotation): 91 | scale_ratio = scale 92 | rot = float(rotation) * np.pi / 180.0 93 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) 94 | t1 = trans.SimilarityTransform(scale=scale_ratio) 95 | cx = center[0] * scale_ratio 96 | cy = center[1] * scale_ratio 97 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) 98 | t3 = trans.SimilarityTransform(rotation=rot) 99 | t4 = trans.SimilarityTransform(translation=(output_size / 2, 100 | output_size / 2)) 101 | t = t1 + t2 + t3 + t4 102 | M = t.params[0:2] 103 | cropped = cv2.warpAffine(data, 104 | M, (output_size, output_size), 105 | borderValue=0.0) 106 | return cropped, M 107 | 108 | 109 | def trans_points2d(pts, M): 110 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 111 | for i in range(pts.shape[0]): 112 | pt = pts[i] 113 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 114 | new_pt = np.dot(M, new_pt) 115 | #print('new_pt', new_pt.shape, new_pt) 116 | new_pts[i] = new_pt[0:2] 117 | 118 | return new_pts 119 | 120 | 121 | def trans_points3d(pts, M): 122 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) 123 | #print(scale) 124 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 125 | for i in range(pts.shape[0]): 126 | pt = pts[i] 127 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 128 | new_pt = np.dot(M, new_pt) 129 | #print('new_pt', new_pt.shape, new_pt) 130 | new_pts[i][0:2] = new_pt[0:2] 131 | new_pts[i][2] = pts[i][2] * scale 132 | 133 | return new_pts 134 | 135 | 136 | def trans_points(pts, M): 137 | if pts.shape[1] == 2: 138 | return trans_points2d(pts, M) 139 | else: 140 | return trans_points3d(pts, M) 141 | 142 | -------------------------------------------------------------------------------- /recognition/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import os.path as osp 5 | import argparse 6 | import cv2 7 | import numpy as np 8 | import onnxruntime 9 | from scrfd import SCRFD 10 | from arcface_onnx import ArcFaceONNX 11 | 12 | onnxruntime.set_default_logger_severity(5) 13 | 14 | assets_dir = osp.expanduser('~/.insightface/models/buffalo_l') 15 | 16 | detector = SCRFD(os.path.join(assets_dir, 'det_10g.onnx')) 17 | detector.prepare(0) 18 | model_path = os.path.join(assets_dir, 'w600k_r50.onnx') 19 | rec = ArcFaceONNX(model_path) 20 | rec.prepare(0) 21 | 22 | def parse_args() -> argparse.Namespace: 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('img1', type=str) 25 | parser.add_argument('img2', type=str) 26 | return parser.parse_args() 27 | 28 | 29 | def func(args): 30 | image1 = cv2.imread(args.img1) 31 | image2 = cv2.imread(args.img2) 32 | bboxes1, kpss1 = detector.autodetect(image1, max_num=1) 33 | if bboxes1.shape[0]==0: 34 | return -1.0, "Face not found in Image-1" 35 | bboxes2, kpss2 = detector.autodetect(image2, max_num=1) 36 | if bboxes2.shape[0]==0: 37 | return -1.0, "Face not found in Image-2" 38 | kps1 = kpss1[0] 39 | kps2 = kpss2[0] 40 | feat1 = rec.get(image1, kps1) 41 | feat2 = rec.get(image2, kps2) 42 | sim = rec.compute_sim(feat1, feat2) 43 | if sim<0.2: 44 | conclu = 'They are NOT the same person' 45 | elif sim>=0.2 and sim<0.28: 46 | conclu = 'They are LIKELY TO be the same person' 47 | else: 48 | conclu = 'They ARE the same person' 49 | return sim, conclu 50 | 51 | 52 | 53 | if __name__ == '__main__': 54 | args = parse_args() 55 | output = func(args) 56 | print('sim: %.4f, message: %s'%(output[0], output[1])) 57 | 58 | -------------------------------------------------------------------------------- /refacer_bulk.py: -------------------------------------------------------------------------------- 1 | # refacer_bulk.py 2 | # 3 | # Example usage: 4 | # python refacer_bulk.py --input_path ./input --dest_face myface.jpg --facetoreplace face1.jpg --threshold 0.3 5 | # 6 | # Or, to disable similarity check (i.e., just apply the destination face to all detected faces): 7 | # python refacer_bulk.py --input_path ./input --dest_face myface.jpg 8 | 9 | import argparse 10 | import os 11 | import cv2 12 | from pathlib import Path 13 | from refacer import Refacer 14 | from PIL import Image 15 | import time 16 | import pyfiglet 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description="Bulk Image Refacer") 20 | parser.add_argument("--input_path", type=str, required=True, help="Directory containing input images") 21 | parser.add_argument("--dest_face", type=str, required=True, help="Path to destination face image") 22 | parser.add_argument("--facetoreplace", type=str, default=None, help="Path to face to replace (origin face)") 23 | parser.add_argument("--threshold", type=float, default=0.2, help="Similarity threshold (default: 0.2)") 24 | parser.add_argument("--force_cpu", action="store_true", help="Force CPU mode") 25 | parser.add_argument("--colab_performance", action="store_true", help="Enable Colab performance tweaks") 26 | return parser.parse_args() 27 | 28 | def main(): 29 | print("\033[94m" + pyfiglet.Figlet(font='slant').renderText("NeoRefacer") + "\033[0m") 30 | 31 | args = parse_args() 32 | 33 | input_dir = Path(args.input_path) 34 | 35 | refacer = Refacer(force_cpu=args.force_cpu, colab_performance=args.colab_performance) 36 | 37 | # Load destination and origin face 38 | dest_img = cv2.imread(args.dest_face) 39 | if dest_img is None: 40 | raise ValueError(f"Destination face image not found: {args.dest_face}") 41 | 42 | origin_img = None 43 | if args.facetoreplace: 44 | origin_img = cv2.imread(args.facetoreplace) 45 | if origin_img is None: 46 | raise ValueError(f"Face to replace image not found: {args.facetoreplace}") 47 | 48 | disable_similarity = origin_img is None 49 | 50 | faces_config = [{ 51 | 'origin': origin_img, 52 | 'destination': dest_img, 53 | 'threshold': args.threshold 54 | }] 55 | 56 | refacer.prepare_faces(faces_config, disable_similarity=disable_similarity) 57 | 58 | print(f"Processing images from: {input_dir}") 59 | image_files = list(input_dir.glob("*")) 60 | supported_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'} 61 | 62 | for image_path in image_files: 63 | if image_path.suffix.lower() not in supported_exts: 64 | print(f"Skipping non-image file: {image_path}") 65 | continue 66 | 67 | print(f"Refacing: {image_path}") 68 | try: 69 | refaced_path = refacer.reface_image(str(image_path), faces_config, disable_similarity=disable_similarity) 70 | print(f"Saved to: {refaced_path}") 71 | except Exception as e: 72 | print(f"Failed to process {image_path}: {e}") 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /requirements-COREML.txt: -------------------------------------------------------------------------------- 1 | ffmpeg_python==0.2.0 2 | imageio[ffmpeg]==2.37.0 3 | imagecodecs==2025.3.30 4 | gradio==5.22.0 5 | insightface==0.7.3 6 | numpy==1.24.3 7 | onnx==1.14.0 8 | onnxruntime-silicon==1.16.3 9 | opencv_python==4.7.0.72 10 | opencv_python_headless==4.7.0.72 11 | scikit-image==0.20.0 12 | tqdm==4.67.1 13 | psutil==7.0.0 14 | ngrok==1.4.0 15 | pyfiglet==1.0.2 16 | # codeformer dependencies 17 | torch==2.6.0 18 | torchvision==0.21.0 19 | gdown==5.2.0 20 | lpips==0.1.4 -------------------------------------------------------------------------------- /requirements-CPU.txt: -------------------------------------------------------------------------------- 1 | ffmpeg_python==0.2.0 2 | imageio[ffmpeg]==2.37.0 3 | imagecodecs==2025.3.30 4 | gradio==5.22.0 5 | insightface==0.7.3 6 | numpy==1.24.3 7 | onnx==1.14.0 8 | onnxruntime==1.21.0 9 | opencv_python==4.7.0.72 10 | opencv_python_headless==4.7.0.72 11 | scikit-image==0.20.0 12 | tqdm==4.67.1 13 | psutil==7.0.0 14 | ngrok==1.4.0 15 | pyfiglet==1.0.2 16 | # codeformer dependencies 17 | torch==2.6.0 18 | torchvision==0.21.0 19 | gdown==5.2.0 20 | lpips==0.1.4 -------------------------------------------------------------------------------- /requirements-GPU.txt: -------------------------------------------------------------------------------- 1 | ffmpeg_python==0.2.0 2 | imageio[ffmpeg]==2.37.0 3 | imagecodecs==2025.3.30 4 | gradio==5.22.0 5 | insightface==0.7.3 6 | numpy==1.24.3 7 | onnx==1.14.0 8 | onnxruntime_gpu==1.21.0 9 | opencv_python==4.7.0.72 10 | opencv_python_headless==4.7.0.72 11 | scikit-image==0.20.0 12 | tqdm==4.67.1 13 | psutil==7.0.0 14 | ngrok==1.4.0 15 | pyfiglet==1.0.2 16 | # codeformer dependencies 17 | # torch==2.6.0 18 | # torchvision==0.21.0 19 | gdown==5.2.0 20 | lpips==0.1.4 -------------------------------------------------------------------------------- /weights/CodeFormer/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MechasAI/NeoRefacer/ad4dccfe49a1b1c803a880762dd682738a160dcc/weights/CodeFormer/.gitkeep -------------------------------------------------------------------------------- /weights/README.md: -------------------------------------------------------------------------------- 1 | # Weights 2 | 3 | Put the downloaded pre-trained models to this folder. -------------------------------------------------------------------------------- /weights/facelib/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MechasAI/NeoRefacer/ad4dccfe49a1b1c803a880762dd682738a160dcc/weights/facelib/.gitkeep -------------------------------------------------------------------------------- /weights/inswapper/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MechasAI/NeoRefacer/ad4dccfe49a1b1c803a880762dd682738a160dcc/weights/inswapper/.gitkeep --------------------------------------------------------------------------------