├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── codeformer ├── __init__.py ├── app.py ├── basicsr │ ├── VERSION │ ├── __init__.py │ ├── archs │ │ ├── __init__.py │ │ ├── arcface_arch.py │ │ ├── arch_util.py │ │ ├── codeformer_arch.py │ │ ├── rrdbnet_arch.py │ │ ├── vgg_arch.py │ │ └── vqgan_arch.py │ ├── data │ │ ├── __init__.py │ │ ├── data_sampler.py │ │ ├── data_util.py │ │ ├── prefetch_dataloader.py │ │ └── transforms.py │ ├── losses │ │ ├── __init__.py │ │ ├── loss_util.py │ │ └── losses.py │ ├── metrics │ │ ├── __init__.py │ │ ├── metric_util.py │ │ └── psnr_ssim.py │ ├── models │ │ └── __init__.py │ ├── ops │ │ ├── __init__.py │ │ ├── dcn │ │ │ ├── __init__.py │ │ │ ├── deform_conv.py │ │ │ └── src │ │ │ │ ├── deform_conv_cuda.cpp │ │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ │ └── deform_conv_ext.cpp │ │ ├── fused_act │ │ │ ├── __init__.py │ │ │ ├── fused_act.py │ │ │ └── src │ │ │ │ ├── fused_bias_act.cpp │ │ │ │ └── fused_bias_act_kernel.cu │ │ └── upfirdn2d │ │ │ ├── __init__.py │ │ │ ├── src │ │ │ ├── upfirdn2d.cpp │ │ │ └── upfirdn2d_kernel.cu │ │ │ └── upfirdn2d.py │ ├── setup.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── dist_util.py │ │ ├── download_util.py │ │ ├── file_client.py │ │ ├── img_util.py │ │ ├── lmdb_util.py │ │ ├── logger.py │ │ ├── matlab_functions.py │ │ ├── misc.py │ │ ├── options.py │ │ ├── realesrgan_utils.py │ │ ├── registry.py │ │ └── video_util.py ├── facelib │ ├── __init__.py │ ├── detection │ │ ├── __init__.py │ │ ├── align_trans.py │ │ ├── matlab_cp2tform.py │ │ ├── retinaface │ │ │ ├── __init__.py │ │ │ ├── retinaface.py │ │ │ ├── retinaface_net.py │ │ │ └── retinaface_utils.py │ │ └── yolov5face │ │ │ ├── __init__.py │ │ │ ├── face_detector.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── experimental.py │ │ │ ├── yolo.py │ │ │ ├── yolov5l.yaml │ │ │ └── yolov5n.yaml │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── autoanchor.py │ │ │ ├── datasets.py │ │ │ ├── extract_ckpt.py │ │ │ ├── general.py │ │ │ └── torch_utils.py │ ├── parsing │ │ ├── __init__.py │ │ ├── bisenet.py │ │ ├── parsenet.py │ │ └── resnet.py │ └── utils │ │ ├── __init__.py │ │ ├── face_restoration_helper.py │ │ ├── face_utils.py │ │ └── misc.py ├── inference_codeformer.py └── scripts │ ├── __init__.py │ ├── code_format.sh │ ├── crop_align_face.py │ ├── download_pretrained_models.py │ ├── download_pretrained_models_from_gdrive.py │ └── package.sh ├── pyproject.toml ├── requirements.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2022 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Towards Robust Blind Face Restoration with Codebook Lookup Transformer 3 |

4 | 5 | pypi version 6 | HuggingFace Spaces 7 | 8 | This repo is a PyTorch implementation of the paper [CodeFormer](https://arxiv.org/abs/2206.11253). 9 | 10 | ### Installation 11 | ```bash 12 | pip install codeformer-pip 13 | ``` 14 | ### Usage 15 | ```python 16 | from codeformer.app import inference_app 17 | 18 | inference_app( 19 | image="test.jpg", 20 | background_enhance=True, 21 | face_upsample=True, 22 | upscale=2, 23 | codeformer_fidelity=0.5, 24 | ) 25 | ``` 26 | ### Citation 27 | ```bibtex 28 | @inproceedings{zhou2022codeformer, 29 | author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change}, 30 | title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer}, 31 | booktitle = {NeurIPS}, 32 | year = {2022} 33 | } 34 | ``` 35 | 36 | ### License 37 | 38 | This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license. -------------------------------------------------------------------------------- /codeformer/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.4" 2 | -------------------------------------------------------------------------------- /codeformer/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import torch 5 | from torchvision.transforms.functional import normalize 6 | 7 | from codeformer.basicsr.archs.rrdbnet_arch import RRDBNet 8 | from codeformer.basicsr.utils import img2tensor, imwrite, tensor2img 9 | from codeformer.basicsr.utils.download_util import load_file_from_url 10 | from codeformer.basicsr.utils.realesrgan_utils import RealESRGANer 11 | from codeformer.basicsr.utils.registry import ARCH_REGISTRY 12 | from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper 13 | from codeformer.facelib.utils.misc import is_gray 14 | 15 | pretrain_model_url = { 16 | "codeformer": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth", 17 | "detection": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth", 18 | "parsing": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth", 19 | "realesrgan": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth", 20 | } 21 | 22 | # download weights 23 | if not os.path.exists("CodeFormer/weights/CodeFormer/codeformer.pth"): 24 | load_file_from_url( 25 | url=pretrain_model_url["codeformer"], model_dir="CodeFormer/weights/CodeFormer", progress=True, file_name=None 26 | ) 27 | if not os.path.exists("CodeFormer/weights/facelib/detection_Resnet50_Final.pth"): 28 | load_file_from_url( 29 | url=pretrain_model_url["detection"], model_dir="CodeFormer/weights/facelib", progress=True, file_name=None 30 | ) 31 | if not os.path.exists("CodeFormer/weights/facelib/parsing_parsenet.pth"): 32 | load_file_from_url( 33 | url=pretrain_model_url["parsing"], model_dir="CodeFormer/weights/facelib", progress=True, file_name=None 34 | ) 35 | if not os.path.exists("CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth"): 36 | load_file_from_url( 37 | url=pretrain_model_url["realesrgan"], model_dir="CodeFormer/weights/realesrgan", progress=True, file_name=None 38 | ) 39 | 40 | 41 | def imread(img_path): 42 | img = cv2.imread(img_path) 43 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 44 | return img 45 | 46 | 47 | # set enhancer with RealESRGAN 48 | def set_realesrgan(): 49 | half = True if torch.cuda.is_available() else False 50 | model = RRDBNet( 51 | num_in_ch=3, 52 | num_out_ch=3, 53 | num_feat=64, 54 | num_block=23, 55 | num_grow_ch=32, 56 | scale=2, 57 | ) 58 | upsampler = RealESRGANer( 59 | scale=2, 60 | model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth", 61 | model=model, 62 | tile=400, 63 | tile_pad=40, 64 | pre_pad=0, 65 | half=half, 66 | ) 67 | return upsampler 68 | 69 | 70 | upsampler = set_realesrgan() 71 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 72 | codeformer_net = ARCH_REGISTRY.get("CodeFormer")( 73 | dim_embd=512, 74 | codebook_size=1024, 75 | n_head=8, 76 | n_layers=9, 77 | connect_list=["32", "64", "128", "256"], 78 | ).to(device) 79 | ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth" 80 | checkpoint = torch.load(ckpt_path)["params_ema"] 81 | codeformer_net.load_state_dict(checkpoint) 82 | codeformer_net.eval() 83 | 84 | os.makedirs("output", exist_ok=True) 85 | 86 | 87 | def inference_app(image, background_enhance, face_upsample, upscale, codeformer_fidelity): 88 | # take the default setting for the demo 89 | has_aligned = False 90 | only_center_face = False 91 | draw_box = False 92 | detection_model = "retinaface_resnet50" 93 | print("Inp:", type(image), background_enhance, face_upsample, upscale, codeformer_fidelity) 94 | if isinstance(image, str): 95 | img = cv2.imread(str(image), cv2.IMREAD_COLOR) 96 | if isinstance(image, np.ndarray): 97 | img = image 98 | print("\timage size:", img.shape) 99 | 100 | upscale = int(upscale) # convert type to int 101 | if upscale > 4: # avoid memory exceeded due to too large upscale 102 | upscale = 4 103 | if upscale > 2 and max(img.shape[:2]) > 1000: # avoid memory exceeded due to too large img resolution 104 | upscale = 2 105 | if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution 106 | upscale = 1 107 | background_enhance = False 108 | face_upsample = False 109 | 110 | face_helper = FaceRestoreHelper( 111 | upscale, 112 | face_size=512, 113 | crop_ratio=(1, 1), 114 | det_model=detection_model, 115 | save_ext="png", 116 | use_parse=True, 117 | device=device, 118 | ) 119 | bg_upsampler = upsampler if background_enhance else None 120 | face_upsampler = upsampler if face_upsample else None 121 | 122 | if has_aligned: 123 | # the input faces are already cropped and aligned 124 | img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) 125 | face_helper.is_gray = is_gray(img, threshold=5) 126 | if face_helper.is_gray: 127 | print("\tgrayscale input: True") 128 | face_helper.cropped_faces = [img] 129 | else: 130 | face_helper.read_image(img) 131 | # get face landmarks for each face 132 | num_det_faces = face_helper.get_face_landmarks_5( 133 | only_center_face=only_center_face, resize=640, eye_dist_threshold=5 134 | ) 135 | print(f"\tdetect {num_det_faces} faces") 136 | # align and warp each face 137 | face_helper.align_warp_face() 138 | 139 | # face restoration for each cropped face 140 | for idx, cropped_face in enumerate(face_helper.cropped_faces): 141 | # prepare data 142 | cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True) 143 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 144 | cropped_face_t = cropped_face_t.unsqueeze(0).to(device) 145 | 146 | try: 147 | with torch.no_grad(): 148 | output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0] 149 | restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) 150 | del output 151 | torch.cuda.empty_cache() 152 | except RuntimeError as error: 153 | print(f"Failed inference for CodeFormer: {error}") 154 | restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) 155 | 156 | restored_face = restored_face.astype("uint8") 157 | face_helper.add_restored_face(restored_face) 158 | 159 | # paste_back 160 | if not has_aligned: 161 | # upsample the background 162 | if bg_upsampler is not None: 163 | # Now only support RealESRGAN for upsampling background 164 | bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] 165 | else: 166 | bg_img = None 167 | face_helper.get_inverse_affine(None) 168 | # paste each restored face to the input image 169 | if face_upsample and face_upsampler is not None: 170 | restored_img = face_helper.paste_faces_to_input_image( 171 | upsample_img=bg_img, 172 | draw_box=draw_box, 173 | face_upsampler=face_upsampler, 174 | ) 175 | else: 176 | restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box) 177 | 178 | return restored_img 179 | -------------------------------------------------------------------------------- /codeformer/basicsr/VERSION: -------------------------------------------------------------------------------- 1 | 1.3.2 2 | -------------------------------------------------------------------------------- /codeformer/basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .train import * 10 | from .utils import * 11 | -------------------------------------------------------------------------------- /codeformer/basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from codeformer.basicsr.utils import get_root_logger, scandir 6 | from codeformer.basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ["build_network"] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith("_arch.py")] 15 | # import all the arch modules 16 | _arch_modules = [importlib.import_module(f"codeformer.basicsr.archs.{file_name}") for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop("type") 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f"Network [{net.__class__.__name__}] is created.") 25 | return net 26 | -------------------------------------------------------------------------------- /codeformer/basicsr/archs/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from codeformer.basicsr.archs.arch_util import default_init_weights, make_layer, pixel_unshuffle 6 | from codeformer.basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | 12 | Used in RRDB block in ESRGAN. 13 | 14 | Args: 15 | num_feat (int): Channel number of intermediate features. 16 | num_grow_ch (int): Channels for each growth. 17 | """ 18 | 19 | def __init__(self, num_feat=64, num_grow_ch=32): 20 | super(ResidualDenseBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 26 | 27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | 29 | # initialization 30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 31 | 32 | def forward(self, x): 33 | x1 = self.lrelu(self.conv1(x)) 34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 38 | # Emperically, we use 0.2 to scale the residual for better performance 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | """Residual in Residual Dense Block. 44 | 45 | Used in RRDB-Net in ESRGAN. 46 | 47 | Args: 48 | num_feat (int): Channel number of intermediate features. 49 | num_grow_ch (int): Channels for each growth. 50 | """ 51 | 52 | def __init__(self, num_feat, num_grow_ch=32): 53 | super(RRDB, self).__init__() 54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 57 | 58 | def forward(self, x): 59 | out = self.rdb1(x) 60 | out = self.rdb2(out) 61 | out = self.rdb3(out) 62 | # Emperically, we use 0.2 to scale the residual for better performance 63 | return out * 0.2 + x 64 | 65 | 66 | @ARCH_REGISTRY.register() 67 | class RRDBNet(nn.Module): 68 | """Networks consisting of Residual in Residual Dense Block, which is used 69 | in ESRGAN. 70 | 71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 72 | 73 | We extend ESRGAN for scale x2 and scale x1. 74 | Note: This is one option for scale 1, scale 2 in RRDBNet. 75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 77 | 78 | Args: 79 | num_in_ch (int): Channel number of inputs. 80 | num_out_ch (int): Channel number of outputs. 81 | num_feat (int): Channel number of intermediate features. 82 | Default: 64 83 | num_block (int): Block number in the trunk network. Defaults: 23 84 | num_grow_ch (int): Channels for each growth. Default: 32. 85 | """ 86 | 87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 88 | super(RRDBNet, self).__init__() 89 | self.scale = scale 90 | if scale == 2: 91 | num_in_ch = num_in_ch * 4 92 | elif scale == 1: 93 | num_in_ch = num_in_ch * 16 94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | # upsample 98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 102 | 103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 104 | 105 | def forward(self, x): 106 | if self.scale == 2: 107 | feat = pixel_unshuffle(x, scale=2) 108 | elif self.scale == 1: 109 | feat = pixel_unshuffle(x, scale=4) 110 | else: 111 | feat = x 112 | feat = self.conv_first(feat) 113 | body_feat = self.conv_body(self.body(feat)) 114 | feat = feat + body_feat 115 | # upsample 116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))) 117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))) 118 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 119 | return out 120 | -------------------------------------------------------------------------------- /codeformer/basicsr/archs/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torchvision.models import vgg as vgg 7 | 8 | from codeformer.basicsr.utils.registry import ARCH_REGISTRY 9 | 10 | VGG_PRETRAIN_PATH = "experiments/pretrained_models/vgg19-dcbb9e9d.pth" 11 | NAMES = { 12 | "vgg11": [ 13 | "conv1_1", 14 | "relu1_1", 15 | "pool1", 16 | "conv2_1", 17 | "relu2_1", 18 | "pool2", 19 | "conv3_1", 20 | "relu3_1", 21 | "conv3_2", 22 | "relu3_2", 23 | "pool3", 24 | "conv4_1", 25 | "relu4_1", 26 | "conv4_2", 27 | "relu4_2", 28 | "pool4", 29 | "conv5_1", 30 | "relu5_1", 31 | "conv5_2", 32 | "relu5_2", 33 | "pool5", 34 | ], 35 | "vgg13": [ 36 | "conv1_1", 37 | "relu1_1", 38 | "conv1_2", 39 | "relu1_2", 40 | "pool1", 41 | "conv2_1", 42 | "relu2_1", 43 | "conv2_2", 44 | "relu2_2", 45 | "pool2", 46 | "conv3_1", 47 | "relu3_1", 48 | "conv3_2", 49 | "relu3_2", 50 | "pool3", 51 | "conv4_1", 52 | "relu4_1", 53 | "conv4_2", 54 | "relu4_2", 55 | "pool4", 56 | "conv5_1", 57 | "relu5_1", 58 | "conv5_2", 59 | "relu5_2", 60 | "pool5", 61 | ], 62 | "vgg16": [ 63 | "conv1_1", 64 | "relu1_1", 65 | "conv1_2", 66 | "relu1_2", 67 | "pool1", 68 | "conv2_1", 69 | "relu2_1", 70 | "conv2_2", 71 | "relu2_2", 72 | "pool2", 73 | "conv3_1", 74 | "relu3_1", 75 | "conv3_2", 76 | "relu3_2", 77 | "conv3_3", 78 | "relu3_3", 79 | "pool3", 80 | "conv4_1", 81 | "relu4_1", 82 | "conv4_2", 83 | "relu4_2", 84 | "conv4_3", 85 | "relu4_3", 86 | "pool4", 87 | "conv5_1", 88 | "relu5_1", 89 | "conv5_2", 90 | "relu5_2", 91 | "conv5_3", 92 | "relu5_3", 93 | "pool5", 94 | ], 95 | "vgg19": [ 96 | "conv1_1", 97 | "relu1_1", 98 | "conv1_2", 99 | "relu1_2", 100 | "pool1", 101 | "conv2_1", 102 | "relu2_1", 103 | "conv2_2", 104 | "relu2_2", 105 | "pool2", 106 | "conv3_1", 107 | "relu3_1", 108 | "conv3_2", 109 | "relu3_2", 110 | "conv3_3", 111 | "relu3_3", 112 | "conv3_4", 113 | "relu3_4", 114 | "pool3", 115 | "conv4_1", 116 | "relu4_1", 117 | "conv4_2", 118 | "relu4_2", 119 | "conv4_3", 120 | "relu4_3", 121 | "conv4_4", 122 | "relu4_4", 123 | "pool4", 124 | "conv5_1", 125 | "relu5_1", 126 | "conv5_2", 127 | "relu5_2", 128 | "conv5_3", 129 | "relu5_3", 130 | "conv5_4", 131 | "relu5_4", 132 | "pool5", 133 | ], 134 | } 135 | 136 | 137 | def insert_bn(names): 138 | """Insert bn layer after each conv. 139 | 140 | Args: 141 | names (list): The list of layer names. 142 | 143 | Returns: 144 | list: The list of layer names with bn layers. 145 | """ 146 | names_bn = [] 147 | for name in names: 148 | names_bn.append(name) 149 | if "conv" in name: 150 | position = name.replace("conv", "") 151 | names_bn.append("bn" + position) 152 | return names_bn 153 | 154 | 155 | @ARCH_REGISTRY.register() 156 | class VGGFeatureExtractor(nn.Module): 157 | """VGG network for feature extraction. 158 | 159 | In this implementation, we allow users to choose whether use normalization 160 | in the input feature and the type of vgg network. Note that the pretrained 161 | path must fit the vgg type. 162 | 163 | Args: 164 | layer_name_list (list[str]): Forward function returns the corresponding 165 | features according to the layer_name_list. 166 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 167 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 168 | use_input_norm (bool): If True, normalize the input image. Importantly, 169 | the input feature must in the range [0, 1]. Default: True. 170 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 171 | Default: False. 172 | requires_grad (bool): If true, the parameters of VGG network will be 173 | optimized. Default: False. 174 | remove_pooling (bool): If true, the max pooling operations in VGG net 175 | will be removed. Default: False. 176 | pooling_stride (int): The stride of max pooling operation. Default: 2. 177 | """ 178 | 179 | def __init__( 180 | self, 181 | layer_name_list, 182 | vgg_type="vgg19", 183 | use_input_norm=True, 184 | range_norm=False, 185 | requires_grad=False, 186 | remove_pooling=False, 187 | pooling_stride=2, 188 | ): 189 | super(VGGFeatureExtractor, self).__init__() 190 | 191 | self.layer_name_list = layer_name_list 192 | self.use_input_norm = use_input_norm 193 | self.range_norm = range_norm 194 | 195 | self.names = NAMES[vgg_type.replace("_bn", "")] 196 | if "bn" in vgg_type: 197 | self.names = insert_bn(self.names) 198 | 199 | # only borrow layers that will be used to avoid unused params 200 | max_idx = 0 201 | for v in layer_name_list: 202 | idx = self.names.index(v) 203 | if idx > max_idx: 204 | max_idx = idx 205 | 206 | if os.path.exists(VGG_PRETRAIN_PATH): 207 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 208 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) 209 | vgg_net.load_state_dict(state_dict) 210 | else: 211 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 212 | 213 | features = vgg_net.features[: max_idx + 1] 214 | 215 | modified_net = OrderedDict() 216 | for k, v in zip(self.names, features): 217 | if "pool" in k: 218 | # if remove_pooling is true, pooling operation will be removed 219 | if remove_pooling: 220 | continue 221 | else: 222 | # in some cases, we may want to change the default stride 223 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 224 | else: 225 | modified_net[k] = v 226 | 227 | self.vgg_net = nn.Sequential(modified_net) 228 | 229 | if not requires_grad: 230 | self.vgg_net.eval() 231 | for param in self.parameters(): 232 | param.requires_grad = False 233 | else: 234 | self.vgg_net.train() 235 | for param in self.parameters(): 236 | param.requires_grad = True 237 | 238 | if self.use_input_norm: 239 | # the mean is for image with range [0, 1] 240 | self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 241 | # the std is for image with range [0, 1] 242 | self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 243 | 244 | def forward(self, x): 245 | """Forward function. 246 | 247 | Args: 248 | x (Tensor): Input tensor with shape (n, c, h, w). 249 | 250 | Returns: 251 | Tensor: Forward results. 252 | """ 253 | if self.range_norm: 254 | x = (x + 1) / 2 255 | if self.use_input_norm: 256 | x = (x - self.mean) / self.std 257 | output = {} 258 | 259 | for key, layer in self.vgg_net._modules.items(): 260 | x = layer(x) 261 | if key in self.layer_name_list: 262 | output[key] = x.clone() 263 | 264 | return output 265 | -------------------------------------------------------------------------------- /codeformer/basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import random 3 | from copy import deepcopy 4 | from functools import partial 5 | from os import path as osp 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | 11 | from codeformer.basicsr.data.prefetch_dataloader import PrefetchDataLoader 12 | from codeformer.basicsr.utils import get_root_logger, scandir 13 | from codeformer.basicsr.utils.dist_util import get_dist_info 14 | from codeformer.basicsr.utils.registry import DATASET_REGISTRY 15 | 16 | __all__ = ["build_dataset", "build_dataloader"] 17 | 18 | # automatically scan and import dataset modules for registry 19 | # scan all the files under the data folder with '_dataset' in file names 20 | data_folder = osp.dirname(osp.abspath(__file__)) 21 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith("_dataset.py")] 22 | # import all the dataset modules 23 | _dataset_modules = [importlib.import_module(f"codeformer.basicsr.data.{file_name}") for file_name in dataset_filenames] 24 | 25 | 26 | def build_dataset(dataset_opt): 27 | """Build dataset from options. 28 | 29 | Args: 30 | dataset_opt (dict): Configuration for dataset. It must constain: 31 | name (str): Dataset name. 32 | type (str): Dataset type. 33 | """ 34 | dataset_opt = deepcopy(dataset_opt) 35 | dataset = DATASET_REGISTRY.get(dataset_opt["type"])(dataset_opt) 36 | logger = get_root_logger() 37 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' "is built.") 38 | return dataset 39 | 40 | 41 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 42 | """Build dataloader. 43 | 44 | Args: 45 | dataset (torch.utils.data.Dataset): Dataset. 46 | dataset_opt (dict): Dataset options. It contains the following keys: 47 | phase (str): 'train' or 'val'. 48 | num_worker_per_gpu (int): Number of workers for each GPU. 49 | batch_size_per_gpu (int): Training batch size for each GPU. 50 | num_gpu (int): Number of GPUs. Used only in the train phase. 51 | Default: 1. 52 | dist (bool): Whether in distributed training. Used only in the train 53 | phase. Default: False. 54 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 55 | seed (int | None): Seed. Default: None 56 | """ 57 | phase = dataset_opt["phase"] 58 | rank, _ = get_dist_info() 59 | if phase == "train": 60 | if dist: # distributed training 61 | batch_size = dataset_opt["batch_size_per_gpu"] 62 | num_workers = dataset_opt["num_worker_per_gpu"] 63 | else: # non-distributed training 64 | multiplier = 1 if num_gpu == 0 else num_gpu 65 | batch_size = dataset_opt["batch_size_per_gpu"] * multiplier 66 | num_workers = dataset_opt["num_worker_per_gpu"] * multiplier 67 | dataloader_args = dict( 68 | dataset=dataset, 69 | batch_size=batch_size, 70 | shuffle=False, 71 | num_workers=num_workers, 72 | sampler=sampler, 73 | drop_last=True, 74 | ) 75 | if sampler is None: 76 | dataloader_args["shuffle"] = True 77 | dataloader_args["worker_init_fn"] = ( 78 | partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 79 | ) 80 | elif phase in ["val", "test"]: # validation 81 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 82 | else: 83 | raise ValueError(f"Wrong dataset phase: {phase}. " "Supported ones are 'train', 'val' and 'test'.") 84 | 85 | dataloader_args["pin_memory"] = dataset_opt.get("pin_memory", False) 86 | 87 | prefetch_mode = dataset_opt.get("prefetch_mode") 88 | if prefetch_mode == "cpu": # CPUPrefetcher 89 | num_prefetch_queue = dataset_opt.get("num_prefetch_queue", 1) 90 | logger = get_root_logger() 91 | logger.info(f"Use {prefetch_mode} prefetch dataloader: " f"num_prefetch_queue = {num_prefetch_queue}") 92 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 93 | else: 94 | # prefetch_mode=None: Normal dataloader 95 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 96 | return torch.utils.data.DataLoader(**dataloader_args) 97 | 98 | 99 | def worker_init_fn(worker_id, num_workers, rank, seed): 100 | # Set the worker seed to num_workers * rank + worker_id + seed 101 | worker_seed = num_workers * rank + worker_id + seed 102 | np.random.seed(worker_seed) 103 | random.seed(worker_seed) 104 | -------------------------------------------------------------------------------- /codeformer/basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.utils.data.sampler import Sampler 5 | 6 | 7 | class EnlargedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | Modified from torch.utils.data.distributed.DistributedSampler 11 | Support enlarging the dataset for iteration-based training, for saving 12 | time when restart the dataloader after each epoch 13 | 14 | Args: 15 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 16 | num_replicas (int | None): Number of processes participating in 17 | the training. It is usually the world_size. 18 | rank (int | None): Rank of the current process within num_replicas. 19 | ratio (int): Enlarging ratio. Default: 1. 20 | """ 21 | 22 | def __init__(self, dataset, num_replicas, rank, ratio=1): 23 | self.dataset = dataset 24 | self.num_replicas = num_replicas 25 | self.rank = rank 26 | self.epoch = 0 27 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 28 | self.total_size = self.num_samples * self.num_replicas 29 | 30 | def __iter__(self): 31 | # deterministically shuffle based on epoch 32 | g = torch.Generator() 33 | g.manual_seed(self.epoch) 34 | indices = torch.randperm(self.total_size, generator=g).tolist() 35 | 36 | dataset_size = len(self.dataset) 37 | indices = [v % dataset_size for v in indices] 38 | 39 | # subsample 40 | indices = indices[self.rank : self.total_size : self.num_replicas] 41 | assert len(indices) == self.num_samples 42 | 43 | return iter(indices) 44 | 45 | def __len__(self): 46 | return self.num_samples 47 | 48 | def set_epoch(self, epoch): 49 | self.epoch = epoch 50 | -------------------------------------------------------------------------------- /codeformer/basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class PrefetchGenerator(threading.Thread): 9 | """A general prefetch generator. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 13 | 14 | Args: 15 | generator: Python generator. 16 | num_prefetch_queue (int): Number of prefetch queue. 17 | """ 18 | 19 | def __init__(self, generator, num_prefetch_queue): 20 | threading.Thread.__init__(self) 21 | self.queue = Queue.Queue(num_prefetch_queue) 22 | self.generator = generator 23 | self.daemon = True 24 | self.start() 25 | 26 | def run(self): 27 | for item in self.generator: 28 | self.queue.put(item) 29 | self.queue.put(None) 30 | 31 | def __next__(self): 32 | next_item = self.queue.get() 33 | if next_item is None: 34 | raise StopIteration 35 | return next_item 36 | 37 | def __iter__(self): 38 | return self 39 | 40 | 41 | class PrefetchDataLoader(DataLoader): 42 | """Prefetch version of dataloader. 43 | 44 | Ref: 45 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 46 | 47 | TODO: 48 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 49 | ddp. 50 | 51 | Args: 52 | num_prefetch_queue (int): Number of prefetch queue. 53 | kwargs (dict): Other arguments for dataloader. 54 | """ 55 | 56 | def __init__(self, num_prefetch_queue, **kwargs): 57 | self.num_prefetch_queue = num_prefetch_queue 58 | super(PrefetchDataLoader, self).__init__(**kwargs) 59 | 60 | def __iter__(self): 61 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 62 | 63 | 64 | class CPUPrefetcher: 65 | """CPU prefetcher. 66 | 67 | Args: 68 | loader: Dataloader. 69 | """ 70 | 71 | def __init__(self, loader): 72 | self.ori_loader = loader 73 | self.loader = iter(loader) 74 | 75 | def next(self): 76 | try: 77 | return next(self.loader) 78 | except StopIteration: 79 | return None 80 | 81 | def reset(self): 82 | self.loader = iter(self.ori_loader) 83 | 84 | 85 | class CUDAPrefetcher: 86 | """CUDA prefetcher. 87 | 88 | Ref: 89 | https://github.com/NVIDIA/apex/issues/304# 90 | 91 | It may consums more GPU memory. 92 | 93 | Args: 94 | loader: Dataloader. 95 | opt (dict): Options. 96 | """ 97 | 98 | def __init__(self, loader, opt): 99 | self.ori_loader = loader 100 | self.loader = iter(loader) 101 | self.opt = opt 102 | self.stream = torch.cuda.Stream() 103 | self.device = torch.device("cuda" if opt["num_gpu"] != 0 else "cpu") 104 | self.preload() 105 | 106 | def preload(self): 107 | try: 108 | self.batch = next(self.loader) # self.batch is a dict 109 | except StopIteration: 110 | self.batch = None 111 | return None 112 | # put tensors to gpu 113 | with torch.cuda.stream(self.stream): 114 | for k, v in self.batch.items(): 115 | if torch.is_tensor(v): 116 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 117 | 118 | def next(self): 119 | torch.cuda.current_stream().wait_stream(self.stream) 120 | batch = self.batch 121 | self.preload() 122 | return batch 123 | 124 | def reset(self): 125 | self.loader = iter(self.ori_loader) 126 | self.preload() 127 | -------------------------------------------------------------------------------- /codeformer/basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | 5 | 6 | def mod_crop(img, scale): 7 | """Mod crop images, used during testing. 8 | 9 | Args: 10 | img (ndarray): Input image. 11 | scale (int): Scale factor. 12 | 13 | Returns: 14 | ndarray: Result image. 15 | """ 16 | img = img.copy() 17 | if img.ndim in (2, 3): 18 | h, w = img.shape[0], img.shape[1] 19 | h_remainder, w_remainder = h % scale, w % scale 20 | img = img[: h - h_remainder, : w - w_remainder, ...] 21 | else: 22 | raise ValueError(f"Wrong img ndim: {img.ndim}.") 23 | return img 24 | 25 | 26 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): 27 | """Paired random crop. 28 | 29 | It crops lists of lq and gt images with corresponding locations. 30 | 31 | Args: 32 | img_gts (list[ndarray] | ndarray): GT images. Note that all images 33 | should have the same shape. If the input is an ndarray, it will 34 | be transformed to a list containing itself. 35 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 36 | should have the same shape. If the input is an ndarray, it will 37 | be transformed to a list containing itself. 38 | gt_patch_size (int): GT patch size. 39 | scale (int): Scale factor. 40 | gt_path (str): Path to ground-truth. 41 | 42 | Returns: 43 | list[ndarray] | ndarray: GT images and LQ images. If returned results 44 | only have one element, just return ndarray. 45 | """ 46 | 47 | if not isinstance(img_gts, list): 48 | img_gts = [img_gts] 49 | if not isinstance(img_lqs, list): 50 | img_lqs = [img_lqs] 51 | 52 | h_lq, w_lq, _ = img_lqs[0].shape 53 | h_gt, w_gt, _ = img_gts[0].shape 54 | lq_patch_size = gt_patch_size // scale 55 | 56 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 57 | raise ValueError( 58 | f"Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ", f"multiplication of LQ ({h_lq}, {w_lq})." 59 | ) 60 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 61 | raise ValueError( 62 | f"LQ ({h_lq}, {w_lq}) is smaller than patch size " 63 | f"({lq_patch_size}, {lq_patch_size}). " 64 | f"Please remove {gt_path}." 65 | ) 66 | 67 | # randomly choose top and left coordinates for lq patch 68 | top = random.randint(0, h_lq - lq_patch_size) 69 | left = random.randint(0, w_lq - lq_patch_size) 70 | 71 | # crop lq patch 72 | img_lqs = [v[top : top + lq_patch_size, left : left + lq_patch_size, ...] for v in img_lqs] 73 | 74 | # crop corresponding gt patch 75 | top_gt, left_gt = int(top * scale), int(left * scale) 76 | img_gts = [v[top_gt : top_gt + gt_patch_size, left_gt : left_gt + gt_patch_size, ...] for v in img_gts] 77 | if len(img_gts) == 1: 78 | img_gts = img_gts[0] 79 | if len(img_lqs) == 1: 80 | img_lqs = img_lqs[0] 81 | return img_gts, img_lqs 82 | 83 | 84 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 85 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 86 | 87 | We use vertical flip and transpose for rotation implementation. 88 | All the images in the list use the same augmentation. 89 | 90 | Args: 91 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 92 | is an ndarray, it will be transformed to a list. 93 | hflip (bool): Horizontal flip. Default: True. 94 | rotation (bool): Ratotation. Default: True. 95 | flows (list[ndarray]: Flows to be augmented. If the input is an 96 | ndarray, it will be transformed to a list. 97 | Dimension is (h, w, 2). Default: None. 98 | return_status (bool): Return the status of flip and rotation. 99 | Default: False. 100 | 101 | Returns: 102 | list[ndarray] | ndarray: Augmented images and flows. If returned 103 | results only have one element, just return ndarray. 104 | 105 | """ 106 | hflip = hflip and random.random() < 0.5 107 | vflip = rotation and random.random() < 0.5 108 | rot90 = rotation and random.random() < 0.5 109 | 110 | def _augment(img): 111 | if hflip: # horizontal 112 | cv2.flip(img, 1, img) 113 | if vflip: # vertical 114 | cv2.flip(img, 0, img) 115 | if rot90: 116 | img = img.transpose(1, 0, 2) 117 | return img 118 | 119 | def _augment_flow(flow): 120 | if hflip: # horizontal 121 | cv2.flip(flow, 1, flow) 122 | flow[:, :, 0] *= -1 123 | if vflip: # vertical 124 | cv2.flip(flow, 0, flow) 125 | flow[:, :, 1] *= -1 126 | if rot90: 127 | flow = flow.transpose(1, 0, 2) 128 | flow = flow[:, :, [1, 0]] 129 | return flow 130 | 131 | if not isinstance(imgs, list): 132 | imgs = [imgs] 133 | imgs = [_augment(img) for img in imgs] 134 | if len(imgs) == 1: 135 | imgs = imgs[0] 136 | 137 | if flows is not None: 138 | if not isinstance(flows, list): 139 | flows = [flows] 140 | flows = [_augment_flow(flow) for flow in flows] 141 | if len(flows) == 1: 142 | flows = flows[0] 143 | return imgs, flows 144 | else: 145 | if return_status: 146 | return imgs, (hflip, vflip, rot90) 147 | else: 148 | return imgs 149 | 150 | 151 | def img_rotate(img, angle, center=None, scale=1.0): 152 | """Rotate image. 153 | 154 | Args: 155 | img (ndarray): Image to be rotated. 156 | angle (float): Rotation angle in degrees. Positive values mean 157 | counter-clockwise rotation. 158 | center (tuple[int]): Rotation center. If the center is None, 159 | initialize it as the center of the image. Default: None. 160 | scale (float): Isotropic scale factor. Default: 1.0. 161 | """ 162 | (h, w) = img.shape[:2] 163 | 164 | if center is None: 165 | center = (w // 2, h // 2) 166 | 167 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 168 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 169 | return rotated_img 170 | -------------------------------------------------------------------------------- /codeformer/basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from codeformer.basicsr.losses.losses import ( 4 | CharbonnierLoss, 5 | GANLoss, 6 | L1Loss, 7 | MSELoss, 8 | PerceptualLoss, 9 | WeightedTVLoss, 10 | g_path_regularize, 11 | gradient_penalty_loss, 12 | r1_penalty, 13 | ) 14 | from codeformer.basicsr.utils import get_root_logger 15 | from codeformer.basicsr.utils.registry import LOSS_REGISTRY 16 | 17 | __all__ = [ 18 | "L1Loss", 19 | "MSELoss", 20 | "CharbonnierLoss", 21 | "WeightedTVLoss", 22 | "PerceptualLoss", 23 | "GANLoss", 24 | "gradient_penalty_loss", 25 | "r1_penalty", 26 | "g_path_regularize", 27 | ] 28 | 29 | 30 | def build_loss(opt): 31 | """Build loss from options. 32 | 33 | Args: 34 | opt (dict): Configuration. It must constain: 35 | type (str): Model type. 36 | """ 37 | opt = deepcopy(opt) 38 | loss_type = opt.pop("type") 39 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 40 | logger = get_root_logger() 41 | logger.info(f"Loss [{loss.__class__.__name__}] is created.") 42 | return loss 43 | -------------------------------------------------------------------------------- /codeformer/basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from torch.nn import functional as F 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are 'none', 'mean' and 'sum'. 12 | 13 | Returns: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | reduction_enum = F._Reduction.get_enum(reduction) 17 | # none: 0, elementwise_mean:1, sum: 2 18 | if reduction_enum == 0: 19 | return loss 20 | elif reduction_enum == 1: 21 | return loss.mean() 22 | else: 23 | return loss.sum() 24 | 25 | 26 | def weight_reduce_loss(loss, weight=None, reduction="mean"): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. Default: None. 32 | reduction (str): Same as built-in losses of PyTorch. Options are 33 | 'none', 'mean' and 'sum'. Default: 'mean'. 34 | 35 | Returns: 36 | Tensor: Loss values. 37 | """ 38 | # if weight is specified, apply element-wise weight 39 | if weight is not None: 40 | assert weight.dim() == loss.dim() 41 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 42 | loss = loss * weight 43 | 44 | # if weight is not specified or reduction is sum, just reduce the loss 45 | if weight is None or reduction == "sum": 46 | loss = reduce_loss(loss, reduction) 47 | # if reduction is mean, then compute mean over weight region 48 | elif reduction == "mean": 49 | if weight.size(1) > 1: 50 | weight = weight.sum() 51 | else: 52 | weight = weight.sum() * loss.size(1) 53 | loss = loss.sum() / weight 54 | 55 | return loss 56 | 57 | 58 | def weighted_loss(loss_func): 59 | """Create a weighted version of a given loss function. 60 | 61 | To use this decorator, the loss function must have the signature like 62 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 63 | element-wise loss without any reduction. This decorator will add weight 64 | and reduction arguments to the function. The decorated function will have 65 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 66 | **kwargs)`. 67 | 68 | :Example: 69 | 70 | >>> import torch 71 | >>> @weighted_loss 72 | >>> def l1_loss(pred, target): 73 | >>> return (pred - target).abs() 74 | 75 | >>> pred = torch.Tensor([0, 2, 3]) 76 | >>> target = torch.Tensor([1, 1, 1]) 77 | >>> weight = torch.Tensor([1, 0, 1]) 78 | 79 | >>> l1_loss(pred, target) 80 | tensor(1.3333) 81 | >>> l1_loss(pred, target, weight) 82 | tensor(1.5000) 83 | >>> l1_loss(pred, target, reduction='none') 84 | tensor([1., 1., 2.]) 85 | >>> l1_loss(pred, target, weight, reduction='sum') 86 | tensor(3.) 87 | """ 88 | 89 | @functools.wraps(loss_func) 90 | def wrapper(pred, target, weight=None, reduction="mean", **kwargs): 91 | # get element-wise loss 92 | loss = loss_func(pred, target, **kwargs) 93 | loss = weight_reduce_loss(loss, weight, reduction) 94 | return loss 95 | 96 | return wrapper 97 | -------------------------------------------------------------------------------- /codeformer/basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from codeformer.basicsr.metrics.psnr_ssim import calculate_psnr, calculate_ssim 4 | from codeformer.basicsr.utils.registry import METRIC_REGISTRY 5 | 6 | __all__ = ["calculate_psnr", "calculate_ssim"] 7 | 8 | 9 | def calculate_metric(data, opt): 10 | """Calculate metric from data and options. 11 | 12 | Args: 13 | opt (dict): Configuration. It must constain: 14 | type (str): Model type. 15 | """ 16 | opt = deepcopy(opt) 17 | metric_type = opt.pop("type") 18 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 19 | return metric 20 | -------------------------------------------------------------------------------- /codeformer/basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from codeformer.basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order="HWC"): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ["HWC", "CHW"]: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " "'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == "CHW": 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255.0 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255.0 46 | -------------------------------------------------------------------------------- /codeformer/basicsr/metrics/psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from codeformer.basicsr.metrics.metric_util import reorder_image, to_y_channel 5 | from codeformer.basicsr.utils.registry import METRIC_REGISTRY 6 | 7 | 8 | @METRIC_REGISTRY.register() 9 | def calculate_psnr(img1, img2, crop_border, input_order="HWC", test_y_channel=False): 10 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 11 | 12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 13 | 14 | Args: 15 | img1 (ndarray): Images with range [0, 255]. 16 | img2 (ndarray): Images with range [0, 255]. 17 | crop_border (int): Cropped pixels in each edge of an image. These 18 | pixels are not involved in the PSNR calculation. 19 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 20 | Default: 'HWC'. 21 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 22 | 23 | Returns: 24 | float: psnr result. 25 | """ 26 | 27 | assert img1.shape == img2.shape, f"Image shapes are differnet: {img1.shape}, {img2.shape}." 28 | if input_order not in ["HWC", "CHW"]: 29 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " '"HWC" and "CHW"') 30 | img1 = reorder_image(img1, input_order=input_order) 31 | img2 = reorder_image(img2, input_order=input_order) 32 | img1 = img1.astype(np.float64) 33 | img2 = img2.astype(np.float64) 34 | 35 | if crop_border != 0: 36 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 38 | 39 | if test_y_channel: 40 | img1 = to_y_channel(img1) 41 | img2 = to_y_channel(img2) 42 | 43 | mse = np.mean((img1 - img2) ** 2) 44 | if mse == 0: 45 | return float("inf") 46 | return 20.0 * np.log10(255.0 / np.sqrt(mse)) 47 | 48 | 49 | def _ssim(img1, img2): 50 | """Calculate SSIM (structural similarity) for one channel images. 51 | 52 | It is called by func:`calculate_ssim`. 53 | 54 | Args: 55 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 56 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 57 | 58 | Returns: 59 | float: ssim result. 60 | """ 61 | 62 | C1 = (0.01 * 255) ** 2 63 | C2 = (0.03 * 255) ** 2 64 | 65 | img1 = img1.astype(np.float64) 66 | img2 = img2.astype(np.float64) 67 | kernel = cv2.getGaussianKernel(11, 1.5) 68 | window = np.outer(kernel, kernel.transpose()) 69 | 70 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 71 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 72 | mu1_sq = mu1 ** 2 73 | mu2_sq = mu2 ** 2 74 | mu1_mu2 = mu1 * mu2 75 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 76 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 77 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 78 | 79 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 80 | return ssim_map.mean() 81 | 82 | 83 | @METRIC_REGISTRY.register() 84 | def calculate_ssim(img1, img2, crop_border, input_order="HWC", test_y_channel=False): 85 | """Calculate SSIM (structural similarity). 86 | 87 | Ref: 88 | Image quality assessment: From error visibility to structural similarity 89 | 90 | The results are the same as that of the official released MATLAB code in 91 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 92 | 93 | For three-channel images, SSIM is calculated for each channel and then 94 | averaged. 95 | 96 | Args: 97 | img1 (ndarray): Images with range [0, 255]. 98 | img2 (ndarray): Images with range [0, 255]. 99 | crop_border (int): Cropped pixels in each edge of an image. These 100 | pixels are not involved in the SSIM calculation. 101 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 102 | Default: 'HWC'. 103 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 104 | 105 | Returns: 106 | float: ssim result. 107 | """ 108 | 109 | assert img1.shape == img2.shape, f"Image shapes are differnet: {img1.shape}, {img2.shape}." 110 | if input_order not in ["HWC", "CHW"]: 111 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " '"HWC" and "CHW"') 112 | img1 = reorder_image(img1, input_order=input_order) 113 | img2 = reorder_image(img2, input_order=input_order) 114 | img1 = img1.astype(np.float64) 115 | img2 = img2.astype(np.float64) 116 | 117 | if crop_border != 0: 118 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 119 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 120 | 121 | if test_y_channel: 122 | img1 = to_y_channel(img1) 123 | img2 = to_y_channel(img2) 124 | 125 | ssims = [] 126 | for i in range(img1.shape[2]): 127 | ssims.append(_ssim(img1[..., i], img2[..., i])) 128 | return np.array(ssims).mean() 129 | -------------------------------------------------------------------------------- /codeformer/basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from codeformer.basicsr.utils import get_root_logger, scandir 6 | from codeformer.basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ["build_model"] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with 12 | # '_model.py' 13 | model_folder = osp.dirname(osp.abspath(__file__)) 14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith("_model.py")] 15 | # import all the model modules 16 | _model_modules = [importlib.import_module(f"codeformer.basicsr.models.{file_name}") for file_name in model_filenames] 17 | 18 | 19 | def build_model(opt): 20 | """Build model from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must constain: 24 | model_type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | model = MODEL_REGISTRY.get(opt["model_type"])(opt) 28 | logger = get_root_logger() 29 | logger.info(f"Model [{model.__class__.__name__}] is created.") 30 | return model 31 | -------------------------------------------------------------------------------- /codeformer/basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /codeformer/basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from codeformer.basicsr.ops.dcn.deform_conv import ( 2 | DeformConv, 3 | DeformConvPack, 4 | ModulatedDeformConv, 5 | ModulatedDeformConvPack, 6 | deform_conv, 7 | modulated_deform_conv, 8 | ) 9 | 10 | __all__ = [ 11 | "DeformConv", 12 | "DeformConvPack", 13 | "ModulatedDeformConv", 14 | "ModulatedDeformConvPack", 15 | "deform_conv", 16 | "modulated_deform_conv", 17 | ] 18 | -------------------------------------------------------------------------------- /codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp: -------------------------------------------------------------------------------- 1 | // modify from 2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #define WITH_CUDA // always use cuda 11 | #ifdef WITH_CUDA 12 | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, 13 | at::Tensor offset, at::Tensor output, 14 | at::Tensor columns, at::Tensor ones, int kW, 15 | int kH, int dW, int dH, int padW, int padH, 16 | int dilationW, int dilationH, int group, 17 | int deformable_group, int im2col_step); 18 | 19 | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, 20 | at::Tensor gradOutput, at::Tensor gradInput, 21 | at::Tensor gradOffset, at::Tensor weight, 22 | at::Tensor columns, int kW, int kH, int dW, 23 | int dH, int padW, int padH, int dilationW, 24 | int dilationH, int group, 25 | int deformable_group, int im2col_step); 26 | 27 | int deform_conv_backward_parameters_cuda( 28 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 29 | at::Tensor gradWeight, // at::Tensor gradBias, 30 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 31 | int padW, int padH, int dilationW, int dilationH, int group, 32 | int deformable_group, float scale, int im2col_step); 33 | 34 | void modulated_deform_conv_cuda_forward( 35 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 36 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 37 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 38 | const int pad_h, const int pad_w, const int dilation_h, 39 | const int dilation_w, const int group, const int deformable_group, 40 | const bool with_bias); 41 | 42 | void modulated_deform_conv_cuda_backward( 43 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 44 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 45 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 46 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 47 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 48 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 49 | const bool with_bias); 50 | #endif 51 | 52 | int deform_conv_forward(at::Tensor input, at::Tensor weight, 53 | at::Tensor offset, at::Tensor output, 54 | at::Tensor columns, at::Tensor ones, int kW, 55 | int kH, int dW, int dH, int padW, int padH, 56 | int dilationW, int dilationH, int group, 57 | int deformable_group, int im2col_step) { 58 | if (input.device().is_cuda()) { 59 | #ifdef WITH_CUDA 60 | return deform_conv_forward_cuda(input, weight, offset, output, columns, 61 | ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, 62 | deformable_group, im2col_step); 63 | #else 64 | AT_ERROR("deform conv is not compiled with GPU support"); 65 | #endif 66 | } 67 | AT_ERROR("deform conv is not implemented on CPU"); 68 | } 69 | 70 | int deform_conv_backward_input(at::Tensor input, at::Tensor offset, 71 | at::Tensor gradOutput, at::Tensor gradInput, 72 | at::Tensor gradOffset, at::Tensor weight, 73 | at::Tensor columns, int kW, int kH, int dW, 74 | int dH, int padW, int padH, int dilationW, 75 | int dilationH, int group, 76 | int deformable_group, int im2col_step) { 77 | if (input.device().is_cuda()) { 78 | #ifdef WITH_CUDA 79 | return deform_conv_backward_input_cuda(input, offset, gradOutput, 80 | gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, 81 | dilationW, dilationH, group, deformable_group, im2col_step); 82 | #else 83 | AT_ERROR("deform conv is not compiled with GPU support"); 84 | #endif 85 | } 86 | AT_ERROR("deform conv is not implemented on CPU"); 87 | } 88 | 89 | int deform_conv_backward_parameters( 90 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 91 | at::Tensor gradWeight, // at::Tensor gradBias, 92 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 93 | int padW, int padH, int dilationW, int dilationH, int group, 94 | int deformable_group, float scale, int im2col_step) { 95 | if (input.device().is_cuda()) { 96 | #ifdef WITH_CUDA 97 | return deform_conv_backward_parameters_cuda(input, offset, gradOutput, 98 | gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, 99 | dilationH, group, deformable_group, scale, im2col_step); 100 | #else 101 | AT_ERROR("deform conv is not compiled with GPU support"); 102 | #endif 103 | } 104 | AT_ERROR("deform conv is not implemented on CPU"); 105 | } 106 | 107 | void modulated_deform_conv_forward( 108 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 109 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 110 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 111 | const int pad_h, const int pad_w, const int dilation_h, 112 | const int dilation_w, const int group, const int deformable_group, 113 | const bool with_bias) { 114 | if (input.device().is_cuda()) { 115 | #ifdef WITH_CUDA 116 | return modulated_deform_conv_cuda_forward(input, weight, bias, ones, 117 | offset, mask, output, columns, kernel_h, kernel_w, stride_h, 118 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, 119 | deformable_group, with_bias); 120 | #else 121 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 122 | #endif 123 | } 124 | AT_ERROR("modulated deform conv is not implemented on CPU"); 125 | } 126 | 127 | void modulated_deform_conv_backward( 128 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 129 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 130 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 131 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 132 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 133 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 134 | const bool with_bias) { 135 | if (input.device().is_cuda()) { 136 | #ifdef WITH_CUDA 137 | return modulated_deform_conv_cuda_backward(input, weight, bias, ones, 138 | offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, 139 | grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, 140 | pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, 141 | with_bias); 142 | #else 143 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 144 | #endif 145 | } 146 | AT_ERROR("modulated deform conv is not implemented on CPU"); 147 | } 148 | 149 | 150 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 151 | m.def("deform_conv_forward", &deform_conv_forward, 152 | "deform forward"); 153 | m.def("deform_conv_backward_input", &deform_conv_backward_input, 154 | "deform_conv_backward_input"); 155 | m.def("deform_conv_backward_parameters", 156 | &deform_conv_backward_parameters, 157 | "deform_conv_backward_parameters"); 158 | m.def("modulated_deform_conv_forward", 159 | &modulated_deform_conv_forward, 160 | "modulated deform conv forward"); 161 | m.def("modulated_deform_conv_backward", 162 | &modulated_deform_conv_backward, 163 | "modulated deform conv backward"); 164 | } 165 | -------------------------------------------------------------------------------- /codeformer/basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ["FusedLeakyReLU", "fused_leaky_relu"] 4 | -------------------------------------------------------------------------------- /codeformer/basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | 7 | try: 8 | from . import fused_act_ext 9 | except ImportError: 10 | import os 11 | 12 | BASICSR_JIT = os.getenv("BASICSR_JIT") 13 | if BASICSR_JIT == "True": 14 | from torch.utils.cpp_extension import load 15 | 16 | module_path = os.path.dirname(__file__) 17 | fused_act_ext = load( 18 | "fused", 19 | sources=[ 20 | os.path.join(module_path, "src", "fused_bias_act.cpp"), 21 | os.path.join(module_path, "src", "fused_bias_act_kernel.cu"), 22 | ], 23 | ) 24 | 25 | 26 | class FusedLeakyReLUFunctionBackward(Function): 27 | @staticmethod 28 | def forward(ctx, grad_output, out, negative_slope, scale): 29 | ctx.save_for_backward(out) 30 | ctx.negative_slope = negative_slope 31 | ctx.scale = scale 32 | 33 | empty = grad_output.new_empty(0) 34 | 35 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 36 | 37 | dim = [0] 38 | 39 | if grad_input.ndim > 2: 40 | dim += list(range(2, grad_input.ndim)) 41 | 42 | grad_bias = grad_input.sum(dim).detach() 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | (out,) = ctx.saved_tensors 49 | gradgrad_out = fused_act_ext.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 61 | ctx.save_for_backward(out) 62 | ctx.negative_slope = negative_slope 63 | ctx.scale = scale 64 | 65 | return out 66 | 67 | @staticmethod 68 | def backward(ctx, grad_output): 69 | (out,) = ctx.saved_tensors 70 | 71 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 72 | 73 | return grad_input, grad_bias, None, None 74 | 75 | 76 | class FusedLeakyReLU(nn.Module): 77 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 78 | super().__init__() 79 | 80 | self.bias = nn.Parameter(torch.zeros(channel)) 81 | self.negative_slope = negative_slope 82 | self.scale = scale 83 | 84 | def forward(self, input): 85 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 86 | 87 | 88 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 89 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 90 | -------------------------------------------------------------------------------- /codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /codeformer/basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ["upfirdn2d"] 4 | -------------------------------------------------------------------------------- /codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.nn import functional as F 6 | 7 | try: 8 | from . import upfirdn2d_ext 9 | except ImportError: 10 | import os 11 | 12 | BASICSR_JIT = os.getenv("BASICSR_JIT") 13 | if BASICSR_JIT == "True": 14 | from torch.utils.cpp_extension import load 15 | 16 | module_path = os.path.dirname(__file__) 17 | upfirdn2d_ext = load( 18 | "upfirdn2d", 19 | sources=[ 20 | os.path.join(module_path, "src", "upfirdn2d.cpp"), 21 | os.path.join(module_path, "src", "upfirdn2d_kernel.cu"), 22 | ], 23 | ) 24 | 25 | 26 | class UpFirDn2dBackward(Function): 27 | @staticmethod 28 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 29 | 30 | up_x, up_y = up 31 | down_x, down_y = down 32 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 33 | 34 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 35 | 36 | grad_input = upfirdn2d_ext.upfirdn2d( 37 | grad_output, 38 | grad_kernel, 39 | down_x, 40 | down_y, 41 | up_x, 42 | up_y, 43 | g_pad_x0, 44 | g_pad_x1, 45 | g_pad_y0, 46 | g_pad_y1, 47 | ) 48 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 49 | 50 | ctx.save_for_backward(kernel) 51 | 52 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 53 | 54 | ctx.up_x = up_x 55 | ctx.up_y = up_y 56 | ctx.down_x = down_x 57 | ctx.down_y = down_y 58 | ctx.pad_x0 = pad_x0 59 | ctx.pad_x1 = pad_x1 60 | ctx.pad_y0 = pad_y0 61 | ctx.pad_y1 = pad_y1 62 | ctx.in_size = in_size 63 | ctx.out_size = out_size 64 | 65 | return grad_input 66 | 67 | @staticmethod 68 | def backward(ctx, gradgrad_input): 69 | (kernel,) = ctx.saved_tensors 70 | 71 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 72 | 73 | gradgrad_out = upfirdn2d_ext.upfirdn2d( 74 | gradgrad_input, 75 | kernel, 76 | ctx.up_x, 77 | ctx.up_y, 78 | ctx.down_x, 79 | ctx.down_y, 80 | ctx.pad_x0, 81 | ctx.pad_x1, 82 | ctx.pad_y0, 83 | ctx.pad_y1, 84 | ) 85 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], 86 | # ctx.out_size[1], ctx.in_size[3]) 87 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 88 | 89 | return gradgrad_out, None, None, None, None, None, None, None, None 90 | 91 | 92 | class UpFirDn2d(Function): 93 | @staticmethod 94 | def forward(ctx, input, kernel, up, down, pad): 95 | up_x, up_y = up 96 | down_x, down_y = down 97 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 98 | 99 | kernel_h, kernel_w = kernel.shape 100 | batch, channel, in_h, in_w = input.shape 101 | ctx.in_size = input.shape 102 | 103 | input = input.reshape(-1, in_h, in_w, 1) 104 | 105 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 106 | 107 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 108 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 109 | ctx.out_size = (out_h, out_w) 110 | 111 | ctx.up = (up_x, up_y) 112 | ctx.down = (down_x, down_y) 113 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 114 | 115 | g_pad_x0 = kernel_w - pad_x0 - 1 116 | g_pad_y0 = kernel_h - pad_y0 - 1 117 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 118 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 119 | 120 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 121 | 122 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 123 | # out = out.view(major, out_h, out_w, minor) 124 | out = out.view(-1, channel, out_h, out_w) 125 | 126 | return out 127 | 128 | @staticmethod 129 | def backward(ctx, grad_output): 130 | kernel, grad_kernel = ctx.saved_tensors 131 | 132 | grad_input = UpFirDn2dBackward.apply( 133 | grad_output, 134 | kernel, 135 | grad_kernel, 136 | ctx.up, 137 | ctx.down, 138 | ctx.pad, 139 | ctx.g_pad, 140 | ctx.in_size, 141 | ctx.out_size, 142 | ) 143 | 144 | return grad_input, None, None, None, None 145 | 146 | 147 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 148 | if input.device.type == "cpu": 149 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 150 | else: 151 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 152 | 153 | return out 154 | 155 | 156 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 157 | _, channel, in_h, in_w = input.shape 158 | input = input.reshape(-1, in_h, in_w, 1) 159 | 160 | _, in_h, in_w, minor = input.shape 161 | kernel_h, kernel_w = kernel.shape 162 | 163 | out = input.view(-1, in_h, 1, in_w, 1, minor) 164 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 165 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 166 | 167 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 168 | out = out[ 169 | :, 170 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 171 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 172 | :, 173 | ] 174 | 175 | out = out.permute(0, 3, 1, 2) 176 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 177 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 178 | out = F.conv2d(out, w) 179 | out = out.reshape( 180 | -1, 181 | minor, 182 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 183 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 184 | ) 185 | out = out.permute(0, 2, 3, 1) 186 | out = out[:, ::down_y, ::down_x, :] 187 | 188 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 189 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 190 | 191 | return out.view(-1, channel, out_h, out_w) 192 | -------------------------------------------------------------------------------- /codeformer/basicsr/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import subprocess 5 | import sys 6 | import time 7 | 8 | from setuptools import find_packages, setup 9 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 10 | from utils.misc import gpu_is_available 11 | 12 | version_file = ".codeformer/basicsr/version.py" 13 | 14 | 15 | def readme(): 16 | with open("README.md", encoding="utf-8") as f: 17 | content = f.read() 18 | return content 19 | 20 | 21 | def get_git_hash(): 22 | def _minimal_ext_cmd(cmd): 23 | # construct minimal environment 24 | env = {} 25 | for k in ["SYSTEMROOT", "PATH", "HOME"]: 26 | v = os.environ.get(k) 27 | if v is not None: 28 | env[k] = v 29 | # LANGUAGE is used on win32 30 | env["LANGUAGE"] = "C" 31 | env["LANG"] = "C" 32 | env["LC_ALL"] = "C" 33 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 34 | return out 35 | 36 | try: 37 | out = _minimal_ext_cmd(["git", "rev-parse", "HEAD"]) 38 | sha = out.strip().decode("ascii") 39 | except OSError: 40 | sha = "unknown" 41 | 42 | return sha 43 | 44 | 45 | def get_hash(): 46 | if os.path.exists(".git"): 47 | sha = get_git_hash()[:7] 48 | elif os.path.exists(version_file): 49 | try: 50 | from version import __version__ 51 | 52 | sha = __version__.split("+")[-1] 53 | except ImportError: 54 | raise ImportError("Unable to get git version") 55 | else: 56 | sha = "unknown" 57 | 58 | return sha 59 | 60 | 61 | def write_version_py(): 62 | content = """# GENERATED VERSION FILE 63 | # TIME: {} 64 | __version__ = '{}' 65 | __gitsha__ = '{}' 66 | version_info = ({}) 67 | """ 68 | sha = get_hash() 69 | with open("./basicsr/VERSION", "r") as f: 70 | SHORT_VERSION = f.read().strip() 71 | VERSION_INFO = ", ".join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split(".")]) 72 | 73 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) 74 | with open(version_file, "w") as f: 75 | f.write(version_file_str) 76 | 77 | 78 | def get_version(): 79 | with open(version_file, "r") as f: 80 | exec(compile(f.read(), version_file, "exec")) 81 | return locals()["__version__"] 82 | 83 | 84 | def make_cuda_ext(name, module, sources, sources_cuda=None): 85 | if sources_cuda is None: 86 | sources_cuda = [] 87 | define_macros = [] 88 | extra_compile_args = {"cxx": []} 89 | 90 | # if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 91 | if gpu_is_available or os.getenv("FORCE_CUDA", "0") == "1": 92 | define_macros += [("WITH_CUDA", None)] 93 | extension = CUDAExtension 94 | extra_compile_args["nvcc"] = [ 95 | "-D__CUDA_NO_HALF_OPERATORS__", 96 | "-D__CUDA_NO_HALF_CONVERSIONS__", 97 | "-D__CUDA_NO_HALF2_OPERATORS__", 98 | ] 99 | sources += sources_cuda 100 | else: 101 | print(f"Compiling {name} without CUDA") 102 | extension = CppExtension 103 | 104 | return extension( 105 | name=f"{module}.{name}", 106 | sources=[os.path.join(*module.split("."), p) for p in sources], 107 | define_macros=define_macros, 108 | extra_compile_args=extra_compile_args, 109 | ) 110 | 111 | 112 | def get_requirements(filename="requirements.txt"): 113 | with open(os.path.join(".", filename), "r") as f: 114 | requires = [line.replace("\n", "") for line in f.readlines()] 115 | return requires 116 | 117 | 118 | if __name__ == "__main__": 119 | if "--cuda_ext" in sys.argv: 120 | ext_modules = [ 121 | make_cuda_ext( 122 | name="deform_conv_ext", 123 | module="ops.dcn", 124 | sources=["src/deform_conv_ext.cpp"], 125 | sources_cuda=["src/deform_conv_cuda.cpp", "src/deform_conv_cuda_kernel.cu"], 126 | ), 127 | make_cuda_ext( 128 | name="fused_act_ext", 129 | module="ops.fused_act", 130 | sources=["src/fused_bias_act.cpp"], 131 | sources_cuda=["src/fused_bias_act_kernel.cu"], 132 | ), 133 | make_cuda_ext( 134 | name="upfirdn2d_ext", 135 | module="ops.upfirdn2d", 136 | sources=["src/upfirdn2d.cpp"], 137 | sources_cuda=["src/upfirdn2d_kernel.cu"], 138 | ), 139 | ] 140 | sys.argv.remove("--cuda_ext") 141 | else: 142 | ext_modules = [] 143 | 144 | write_version_py() 145 | setup( 146 | name="basicsr", 147 | version=get_version(), 148 | description="Open Source Image and Video Super-Resolution Toolbox", 149 | long_description=readme(), 150 | long_description_content_type="text/markdown", 151 | author="Xintao Wang", 152 | author_email="xintao.wang@outlook.com", 153 | keywords="computer vision, restoration, super resolution", 154 | url="https://github.com/xinntao/BasicSR", 155 | include_package_data=True, 156 | packages=find_packages(exclude=("options", "datasets", "experiments", "results", "tb_logger", "wandb")), 157 | classifiers=[ 158 | "Development Status :: 4 - Beta", 159 | "License :: OSI Approved :: Apache Software License", 160 | "Operating System :: OS Independent", 161 | "Programming Language :: Python :: 3", 162 | "Programming Language :: Python :: 3.7", 163 | "Programming Language :: Python :: 3.8", 164 | ], 165 | license="Apache License 2.0", 166 | setup_requires=["cython", "numpy"], 167 | install_requires=get_requirements(), 168 | ext_modules=ext_modules, 169 | cmdclass={"build_ext": BuildExtension}, 170 | zip_safe=False, 171 | ) 172 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from codeformer.basicsr.utils.file_client import FileClient 2 | from codeformer.basicsr.utils.img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 3 | from codeformer.basicsr.utils.logger import ( 4 | MessageLogger, 5 | get_env_info, 6 | get_root_logger, 7 | init_tb_logger, 8 | init_wandb_logger, 9 | ) 10 | from codeformer.basicsr.utils.misc import * 11 | from codeformer.basicsr.utils.misc import ( 12 | check_resume, 13 | get_time_str, 14 | make_exp_dirs, 15 | mkdir_and_rename, 16 | scandir, 17 | set_random_seed, 18 | sizeof_fmt, 19 | ) 20 | from codeformer.basicsr.utils.realesrgan_utils import RealESRGANer 21 | 22 | __all__ = [ 23 | # file_client.py 24 | "FileClient", 25 | # img_util.py 26 | "img2tensor", 27 | "tensor2img", 28 | "imfrombytes", 29 | "imwrite", 30 | "crop_border", 31 | # logger.py 32 | "MessageLogger", 33 | "init_tb_logger", 34 | "init_wandb_logger", 35 | "get_root_logger", 36 | "get_env_info", 37 | # misc.py 38 | "set_random_seed", 39 | "get_time_str", 40 | "mkdir_and_rename", 41 | "make_exp_dirs", 42 | "scandir", 43 | "check_resume", 44 | "sizeof_fmt", 45 | ] 46 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | 10 | 11 | def init_dist(launcher, backend="nccl", **kwargs): 12 | if mp.get_start_method(allow_none=True) is None: 13 | mp.set_start_method("spawn") 14 | if launcher == "pytorch": 15 | _init_dist_pytorch(backend, **kwargs) 16 | elif launcher == "slurm": 17 | _init_dist_slurm(backend, **kwargs) 18 | else: 19 | raise ValueError(f"Invalid launcher type: {launcher}") 20 | 21 | 22 | def _init_dist_pytorch(backend, **kwargs): 23 | rank = int(os.environ["RANK"]) 24 | num_gpus = torch.cuda.device_count() 25 | torch.cuda.set_device(rank % num_gpus) 26 | dist.init_process_group(backend=backend, **kwargs) 27 | 28 | 29 | def _init_dist_slurm(backend, port=None): 30 | """Initialize slurm distributed training environment. 31 | 32 | If argument ``port`` is not specified, then the master port will be system 33 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 34 | environment variable, then a default port ``29500`` will be used. 35 | 36 | Args: 37 | backend (str): Backend of torch.distributed. 38 | port (int, optional): Master port. Defaults to None. 39 | """ 40 | proc_id = int(os.environ["SLURM_PROCID"]) 41 | ntasks = int(os.environ["SLURM_NTASKS"]) 42 | node_list = os.environ["SLURM_NODELIST"] 43 | num_gpus = torch.cuda.device_count() 44 | torch.cuda.set_device(proc_id % num_gpus) 45 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") 46 | # specify master port 47 | if port is not None: 48 | os.environ["MASTER_PORT"] = str(port) 49 | elif "MASTER_PORT" in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ["MASTER_PORT"] = "29500" 54 | os.environ["MASTER_ADDR"] = addr 55 | os.environ["WORLD_SIZE"] = str(ntasks) 56 | os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) 57 | os.environ["RANK"] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from urllib.parse import urlparse 4 | 5 | import requests 6 | from torch.hub import download_url_to_file, get_dir 7 | from tqdm import tqdm 8 | 9 | from codeformer.basicsr.utils.misc import sizeof_fmt 10 | 11 | 12 | def download_file_from_google_drive(file_id, save_path): 13 | """Download files from google drive. 14 | Ref: 15 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 16 | Args: 17 | file_id (str): File id. 18 | save_path (str): Save path. 19 | """ 20 | 21 | session = requests.Session() 22 | URL = "https://docs.google.com/uc?export=download" 23 | params = {"id": file_id} 24 | 25 | response = session.get(URL, params=params, stream=True) 26 | token = get_confirm_token(response) 27 | if token: 28 | params["confirm"] = token 29 | response = session.get(URL, params=params, stream=True) 30 | 31 | # get file size 32 | response_file_size = session.get(URL, params=params, stream=True, headers={"Range": "bytes=0-2"}) 33 | print(response_file_size) 34 | if "Content-Range" in response_file_size.headers: 35 | file_size = int(response_file_size.headers["Content-Range"].split("/")[1]) 36 | else: 37 | file_size = None 38 | 39 | save_response_content(response, save_path, file_size) 40 | 41 | 42 | def get_confirm_token(response): 43 | for key, value in response.cookies.items(): 44 | if key.startswith("download_warning"): 45 | return value 46 | return None 47 | 48 | 49 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 50 | if file_size is not None: 51 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit="chunk") 52 | 53 | readable_file_size = sizeof_fmt(file_size) 54 | else: 55 | pbar = None 56 | 57 | with open(destination, "wb") as f: 58 | downloaded_size = 0 59 | for chunk in response.iter_content(chunk_size): 60 | downloaded_size += chunk_size 61 | if pbar is not None: 62 | pbar.update(1) 63 | pbar.set_description(f"Download {sizeof_fmt(downloaded_size)} / {readable_file_size}") 64 | if chunk: # filter out keep-alive new chunks 65 | f.write(chunk) 66 | if pbar is not None: 67 | pbar.close() 68 | 69 | 70 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 71 | """Load file form http url, will download models if necessary. 72 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 73 | Args: 74 | url (str): URL to be downloaded. 75 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 76 | Default: None. 77 | progress (bool): Whether to show the download progress. Default: True. 78 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 79 | Returns: 80 | str: The path to the downloaded file. 81 | """ 82 | if model_dir is None: # use the pytorch hub_dir 83 | hub_dir = get_dir() 84 | model_dir = os.path.join(hub_dir, "checkpoints") 85 | 86 | os.makedirs(model_dir, exist_ok=True) 87 | 88 | parts = urlparse(url) 89 | filename = os.path.basename(parts.path) 90 | if file_name is not None: 91 | filename = file_name 92 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 93 | if not os.path.exists(cached_file): 94 | print(f'Downloading: "{url}" to {cached_file}\n') 95 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 96 | return cached_file 97 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | 36 | sys.path.append(sys_path) 37 | try: 38 | import mc 39 | except ImportError: 40 | raise ImportError("Please install memcached to enable MemcachedBackend.") 41 | 42 | self.server_list_cfg = server_list_cfg 43 | self.client_cfg = client_cfg 44 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 45 | # mc.pyvector servers as a point which points to a memory cache 46 | self._mc_buffer = mc.pyvector() 47 | 48 | def get(self, filepath): 49 | filepath = str(filepath) 50 | import mc 51 | 52 | self._client.Get(filepath, self._mc_buffer) 53 | value_buf = mc.ConvertBuffer(self._mc_buffer) 54 | return value_buf 55 | 56 | def get_text(self, filepath): 57 | raise NotImplementedError 58 | 59 | 60 | class HardDiskBackend(BaseStorageBackend): 61 | """Raw hard disks storage backend.""" 62 | 63 | def get(self, filepath): 64 | filepath = str(filepath) 65 | with open(filepath, "rb") as f: 66 | value_buf = f.read() 67 | return value_buf 68 | 69 | def get_text(self, filepath): 70 | filepath = str(filepath) 71 | with open(filepath, "r") as f: 72 | value_buf = f.read() 73 | return value_buf 74 | 75 | 76 | class LmdbBackend(BaseStorageBackend): 77 | """Lmdb storage backend. 78 | 79 | Args: 80 | db_paths (str | list[str]): Lmdb database paths. 81 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 82 | readonly (bool, optional): Lmdb environment parameter. If True, 83 | disallow any write operations. Default: True. 84 | lock (bool, optional): Lmdb environment parameter. If False, when 85 | concurrent access occurs, do not lock the database. Default: False. 86 | readahead (bool, optional): Lmdb environment parameter. If False, 87 | disable the OS filesystem readahead mechanism, which may improve 88 | random read performance when a database is larger than RAM. 89 | Default: False. 90 | 91 | Attributes: 92 | db_paths (list): Lmdb database path. 93 | _client (list): A list of several lmdb envs. 94 | """ 95 | 96 | def __init__(self, db_paths, client_keys="default", readonly=True, lock=False, readahead=False, **kwargs): 97 | try: 98 | import lmdb 99 | except ImportError: 100 | raise ImportError("Please install lmdb to enable LmdbBackend.") 101 | 102 | if isinstance(client_keys, str): 103 | client_keys = [client_keys] 104 | 105 | if isinstance(db_paths, list): 106 | self.db_paths = [str(v) for v in db_paths] 107 | elif isinstance(db_paths, str): 108 | self.db_paths = [str(db_paths)] 109 | assert len(client_keys) == len(self.db_paths), ( 110 | "client_keys and db_paths should have the same length, " 111 | f"but received {len(client_keys)} and {len(self.db_paths)}." 112 | ) 113 | 114 | self._client = {} 115 | for client, path in zip(client_keys, self.db_paths): 116 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 117 | 118 | def get(self, filepath, client_key): 119 | """Get values according to the filepath from one lmdb named client_key. 120 | 121 | Args: 122 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 123 | client_key (str): Used for distinguishing differnet lmdb envs. 124 | """ 125 | filepath = str(filepath) 126 | assert client_key in self._client, f"client_key {client_key} is not " "in lmdb clients." 127 | client = self._client[client_key] 128 | with client.begin(write=False) as txn: 129 | value_buf = txn.get(filepath.encode("ascii")) 130 | return value_buf 131 | 132 | def get_text(self, filepath): 133 | raise NotImplementedError 134 | 135 | 136 | class FileClient(object): 137 | """A general file client to access files in different backend. 138 | 139 | The client loads a file or text in a specified backend from its path 140 | and return it as a binary file. it can also register other backend 141 | accessor with a given name and backend class. 142 | 143 | Attributes: 144 | backend (str): The storage backend type. Options are "disk", 145 | "memcached" and "lmdb". 146 | client (:obj:`BaseStorageBackend`): The backend object. 147 | """ 148 | 149 | _backends = { 150 | "disk": HardDiskBackend, 151 | "memcached": MemcachedBackend, 152 | "lmdb": LmdbBackend, 153 | } 154 | 155 | def __init__(self, backend="disk", **kwargs): 156 | if backend not in self._backends: 157 | raise ValueError( 158 | f"Backend {backend} is not supported. Currently supported ones" f" are {list(self._backends.keys())}" 159 | ) 160 | self.backend = backend 161 | self.client = self._backends[backend](**kwargs) 162 | 163 | def get(self, filepath, client_key="default"): 164 | # client_key is used only for lmdb, where different fileclients have 165 | # different lmdb environments. 166 | if self.backend == "lmdb": 167 | return self.client.get(filepath, client_key) 168 | else: 169 | return self.client.get(filepath) 170 | 171 | def get_text(self, filepath): 172 | return self.client.get_text(filepath) 173 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from torchvision.utils import make_grid 8 | 9 | 10 | def img2tensor(imgs, bgr2rgb=True, float32=True): 11 | """Numpy array to tensor. 12 | 13 | Args: 14 | imgs (list[ndarray] | ndarray): Input images. 15 | bgr2rgb (bool): Whether to change bgr to rgb. 16 | float32 (bool): Whether to change to float32. 17 | 18 | Returns: 19 | list[tensor] | tensor: Tensor images. If returned results only have 20 | one element, just return tensor. 21 | """ 22 | 23 | def _totensor(img, bgr2rgb, float32): 24 | if img.shape[2] == 3 and bgr2rgb: 25 | if img.dtype == "float64": 26 | img = img.astype("float32") 27 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 28 | img = torch.from_numpy(img.transpose(2, 0, 1)) 29 | if float32: 30 | img = img.float() 31 | return img 32 | 33 | if isinstance(imgs, list): 34 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 35 | else: 36 | return _totensor(imgs, bgr2rgb, float32) 37 | 38 | 39 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 40 | """Convert torch Tensors into image numpy arrays. 41 | 42 | After clamping to [min, max], values will be normalized to [0, 1]. 43 | 44 | Args: 45 | tensor (Tensor or list[Tensor]): Accept shapes: 46 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 47 | 2) 3D Tensor of shape (3/1 x H x W); 48 | 3) 2D Tensor of shape (H x W). 49 | Tensor channel should be in RGB order. 50 | rgb2bgr (bool): Whether to change rgb to bgr. 51 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 52 | to uint8 type with range [0, 255]; otherwise, float type with 53 | range [0, 1]. Default: ``np.uint8``. 54 | min_max (tuple[int]): min and max values for clamp. 55 | 56 | Returns: 57 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 58 | shape (H x W). The channel order is BGR. 59 | """ 60 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 61 | raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") 62 | 63 | if torch.is_tensor(tensor): 64 | tensor = [tensor] 65 | result = [] 66 | for _tensor in tensor: 67 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 68 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 69 | 70 | n_dim = _tensor.dim() 71 | if n_dim == 4: 72 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 73 | img_np = img_np.transpose(1, 2, 0) 74 | if rgb2bgr: 75 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 76 | elif n_dim == 3: 77 | img_np = _tensor.numpy() 78 | img_np = img_np.transpose(1, 2, 0) 79 | if img_np.shape[2] == 1: # gray image 80 | img_np = np.squeeze(img_np, axis=2) 81 | else: 82 | if rgb2bgr: 83 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 84 | elif n_dim == 2: 85 | img_np = _tensor.numpy() 86 | else: 87 | raise TypeError("Only support 4D, 3D or 2D tensor. " f"But received with dimension: {n_dim}") 88 | if out_type == np.uint8: 89 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 90 | img_np = (img_np * 255.0).round() 91 | img_np = img_np.astype(out_type) 92 | result.append(img_np) 93 | if len(result) == 1: 94 | result = result[0] 95 | return result 96 | 97 | 98 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 99 | """This implementation is slightly faster than tensor2img. 100 | It now only supports torch tensor with shape (1, c, h, w). 101 | 102 | Args: 103 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 104 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 105 | min_max (tuple[int]): min and max values for clamp. 106 | """ 107 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 108 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 109 | output = output.type(torch.uint8).cpu().numpy() 110 | if rgb2bgr: 111 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 112 | return output 113 | 114 | 115 | def imfrombytes(content, flag="color", float32=False): 116 | """Read an image from bytes. 117 | 118 | Args: 119 | content (bytes): Image bytes got from files or other streams. 120 | flag (str): Flags specifying the color type of a loaded image, 121 | candidates are `color`, `grayscale` and `unchanged`. 122 | float32 (bool): Whether to change to float32., If True, will also norm 123 | to [0, 1]. Default: False. 124 | 125 | Returns: 126 | ndarray: Loaded image array. 127 | """ 128 | img_np = np.frombuffer(content, np.uint8) 129 | imread_flags = {"color": cv2.IMREAD_COLOR, "grayscale": cv2.IMREAD_GRAYSCALE, "unchanged": cv2.IMREAD_UNCHANGED} 130 | img = cv2.imdecode(img_np, imread_flags[flag]) 131 | if float32: 132 | img = img.astype(np.float32) / 255.0 133 | return img 134 | 135 | 136 | def imwrite(img, file_path, params=None, auto_mkdir=True): 137 | """Write image to file. 138 | 139 | Args: 140 | img (ndarray): Image array to be written. 141 | file_path (str): Image file path. 142 | params (None or list): Same as opencv's :func:`imwrite` interface. 143 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 144 | whether to create it automatically. 145 | 146 | Returns: 147 | bool: Successful or not. 148 | """ 149 | if auto_mkdir: 150 | dir_name = os.path.abspath(os.path.dirname(file_path)) 151 | os.makedirs(dir_name, exist_ok=True) 152 | return cv2.imwrite(file_path, img, params) 153 | 154 | 155 | def crop_border(imgs, crop_border): 156 | """Crop borders of images. 157 | 158 | Args: 159 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 160 | crop_border (int): Crop border for each end of height and weight. 161 | 162 | Returns: 163 | list[ndarray]: Cropped images. 164 | """ 165 | if crop_border == 0: 166 | return imgs 167 | else: 168 | if isinstance(imgs, list): 169 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 170 | else: 171 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 172 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from multiprocessing import Pool 3 | from os import path as osp 4 | 5 | import cv2 6 | import lmdb 7 | from tqdm import tqdm 8 | 9 | 10 | def make_lmdb_from_imgs( 11 | data_path, 12 | lmdb_path, 13 | img_path_list, 14 | keys, 15 | batch=5000, 16 | compress_level=1, 17 | multiprocessing_read=False, 18 | n_thread=40, 19 | map_size=None, 20 | ): 21 | """Make lmdb from images. 22 | 23 | Contents of lmdb. The file structure is: 24 | example.lmdb 25 | ├── data.mdb 26 | ├── lock.mdb 27 | ├── meta_info.txt 28 | 29 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 30 | https://lmdb.readthedocs.io/en/release/ for more details. 31 | 32 | The meta_info.txt is a specified txt file to record the meta information 33 | of our datasets. It will be automatically created when preparing 34 | datasets by our provided dataset tools. 35 | Each line in the txt file records 1)image name (with extension), 36 | 2)image shape, and 3)compression level, separated by a white space. 37 | 38 | For example, the meta information could be: 39 | `000_00000000.png (720,1280,3) 1`, which means: 40 | 1) image name (with extension): 000_00000000.png; 41 | 2) image shape: (720,1280,3); 42 | 3) compression level: 1 43 | 44 | We use the image name without extension as the lmdb key. 45 | 46 | If `multiprocessing_read` is True, it will read all the images to memory 47 | using multiprocessing. Thus, your server needs to have enough memory. 48 | 49 | Args: 50 | data_path (str): Data path for reading images. 51 | lmdb_path (str): Lmdb save path. 52 | img_path_list (str): Image path list. 53 | keys (str): Used for lmdb keys. 54 | batch (int): After processing batch images, lmdb commits. 55 | Default: 5000. 56 | compress_level (int): Compress level when encoding images. Default: 1. 57 | multiprocessing_read (bool): Whether use multiprocessing to read all 58 | the images to memory. Default: False. 59 | n_thread (int): For multiprocessing. 60 | map_size (int | None): Map size for lmdb env. If None, use the 61 | estimated size from images. Default: None 62 | """ 63 | 64 | assert len(img_path_list) == len(keys), ( 65 | "img_path_list and keys should have the same length, " f"but got {len(img_path_list)} and {len(keys)}" 66 | ) 67 | print(f"Create lmdb for {data_path}, save to {lmdb_path}...") 68 | print(f"Totoal images: {len(img_path_list)}") 69 | if not lmdb_path.endswith(".lmdb"): 70 | raise ValueError("lmdb_path must end with '.lmdb'.") 71 | if osp.exists(lmdb_path): 72 | print(f"Folder {lmdb_path} already exists. Exit.") 73 | sys.exit(1) 74 | 75 | if multiprocessing_read: 76 | # read all the images to memory (multiprocessing) 77 | dataset = {} # use dict to keep the order for multiprocessing 78 | shapes = {} 79 | print(f"Read images with multiprocessing, #thread: {n_thread} ...") 80 | pbar = tqdm(total=len(img_path_list), unit="image") 81 | 82 | def callback(arg): 83 | """get the image data and update pbar.""" 84 | key, dataset[key], shapes[key] = arg 85 | pbar.update(1) 86 | pbar.set_description(f"Read {key}") 87 | 88 | pool = Pool(n_thread) 89 | for path, key in zip(img_path_list, keys): 90 | pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) 91 | pool.close() 92 | pool.join() 93 | pbar.close() 94 | print(f"Finish reading {len(img_path_list)} images.") 95 | 96 | # create lmdb environment 97 | if map_size is None: 98 | # obtain data size for one image 99 | img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 100 | _, img_byte = cv2.imencode(".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 101 | data_size_per_img = img_byte.nbytes 102 | print("Data size per image is: ", data_size_per_img) 103 | data_size = data_size_per_img * len(img_path_list) 104 | map_size = data_size * 10 105 | 106 | env = lmdb.open(lmdb_path, map_size=map_size) 107 | 108 | # write data to lmdb 109 | pbar = tqdm(total=len(img_path_list), unit="chunk") 110 | txn = env.begin(write=True) 111 | txt_file = open(osp.join(lmdb_path, "meta_info.txt"), "w") 112 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 113 | pbar.update(1) 114 | pbar.set_description(f"Write {key}") 115 | key_byte = key.encode("ascii") 116 | if multiprocessing_read: 117 | img_byte = dataset[key] 118 | h, w, c = shapes[key] 119 | else: 120 | _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) 121 | h, w, c = img_shape 122 | 123 | txn.put(key_byte, img_byte) 124 | # write meta information 125 | txt_file.write(f"{key}.png ({h},{w},{c}) {compress_level}\n") 126 | if idx % batch == 0: 127 | txn.commit() 128 | txn = env.begin(write=True) 129 | pbar.close() 130 | txn.commit() 131 | env.close() 132 | txt_file.close() 133 | print("\nFinish writing lmdb.") 134 | 135 | 136 | def read_img_worker(path, key, compress_level): 137 | """Read image worker. 138 | 139 | Args: 140 | path (str): Image path. 141 | key (str): Image key. 142 | compress_level (int): Compress level when encoding images. 143 | 144 | Returns: 145 | str: Image key. 146 | byte: Image byte. 147 | tuple[int]: Image shape. 148 | """ 149 | 150 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 151 | if img.ndim == 2: 152 | h, w = img.shape 153 | c = 1 154 | else: 155 | h, w, c = img.shape 156 | _, img_byte = cv2.imencode(".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 157 | return (key, img_byte, (h, w, c)) 158 | 159 | 160 | class LmdbMaker: 161 | """LMDB Maker. 162 | 163 | Args: 164 | lmdb_path (str): Lmdb save path. 165 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 166 | batch (int): After processing batch images, lmdb commits. 167 | Default: 5000. 168 | compress_level (int): Compress level when encoding images. Default: 1. 169 | """ 170 | 171 | def __init__(self, lmdb_path, map_size=1024 ** 4, batch=5000, compress_level=1): 172 | if not lmdb_path.endswith(".lmdb"): 173 | raise ValueError("lmdb_path must end with '.lmdb'.") 174 | if osp.exists(lmdb_path): 175 | print(f"Folder {lmdb_path} already exists. Exit.") 176 | sys.exit(1) 177 | 178 | self.lmdb_path = lmdb_path 179 | self.batch = batch 180 | self.compress_level = compress_level 181 | self.env = lmdb.open(lmdb_path, map_size=map_size) 182 | self.txn = self.env.begin(write=True) 183 | self.txt_file = open(osp.join(lmdb_path, "meta_info.txt"), "w") 184 | self.counter = 0 185 | 186 | def put(self, img_byte, key, img_shape): 187 | self.counter += 1 188 | key_byte = key.encode("ascii") 189 | self.txn.put(key_byte, img_byte) 190 | # write meta information 191 | h, w, c = img_shape 192 | self.txt_file.write(f"{key}.png ({h},{w},{c}) {self.compress_level}\n") 193 | if self.counter % self.batch == 0: 194 | self.txn.commit() 195 | self.txn = self.env.begin(write=True) 196 | 197 | def close(self): 198 | self.txn.commit() 199 | self.env.close() 200 | self.txt_file.close() 201 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from codeformer.basicsr.utils.dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class MessageLogger: 11 | """Message logger for printing. 12 | Args: 13 | opt (dict): Config. It contains the following keys: 14 | name (str): Exp name. 15 | logger (dict): Contains 'print_freq' (str) for logger interval. 16 | train (dict): Contains 'total_iter' (int) for total iters. 17 | use_tb_logger (bool): Use tensorboard logger. 18 | start_iter (int): Start iter. Default: 1. 19 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 20 | """ 21 | 22 | def __init__(self, opt, start_iter=1, tb_logger=None): 23 | self.exp_name = opt["name"] 24 | self.interval = opt["logger"]["print_freq"] 25 | self.start_iter = start_iter 26 | self.max_iters = opt["train"]["total_iter"] 27 | self.use_tb_logger = opt["logger"]["use_tb_logger"] 28 | self.tb_logger = tb_logger 29 | self.start_time = time.time() 30 | self.logger = get_root_logger() 31 | 32 | @master_only 33 | def __call__(self, log_vars): 34 | """Format logging message. 35 | Args: 36 | log_vars (dict): It contains the following keys: 37 | epoch (int): Epoch number. 38 | iter (int): Current iter. 39 | lrs (list): List for learning rates. 40 | time (float): Iter time. 41 | data_time (float): Data time for each iter. 42 | """ 43 | # epoch, iter, learning rates 44 | epoch = log_vars.pop("epoch") 45 | current_iter = log_vars.pop("iter") 46 | lrs = log_vars.pop("lrs") 47 | 48 | message = f"[{self.exp_name[:5]}..][epoch:{epoch:3d}, " f"iter:{current_iter:8,d}, lr:(" 49 | for v in lrs: 50 | message += f"{v:.3e}," 51 | message += ")] " 52 | 53 | # time and estimated time 54 | if "time" in log_vars.keys(): 55 | iter_time = log_vars.pop("time") 56 | data_time = log_vars.pop("data_time") 57 | 58 | total_time = time.time() - self.start_time 59 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 60 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 61 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 62 | message += f"[eta: {eta_str}, " 63 | message += f"time (data): {iter_time:.3f} ({data_time:.3f})] " 64 | 65 | # other items, especially losses 66 | for k, v in log_vars.items(): 67 | message += f"{k}: {v:.4e} " 68 | # tensorboard logger 69 | if self.use_tb_logger: 70 | if k.startswith("l_"): 71 | self.tb_logger.add_scalar(f"losses/{k}", v, current_iter) 72 | else: 73 | self.tb_logger.add_scalar(k, v, current_iter) 74 | self.logger.info(message) 75 | 76 | 77 | @master_only 78 | def init_tb_logger(log_dir): 79 | from torch.utils.tensorboard import SummaryWriter 80 | 81 | tb_logger = SummaryWriter(log_dir=log_dir) 82 | return tb_logger 83 | 84 | 85 | @master_only 86 | def init_wandb_logger(opt): 87 | """We now only use wandb to sync tensorboard log.""" 88 | import wandb 89 | 90 | logger = logging.getLogger("basicsr") 91 | 92 | project = opt["logger"]["wandb"]["project"] 93 | resume_id = opt["logger"]["wandb"].get("resume_id") 94 | if resume_id: 95 | wandb_id = resume_id 96 | resume = "allow" 97 | logger.warning(f"Resume wandb logger with id={wandb_id}.") 98 | else: 99 | wandb_id = wandb.util.generate_id() 100 | resume = "never" 101 | 102 | wandb.init(id=wandb_id, resume=resume, name=opt["name"], config=opt, project=project, sync_tensorboard=True) 103 | 104 | logger.info(f"Use wandb logger with id={wandb_id}; project={project}.") 105 | 106 | 107 | def get_root_logger(logger_name="basicsr", log_level=logging.INFO, log_file=None): 108 | """Get the root logger. 109 | The logger will be initialized if it has not been initialized. By default a 110 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 111 | also be added. 112 | Args: 113 | logger_name (str): root logger name. Default: 'basicsr'. 114 | log_file (str | None): The log filename. If specified, a FileHandler 115 | will be added to the root logger. 116 | log_level (int): The root logger level. Note that only the process of 117 | rank 0 is affected, while other processes will set the level to 118 | "Error" and be silent most of the time. 119 | Returns: 120 | logging.Logger: The root logger. 121 | """ 122 | logger = logging.getLogger(logger_name) 123 | # if the logger has been initialized, just return it 124 | if logger_name in initialized_logger: 125 | return logger 126 | 127 | format_str = "%(asctime)s %(levelname)s: %(message)s" 128 | stream_handler = logging.StreamHandler() 129 | stream_handler.setFormatter(logging.Formatter(format_str)) 130 | logger.addHandler(stream_handler) 131 | logger.propagate = False 132 | rank, _ = get_dist_info() 133 | if rank != 0: 134 | logger.setLevel("ERROR") 135 | elif log_file is not None: 136 | logger.setLevel(log_level) 137 | # add file handler 138 | # file_handler = logging.FileHandler(log_file, 'w') 139 | file_handler = logging.FileHandler(log_file, "a") # Shangchen: keep the previous log 140 | file_handler.setFormatter(logging.Formatter(format_str)) 141 | file_handler.setLevel(log_level) 142 | logger.addHandler(file_handler) 143 | initialized_logger[logger_name] = True 144 | return logger 145 | 146 | 147 | def get_env_info(): 148 | """Get environment information. 149 | Currently, only log the software version. 150 | """ 151 | import torch 152 | import torchvision 153 | from basicsr.version import __version__ 154 | 155 | msg = r""" 156 | ____ _ _____ ____ 157 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 158 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 159 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 160 | /_____/ \__,_//____//_/ \___//____//_/ |_| 161 | ______ __ __ __ __ 162 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 163 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 164 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 165 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 166 | """ 167 | msg += ( 168 | "\nVersion Information: " 169 | f"\n\tBasicSR: {__version__}" 170 | f"\n\tPyTorch: {torch.__version__}" 171 | f"\n\tTorchVision: {torchvision.__version__}" 172 | ) 173 | return msg 174 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | import time 5 | from os import path as osp 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from codeformer.basicsr.utils.dist_util import master_only 11 | from codeformer.basicsr.utils.logger import get_root_logger 12 | 13 | IS_HIGH_VERSION = [ 14 | int(m) 15 | for m in list( 16 | re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$", torch.__version__)[0][:3] 17 | ) 18 | ] >= [1, 12, 0] 19 | 20 | 21 | def gpu_is_available(): 22 | if IS_HIGH_VERSION: 23 | if torch.backends.mps.is_available(): 24 | return True 25 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False 26 | 27 | 28 | def get_device(gpu_id=None): 29 | if gpu_id is None: 30 | gpu_str = "" 31 | elif isinstance(gpu_id, int): 32 | gpu_str = f":{gpu_id}" 33 | else: 34 | raise TypeError("Input should be int value.") 35 | 36 | if IS_HIGH_VERSION: 37 | if torch.backends.mps.is_available(): 38 | return torch.device("mps" + gpu_str) 39 | return torch.device( 40 | "cuda" + gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else "cpu" 41 | ) 42 | 43 | 44 | def set_random_seed(seed): 45 | """Set random seeds.""" 46 | random.seed(seed) 47 | np.random.seed(seed) 48 | torch.manual_seed(seed) 49 | torch.cuda.manual_seed(seed) 50 | torch.cuda.manual_seed_all(seed) 51 | 52 | 53 | def get_time_str(): 54 | return time.strftime("%Y%m%d_%H%M%S", time.localtime()) 55 | 56 | 57 | def mkdir_and_rename(path): 58 | """mkdirs. If path exists, rename it with timestamp and create a new one. 59 | 60 | Args: 61 | path (str): Folder path. 62 | """ 63 | if osp.exists(path): 64 | new_name = path + "_archived_" + get_time_str() 65 | print(f"Path already exists. Rename it to {new_name}", flush=True) 66 | os.rename(path, new_name) 67 | os.makedirs(path, exist_ok=True) 68 | 69 | 70 | @master_only 71 | def make_exp_dirs(opt): 72 | """Make dirs for experiments.""" 73 | path_opt = opt["path"].copy() 74 | if opt["is_train"]: 75 | mkdir_and_rename(path_opt.pop("experiments_root")) 76 | else: 77 | mkdir_and_rename(path_opt.pop("results_root")) 78 | for key, path in path_opt.items(): 79 | if ("strict_load" not in key) and ("pretrain_network" not in key) and ("resume" not in key): 80 | os.makedirs(path, exist_ok=True) 81 | 82 | 83 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 84 | """Scan a directory to find the interested files. 85 | 86 | Args: 87 | dir_path (str): Path of the directory. 88 | suffix (str | tuple(str), optional): File suffix that we are 89 | interested in. Default: None. 90 | recursive (bool, optional): If set to True, recursively scan the 91 | directory. Default: False. 92 | full_path (bool, optional): If set to True, include the dir_path. 93 | Default: False. 94 | 95 | Returns: 96 | A generator for all the interested files with relative pathes. 97 | """ 98 | 99 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 100 | raise TypeError('"suffix" must be a string or tuple of strings') 101 | 102 | root = dir_path 103 | 104 | def _scandir(dir_path, suffix, recursive): 105 | for entry in os.scandir(dir_path): 106 | if not entry.name.startswith(".") and entry.is_file(): 107 | if full_path: 108 | return_path = entry.path 109 | else: 110 | return_path = osp.relpath(entry.path, root) 111 | 112 | if suffix is None: 113 | yield return_path 114 | elif return_path.endswith(suffix): 115 | yield return_path 116 | else: 117 | if recursive: 118 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 119 | else: 120 | continue 121 | 122 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 123 | 124 | 125 | def check_resume(opt, resume_iter): 126 | """Check resume states and pretrain_network paths. 127 | 128 | Args: 129 | opt (dict): Options. 130 | resume_iter (int): Resume iteration. 131 | """ 132 | logger = get_root_logger() 133 | if opt["path"]["resume_state"]: 134 | # get all the networks 135 | networks = [key for key in opt.keys() if key.startswith("network_")] 136 | flag_pretrain = False 137 | for network in networks: 138 | if opt["path"].get(f"pretrain_{network}") is not None: 139 | flag_pretrain = True 140 | if flag_pretrain: 141 | logger.warning("pretrain_network path will be ignored during resuming.") 142 | # set pretrained model paths 143 | for network in networks: 144 | name = f"pretrain_{network}" 145 | basename = network.replace("network_", "") 146 | if opt["path"].get("ignore_resume_networks") is None or ( 147 | basename not in opt["path"]["ignore_resume_networks"] 148 | ): 149 | opt["path"][name] = osp.join(opt["path"]["models"], f"net_{basename}_{resume_iter}.pth") 150 | logger.info(f"Set {name} to {opt['path'][name]}") 151 | 152 | 153 | def sizeof_fmt(size, suffix="B"): 154 | """Get human readable file size. 155 | 156 | Args: 157 | size (int): File size. 158 | suffix (str): Suffix. Default: 'B'. 159 | 160 | Return: 161 | str: Formated file siz. 162 | """ 163 | for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]: 164 | if abs(size) < 1024.0: 165 | return f"{size:3.1f} {unit}{suffix}" 166 | size /= 1024.0 167 | return f"{size:3.1f} Y{suffix}" 168 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import OrderedDict 3 | from os import path as osp 4 | 5 | import yaml 6 | 7 | from codeformer.basicsr.utils.misc import get_time_str 8 | 9 | 10 | def ordered_yaml(): 11 | """Support OrderedDict for yaml. 12 | 13 | Returns: 14 | yaml Loader and Dumper. 15 | """ 16 | try: 17 | from yaml import CDumper as Dumper 18 | from yaml import CLoader as Loader 19 | except ImportError: 20 | from yaml import Dumper, Loader 21 | 22 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 23 | 24 | def dict_representer(dumper, data): 25 | return dumper.represent_dict(data.items()) 26 | 27 | def dict_constructor(loader, node): 28 | return OrderedDict(loader.construct_pairs(node)) 29 | 30 | Dumper.add_representer(OrderedDict, dict_representer) 31 | Loader.add_constructor(_mapping_tag, dict_constructor) 32 | return Loader, Dumper 33 | 34 | 35 | def parse(opt_path, root_path, is_train=True): 36 | """Parse option file. 37 | 38 | Args: 39 | opt_path (str): Option file path. 40 | is_train (str): Indicate whether in training or not. Default: True. 41 | 42 | Returns: 43 | (dict): Options. 44 | """ 45 | with open(opt_path, mode="r") as f: 46 | Loader, _ = ordered_yaml() 47 | opt = yaml.load(f, Loader=Loader) 48 | 49 | opt["is_train"] = is_train 50 | 51 | # opt['name'] = f"{get_time_str()}_{opt['name']}" 52 | if opt["path"].get("resume_state", None): # Shangchen added 53 | resume_state_path = opt["path"].get("resume_state") 54 | opt["name"] = resume_state_path.split("/")[-3] 55 | else: 56 | opt["name"] = f"{get_time_str()}_{opt['name']}" 57 | 58 | # datasets 59 | for phase, dataset in opt["datasets"].items(): 60 | # for several datasets, e.g., test_1, test_2 61 | phase = phase.split("_")[0] 62 | dataset["phase"] = phase 63 | if "scale" in opt: 64 | dataset["scale"] = opt["scale"] 65 | if dataset.get("dataroot_gt") is not None: 66 | dataset["dataroot_gt"] = osp.expanduser(dataset["dataroot_gt"]) 67 | if dataset.get("dataroot_lq") is not None: 68 | dataset["dataroot_lq"] = osp.expanduser(dataset["dataroot_lq"]) 69 | 70 | # paths 71 | for key, val in opt["path"].items(): 72 | if (val is not None) and ("resume_state" in key or "pretrain_network" in key): 73 | opt["path"][key] = osp.expanduser(val) 74 | 75 | if is_train: 76 | experiments_root = osp.join(root_path, "experiments", opt["name"]) 77 | opt["path"]["experiments_root"] = experiments_root 78 | opt["path"]["models"] = osp.join(experiments_root, "models") 79 | opt["path"]["training_states"] = osp.join(experiments_root, "training_states") 80 | opt["path"]["log"] = experiments_root 81 | opt["path"]["visualization"] = osp.join(experiments_root, "visualization") 82 | 83 | else: # test 84 | results_root = osp.join(root_path, "results", opt["name"]) 85 | opt["path"]["results_root"] = results_root 86 | opt["path"]["log"] = results_root 87 | opt["path"]["visualization"] = osp.join(results_root, "visualization") 88 | 89 | return opt 90 | 91 | 92 | def dict2str(opt, indent_level=1): 93 | """dict to string for printing options. 94 | 95 | Args: 96 | opt (dict): Option dict. 97 | indent_level (int): Indent level. Default: 1. 98 | 99 | Return: 100 | (str): Option string for printing. 101 | """ 102 | msg = "\n" 103 | for k, v in opt.items(): 104 | if isinstance(v, dict): 105 | msg += " " * (indent_level * 2) + k + ":[" 106 | msg += dict2str(v, indent_level + 1) 107 | msg += " " * (indent_level * 2) + "]\n" 108 | else: 109 | msg += " " * (indent_level * 2) + k + ": " + str(v) + "\n" 110 | return msg 111 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry: 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj): 39 | assert name not in self._obj_map, ( 40 | f"An object named '{name}' was already registered " f"in '{self._name}' registry!" 41 | ) 42 | self._obj_map[name] = obj 43 | 44 | def register(self, obj=None): 45 | """ 46 | Register the given object under the the name `obj.__name__`. 47 | Can be used as either a decorator or not. 48 | See docstring of this class for usage. 49 | """ 50 | if obj is None: 51 | # used as a decorator 52 | def deco(func_or_class): 53 | name = func_or_class.__name__ 54 | self._do_register(name, func_or_class) 55 | return func_or_class 56 | 57 | return deco 58 | 59 | # used as a function call 60 | name = obj.__name__ 61 | self._do_register(name, obj) 62 | 63 | def get(self, name): 64 | ret = self._obj_map.get(name) 65 | if ret is None: 66 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 67 | return ret 68 | 69 | def __contains__(self, name): 70 | return name in self._obj_map 71 | 72 | def __iter__(self): 73 | return iter(self._obj_map.items()) 74 | 75 | def keys(self): 76 | return self._obj_map.keys() 77 | 78 | 79 | DATASET_REGISTRY = Registry("dataset") 80 | ARCH_REGISTRY = Registry("arch") 81 | MODEL_REGISTRY = Registry("model") 82 | LOSS_REGISTRY = Registry("loss") 83 | METRIC_REGISTRY = Registry("metric") 84 | -------------------------------------------------------------------------------- /codeformer/basicsr/utils/video_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code is modified from the Real-ESRGAN: 3 | https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py 4 | 5 | """ 6 | import sys 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | try: 12 | import ffmpeg 13 | except ImportError: 14 | import pip 15 | 16 | pip.main(["install", "--user", "ffmpeg-python"]) 17 | import ffmpeg 18 | 19 | 20 | def get_video_meta_info(video_path): 21 | ret = {} 22 | probe = ffmpeg.probe(video_path) 23 | video_streams = [stream for stream in probe["streams"] if stream["codec_type"] == "video"] 24 | has_audio = any(stream["codec_type"] == "audio" for stream in probe["streams"]) 25 | ret["width"] = video_streams[0]["width"] 26 | ret["height"] = video_streams[0]["height"] 27 | ret["fps"] = eval(video_streams[0]["avg_frame_rate"]) 28 | ret["audio"] = ffmpeg.input(video_path).audio if has_audio else None 29 | ret["nb_frames"] = int(video_streams[0]["nb_frames"]) 30 | return ret 31 | 32 | 33 | class VideoReader: 34 | def __init__(self, video_path): 35 | self.paths = [] # for image&folder type 36 | self.audio = None 37 | try: 38 | self.stream_reader = ( 39 | ffmpeg.input(video_path) 40 | .output("pipe:", format="rawvideo", pix_fmt="bgr24", loglevel="error") 41 | .run_async(pipe_stdin=True, pipe_stdout=True, cmd="ffmpeg") 42 | ) 43 | except FileNotFoundError: 44 | print("Please install ffmpeg (not ffmpeg-python) by running\n", "\t$ conda install -c conda-forge ffmpeg") 45 | sys.exit(0) 46 | 47 | meta = get_video_meta_info(video_path) 48 | self.width = meta["width"] 49 | self.height = meta["height"] 50 | self.input_fps = meta["fps"] 51 | self.audio = meta["audio"] 52 | self.nb_frames = meta["nb_frames"] 53 | 54 | self.idx = 0 55 | 56 | def get_resolution(self): 57 | return self.height, self.width 58 | 59 | def get_fps(self): 60 | if self.input_fps is not None: 61 | return self.input_fps 62 | return 24 63 | 64 | def get_audio(self): 65 | return self.audio 66 | 67 | def __len__(self): 68 | return self.nb_frames 69 | 70 | def get_frame_from_stream(self): 71 | img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel 72 | if not img_bytes: 73 | return None 74 | img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3]) 75 | return img 76 | 77 | def get_frame_from_list(self): 78 | if self.idx >= self.nb_frames: 79 | return None 80 | img = cv2.imread(self.paths[self.idx]) 81 | self.idx += 1 82 | return img 83 | 84 | def get_frame(self): 85 | return self.get_frame_from_stream() 86 | 87 | def close(self): 88 | self.stream_reader.stdin.close() 89 | self.stream_reader.wait() 90 | 91 | 92 | class VideoWriter: 93 | def __init__(self, video_save_path, height, width, fps, audio): 94 | if height > 2160: 95 | print( 96 | "You are generating video that is larger than 4K, which will be very slow due to IO speed.", 97 | "We highly recommend to decrease the outscale(aka, -s).", 98 | ) 99 | if audio is not None: 100 | self.stream_writer = ( 101 | ffmpeg.input("pipe:", format="rawvideo", pix_fmt="bgr24", s=f"{width}x{height}", framerate=fps) 102 | .output(audio, video_save_path, pix_fmt="yuv420p", vcodec="libx264", loglevel="error", acodec="copy") 103 | .overwrite_output() 104 | .run_async(pipe_stdin=True, pipe_stdout=True, cmd="ffmpeg") 105 | ) 106 | else: 107 | self.stream_writer = ( 108 | ffmpeg.input("pipe:", format="rawvideo", pix_fmt="bgr24", s=f"{width}x{height}", framerate=fps) 109 | .output(video_save_path, pix_fmt="yuv420p", vcodec="libx264", loglevel="error") 110 | .overwrite_output() 111 | .run_async(pipe_stdin=True, pipe_stdout=True, cmd="ffmpeg") 112 | ) 113 | 114 | def write_frame(self, frame): 115 | try: 116 | frame = frame.astype(np.uint8).tobytes() 117 | self.stream_writer.stdin.write(frame) 118 | except BrokenPipeError: 119 | print( 120 | "Please re-install ffmpeg and libx264 by running\n", 121 | "\t$ conda install -c conda-forge ffmpeg\n", 122 | "\t$ conda install -c conda-forge x264", 123 | ) 124 | sys.exit(0) 125 | 126 | def close(self): 127 | self.stream_writer.stdin.close() 128 | self.stream_writer.wait() 129 | -------------------------------------------------------------------------------- /codeformer/facelib/__init__.py: -------------------------------------------------------------------------------- 1 | from codeformer.facelib.utils import * 2 | -------------------------------------------------------------------------------- /codeformer/facelib/detection/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | 4 | import torch 5 | from codeformer.facelib.detection.yolov5face.models.common import Conv 6 | from codeformer.facelib.utils import download_pretrained_models, load_file_from_url 7 | from torch import nn 8 | 9 | from codeformer.facelib.detection.retinaface.retinaface import RetinaFace 10 | from codeformer.facelib.detection.yolov5face.face_detector import YoloDetector 11 | 12 | 13 | def init_detection_model(model_name, half=False, device="cuda"): 14 | if "retinaface" in model_name: 15 | model = init_retinaface_model(model_name, half, device) 16 | elif "YOLOv5" in model_name: 17 | model = init_yolov5face_model(model_name, device) 18 | else: 19 | raise NotImplementedError(f"{model_name} is not implemented.") 20 | 21 | return model 22 | 23 | 24 | def init_retinaface_model(model_name, half=False, device="cuda"): 25 | if model_name == "retinaface_resnet50": 26 | model = RetinaFace(network_name="resnet50", half=half) 27 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth" 28 | elif model_name == "retinaface_mobile0.25": 29 | model = RetinaFace(network_name="mobile0.25", half=half) 30 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth" 31 | else: 32 | raise NotImplementedError(f"{model_name} is not implemented.") 33 | 34 | model_path = load_file_from_url(url=model_url, model_dir="weights/facelib", progress=True, file_name=None) 35 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 36 | # remove unnecessary 'module.' 37 | for k, v in deepcopy(load_net).items(): 38 | if k.startswith("module."): 39 | load_net[k[7:]] = v 40 | load_net.pop(k) 41 | model.load_state_dict(load_net, strict=True) 42 | model.eval() 43 | model = model.to(device) 44 | 45 | return model 46 | 47 | 48 | def init_yolov5face_model(model_name, device="cuda"): 49 | if model_name == "YOLOv5l": 50 | model = YoloDetector(config_name="facelib/detection/yolov5face/models/yolov5l.yaml", device=device) 51 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth" 52 | elif model_name == "YOLOv5n": 53 | model = YoloDetector(config_name="facelib/detection/yolov5face/models/yolov5n.yaml", device=device) 54 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth" 55 | else: 56 | raise NotImplementedError(f"{model_name} is not implemented.") 57 | 58 | model_path = load_file_from_url(url=model_url, model_dir="weights/facelib", progress=True, file_name=None) 59 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 60 | model.detector.load_state_dict(load_net, strict=True) 61 | model.detector.eval() 62 | model.detector = model.detector.to(device).float() 63 | 64 | for m in model.detector.modules(): 65 | if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 66 | m.inplace = True # pytorch 1.7.0 compatibility 67 | elif isinstance(m, Conv): 68 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 69 | 70 | return model 71 | 72 | 73 | # Download from Google Drive 74 | # def init_yolov5face_model(model_name, device='cuda'): 75 | # if model_name == 'YOLOv5l': 76 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) 77 | # f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} 78 | # elif model_name == 'YOLOv5n': 79 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) 80 | # f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} 81 | # else: 82 | # raise NotImplementedError(f'{model_name} is not implemented.') 83 | 84 | # model_path = os.path.join('weights/facelib', list(f_id.keys())[0]) 85 | # if not os.path.exists(model_path): 86 | # download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib') 87 | 88 | # load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 89 | # model.detector.load_state_dict(load_net, strict=True) 90 | # model.detector.eval() 91 | # model.detector = model.detector.to(device).float() 92 | 93 | # for m in model.detector.modules(): 94 | # if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 95 | # m.inplace = True # pytorch 1.7.0 compatibility 96 | # elif isinstance(m, Conv): 97 | # m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 98 | 99 | # return model 100 | -------------------------------------------------------------------------------- /codeformer/facelib/detection/align_trans.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from codeformer.facelib.detection.matlab_cp2tform import get_similarity_transform_for_cv2 5 | 6 | # reference facial points, a list of coordinates (x,y) 7 | REFERENCE_FACIAL_POINTS = [ 8 | [30.29459953, 51.69630051], 9 | [65.53179932, 51.50139999], 10 | [48.02519989, 71.73660278], 11 | [33.54930115, 92.3655014], 12 | [62.72990036, 92.20410156], 13 | ] 14 | 15 | DEFAULT_CROP_SIZE = (96, 112) 16 | 17 | 18 | class FaceWarpException(Exception): 19 | def __str__(self): 20 | return "In File {}:{}".format(__file__, super.__str__(self)) 21 | 22 | 23 | def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): 24 | """ 25 | Function: 26 | ---------- 27 | get reference 5 key points according to crop settings: 28 | 0. Set default crop_size: 29 | if default_square: 30 | crop_size = (112, 112) 31 | else: 32 | crop_size = (96, 112) 33 | 1. Pad the crop_size by inner_padding_factor in each side; 34 | 2. Resize crop_size into (output_size - outer_padding*2), 35 | pad into output_size with outer_padding; 36 | 3. Output reference_5point; 37 | Parameters: 38 | ---------- 39 | @output_size: (w, h) or None 40 | size of aligned face image 41 | @inner_padding_factor: (w_factor, h_factor) 42 | padding factor for inner (w, h) 43 | @outer_padding: (w_pad, h_pad) 44 | each row is a pair of coordinates (x, y) 45 | @default_square: True or False 46 | if True: 47 | default crop_size = (112, 112) 48 | else: 49 | default crop_size = (96, 112); 50 | !!! make sure, if output_size is not None: 51 | (output_size - outer_padding) 52 | = some_scale * (default crop_size * (1.0 + 53 | inner_padding_factor)) 54 | Returns: 55 | ---------- 56 | @reference_5point: 5x2 np.array 57 | each row is a pair of transformed coordinates (x, y) 58 | """ 59 | 60 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) 61 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE) 62 | 63 | # 0) make the inner region a square 64 | if default_square: 65 | size_diff = max(tmp_crop_size) - tmp_crop_size 66 | tmp_5pts += size_diff / 2 67 | tmp_crop_size += size_diff 68 | 69 | if output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]: 70 | 71 | return tmp_5pts 72 | 73 | if inner_padding_factor == 0 and outer_padding == (0, 0): 74 | if output_size is None: 75 | return tmp_5pts 76 | else: 77 | raise FaceWarpException("No paddings to do, output_size must be None or {}".format(tmp_crop_size)) 78 | 79 | # check output size 80 | if not (0 <= inner_padding_factor <= 1.0): 81 | raise FaceWarpException("Not (0 <= inner_padding_factor <= 1.0)") 82 | 83 | if (inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None: 84 | output_size = tmp_crop_size * (1 + inner_padding_factor * 2).astype(np.int32) 85 | output_size += np.array(outer_padding) 86 | if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): 87 | raise FaceWarpException("Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])") 88 | 89 | # 1) pad the inner region according inner_padding_factor 90 | if inner_padding_factor > 0: 91 | size_diff = tmp_crop_size * inner_padding_factor * 2 92 | tmp_5pts += size_diff / 2 93 | tmp_crop_size += np.round(size_diff).astype(np.int32) 94 | 95 | # 2) resize the padded inner region 96 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 97 | 98 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: 99 | raise FaceWarpException( 100 | "Must have (output_size - outer_padding)" "= some_scale * (crop_size * (1.0 + inner_padding_factor)" 101 | ) 102 | 103 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] 104 | tmp_5pts = tmp_5pts * scale_factor 105 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) 106 | # tmp_5pts = tmp_5pts + size_diff / 2 107 | tmp_crop_size = size_bf_outer_pad 108 | 109 | # 3) add outer_padding to make output_size 110 | reference_5point = tmp_5pts + np.array(outer_padding) 111 | tmp_crop_size = output_size 112 | 113 | return reference_5point 114 | 115 | 116 | def get_affine_transform_matrix(src_pts, dst_pts): 117 | """ 118 | Function: 119 | ---------- 120 | get affine transform matrix 'tfm' from src_pts to dst_pts 121 | Parameters: 122 | ---------- 123 | @src_pts: Kx2 np.array 124 | source points matrix, each row is a pair of coordinates (x, y) 125 | @dst_pts: Kx2 np.array 126 | destination points matrix, each row is a pair of coordinates (x, y) 127 | Returns: 128 | ---------- 129 | @tfm: 2x3 np.array 130 | transform matrix from src_pts to dst_pts 131 | """ 132 | 133 | tfm = np.float32([[1, 0, 0], [0, 1, 0]]) 134 | n_pts = src_pts.shape[0] 135 | ones = np.ones((n_pts, 1), src_pts.dtype) 136 | src_pts_ = np.hstack([src_pts, ones]) 137 | dst_pts_ = np.hstack([dst_pts, ones]) 138 | 139 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) 140 | 141 | if rank == 3: 142 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) 143 | elif rank == 2: 144 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) 145 | 146 | return tfm 147 | 148 | 149 | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type="smilarity"): 150 | """ 151 | Function: 152 | ---------- 153 | apply affine transform 'trans' to uv 154 | Parameters: 155 | ---------- 156 | @src_img: 3x3 np.array 157 | input image 158 | @facial_pts: could be 159 | 1)a list of K coordinates (x,y) 160 | or 161 | 2) Kx2 or 2xK np.array 162 | each row or col is a pair of coordinates (x, y) 163 | @reference_pts: could be 164 | 1) a list of K coordinates (x,y) 165 | or 166 | 2) Kx2 or 2xK np.array 167 | each row or col is a pair of coordinates (x, y) 168 | or 169 | 3) None 170 | if None, use default reference facial points 171 | @crop_size: (w, h) 172 | output face image size 173 | @align_type: transform type, could be one of 174 | 1) 'similarity': use similarity transform 175 | 2) 'cv2_affine': use the first 3 points to do affine transform, 176 | by calling cv2.getAffineTransform() 177 | 3) 'affine': use all points to do affine transform 178 | Returns: 179 | ---------- 180 | @face_img: output face image with size (w, h) = @crop_size 181 | """ 182 | 183 | if reference_pts is None: 184 | if crop_size[0] == 96 and crop_size[1] == 112: 185 | reference_pts = REFERENCE_FACIAL_POINTS 186 | else: 187 | default_square = False 188 | inner_padding_factor = 0 189 | outer_padding = (0, 0) 190 | output_size = crop_size 191 | 192 | reference_pts = get_reference_facial_points( 193 | output_size, inner_padding_factor, outer_padding, default_square 194 | ) 195 | 196 | ref_pts = np.float32(reference_pts) 197 | ref_pts_shp = ref_pts.shape 198 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: 199 | raise FaceWarpException("reference_pts.shape must be (K,2) or (2,K) and K>2") 200 | 201 | if ref_pts_shp[0] == 2: 202 | ref_pts = ref_pts.T 203 | 204 | src_pts = np.float32(facial_pts) 205 | src_pts_shp = src_pts.shape 206 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: 207 | raise FaceWarpException("facial_pts.shape must be (K,2) or (2,K) and K>2") 208 | 209 | if src_pts_shp[0] == 2: 210 | src_pts = src_pts.T 211 | 212 | if src_pts.shape != ref_pts.shape: 213 | raise FaceWarpException("facial_pts and reference_pts must have the same shape") 214 | 215 | if align_type == "cv2_affine": 216 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) 217 | elif align_type == "affine": 218 | tfm = get_affine_transform_matrix(src_pts, ref_pts) 219 | else: 220 | tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) 221 | 222 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) 223 | 224 | return face_img 225 | -------------------------------------------------------------------------------- /codeformer/facelib/detection/retinaface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/facelib/detection/retinaface/__init__.py -------------------------------------------------------------------------------- /codeformer/facelib/detection/retinaface/retinaface_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv_bn(inp, oup, stride=1, leaky=0): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 9 | nn.BatchNorm2d(oup), 10 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 11 | ) 12 | 13 | 14 | def conv_bn_no_relu(inp, oup, stride): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 17 | nn.BatchNorm2d(oup), 18 | ) 19 | 20 | 21 | def conv_bn1X1(inp, oup, stride, leaky=0): 22 | return nn.Sequential( 23 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), 24 | nn.BatchNorm2d(oup), 25 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 26 | ) 27 | 28 | 29 | def conv_dw(inp, oup, stride, leaky=0.1): 30 | return nn.Sequential( 31 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 32 | nn.BatchNorm2d(inp), 33 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 34 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 35 | nn.BatchNorm2d(oup), 36 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 37 | ) 38 | 39 | 40 | class SSH(nn.Module): 41 | def __init__(self, in_channel, out_channel): 42 | super(SSH, self).__init__() 43 | assert out_channel % 4 == 0 44 | leaky = 0 45 | if out_channel <= 64: 46 | leaky = 0.1 47 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) 48 | 49 | self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) 50 | self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 51 | 52 | self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) 53 | self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 54 | 55 | def forward(self, input): 56 | conv3X3 = self.conv3X3(input) 57 | 58 | conv5X5_1 = self.conv5X5_1(input) 59 | conv5X5 = self.conv5X5_2(conv5X5_1) 60 | 61 | conv7X7_2 = self.conv7X7_2(conv5X5_1) 62 | conv7X7 = self.conv7x7_3(conv7X7_2) 63 | 64 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class FPN(nn.Module): 70 | def __init__(self, in_channels_list, out_channels): 71 | super(FPN, self).__init__() 72 | leaky = 0 73 | if out_channels <= 64: 74 | leaky = 0.1 75 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) 76 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) 77 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) 78 | 79 | self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) 80 | self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) 81 | 82 | def forward(self, input): 83 | # names = list(input.keys()) 84 | # input = list(input.values()) 85 | 86 | output1 = self.output1(input[0]) 87 | output2 = self.output2(input[1]) 88 | output3 = self.output3(input[2]) 89 | 90 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest") 91 | output2 = output2 + up3 92 | output2 = self.merge2(output2) 93 | 94 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest") 95 | output1 = output1 + up2 96 | output1 = self.merge1(output1) 97 | 98 | out = [output1, output2, output3] 99 | return out 100 | 101 | 102 | class MobileNetV1(nn.Module): 103 | def __init__(self): 104 | super(MobileNetV1, self).__init__() 105 | self.stage1 = nn.Sequential( 106 | conv_bn(3, 8, 2, leaky=0.1), # 3 107 | conv_dw(8, 16, 1), # 7 108 | conv_dw(16, 32, 2), # 11 109 | conv_dw(32, 32, 1), # 19 110 | conv_dw(32, 64, 2), # 27 111 | conv_dw(64, 64, 1), # 43 112 | ) 113 | self.stage2 = nn.Sequential( 114 | conv_dw(64, 128, 2), # 43 + 16 = 59 115 | conv_dw(128, 128, 1), # 59 + 32 = 91 116 | conv_dw(128, 128, 1), # 91 + 32 = 123 117 | conv_dw(128, 128, 1), # 123 + 32 = 155 118 | conv_dw(128, 128, 1), # 155 + 32 = 187 119 | conv_dw(128, 128, 1), # 187 + 32 = 219 120 | ) 121 | self.stage3 = nn.Sequential( 122 | conv_dw(128, 256, 2), # 219 +3 2 = 241 123 | conv_dw(256, 256, 1), # 241 + 64 = 301 124 | ) 125 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 126 | self.fc = nn.Linear(256, 1000) 127 | 128 | def forward(self, x): 129 | x = self.stage1(x) 130 | x = self.stage2(x) 131 | x = self.stage3(x) 132 | x = self.avg(x) 133 | # x = self.model(x) 134 | x = x.view(-1, 256) 135 | x = self.fc(x) 136 | return x 137 | 138 | 139 | class ClassHead(nn.Module): 140 | def __init__(self, inchannels=512, num_anchors=3): 141 | super(ClassHead, self).__init__() 142 | self.num_anchors = num_anchors 143 | self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) 144 | 145 | def forward(self, x): 146 | out = self.conv1x1(x) 147 | out = out.permute(0, 2, 3, 1).contiguous() 148 | 149 | return out.view(out.shape[0], -1, 2) 150 | 151 | 152 | class BboxHead(nn.Module): 153 | def __init__(self, inchannels=512, num_anchors=3): 154 | super(BboxHead, self).__init__() 155 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) 156 | 157 | def forward(self, x): 158 | out = self.conv1x1(x) 159 | out = out.permute(0, 2, 3, 1).contiguous() 160 | 161 | return out.view(out.shape[0], -1, 4) 162 | 163 | 164 | class LandmarkHead(nn.Module): 165 | def __init__(self, inchannels=512, num_anchors=3): 166 | super(LandmarkHead, self).__init__() 167 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) 168 | 169 | def forward(self, x): 170 | out = self.conv1x1(x) 171 | out = out.permute(0, 2, 3, 1).contiguous() 172 | 173 | return out.view(out.shape[0], -1, 10) 174 | 175 | 176 | def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): 177 | classhead = nn.ModuleList() 178 | for i in range(fpn_num): 179 | classhead.append(ClassHead(inchannels, anchor_num)) 180 | return classhead 181 | 182 | 183 | def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): 184 | bboxhead = nn.ModuleList() 185 | for i in range(fpn_num): 186 | bboxhead.append(BboxHead(inchannels, anchor_num)) 187 | return bboxhead 188 | 189 | 190 | def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): 191 | landmarkhead = nn.ModuleList() 192 | for i in range(fpn_num): 193 | landmarkhead.append(LandmarkHead(inchannels, anchor_num)) 194 | return landmarkhead 195 | -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/facelib/detection/yolov5face/__init__.py -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/face_detector.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | from pathlib import Path 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | 9 | from codeformer.facelib.detection.yolov5face.models.yolo import Model 10 | from codeformer.facelib.detection.yolov5face.utils.datasets import letterbox 11 | from codeformer.facelib.detection.yolov5face.utils.general import ( 12 | check_img_size, 13 | non_max_suppression_face, 14 | scale_coords, 15 | scale_coords_landmarks, 16 | ) 17 | 18 | # IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9) 19 | IS_HIGH_VERSION = [ 20 | int(m) 21 | for m in list( 22 | re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$", torch.__version__)[0][:3] 23 | ) 24 | ] >= [1, 9, 0] 25 | 26 | 27 | def isListempty(inList): 28 | if isinstance(inList, list): # Is a list 29 | return all(map(isListempty, inList)) 30 | return False # Not a list 31 | 32 | 33 | class YoloDetector: 34 | def __init__( 35 | self, 36 | config_name, 37 | min_face=10, 38 | target_size=None, 39 | device="cuda", 40 | ): 41 | """ 42 | config_name: name of .yaml config with network configuration from models/ folder. 43 | min_face : minimal face size in pixels. 44 | target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. 45 | None for original resolution. 46 | """ 47 | self._class_path = Path(__file__).parent.absolute() 48 | self.target_size = target_size 49 | self.min_face = min_face 50 | self.detector = Model(cfg=config_name) 51 | self.device = device 52 | 53 | def _preprocess(self, imgs): 54 | """ 55 | Preprocessing image before passing through the network. Resize and conversion to torch tensor. 56 | """ 57 | pp_imgs = [] 58 | for img in imgs: 59 | h0, w0 = img.shape[:2] # orig hw 60 | if self.target_size: 61 | r = self.target_size / min(h0, w0) # resize image to img_size 62 | if r < 1: 63 | img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) 64 | 65 | imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size 66 | img = letterbox(img, new_shape=imgsz)[0] 67 | pp_imgs.append(img) 68 | pp_imgs = np.array(pp_imgs) 69 | pp_imgs = pp_imgs.transpose(0, 3, 1, 2) 70 | pp_imgs = torch.from_numpy(pp_imgs).to(self.device) 71 | pp_imgs = pp_imgs.float() # uint8 to fp16/32 72 | return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 73 | 74 | def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): 75 | """ 76 | Postprocessing of raw pytorch model output. 77 | Returns: 78 | bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. 79 | points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). 80 | """ 81 | bboxes = [[] for _ in range(len(origimgs))] 82 | landmarks = [[] for _ in range(len(origimgs))] 83 | 84 | pred = non_max_suppression_face(pred, conf_thres, iou_thres) 85 | 86 | for image_id, origimg in enumerate(origimgs): 87 | img_shape = origimg.shape 88 | image_height, image_width = img_shape[:2] 89 | gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh 90 | gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks 91 | det = pred[image_id].cpu() 92 | scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() 93 | scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() 94 | 95 | for j in range(det.size()[0]): 96 | box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() 97 | box = list( 98 | map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height]) 99 | ) 100 | if box[3] - box[1] < self.min_face: 101 | continue 102 | lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() 103 | lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) 104 | lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] 105 | bboxes[image_id].append(box) 106 | landmarks[image_id].append(lm) 107 | return bboxes, landmarks 108 | 109 | def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): 110 | """ 111 | Get bbox coordinates and keypoints of faces on original image. 112 | Params: 113 | imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) 114 | conf_thres: confidence threshold for each prediction 115 | iou_thres: threshold for NMS (filter of intersecting bboxes) 116 | Returns: 117 | bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. 118 | points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). 119 | """ 120 | # Pass input images through face detector 121 | images = imgs if isinstance(imgs, list) else [imgs] 122 | images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] 123 | origimgs = copy.deepcopy(images) 124 | 125 | images = self._preprocess(images) 126 | 127 | if IS_HIGH_VERSION: 128 | with torch.inference_mode(): # for pytorch>=1.9 129 | pred = self.detector(images)[0] 130 | else: 131 | with torch.no_grad(): # for pytorch<1.9 132 | pred = self.detector(images)[0] 133 | 134 | bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) 135 | 136 | # return bboxes, points 137 | if not isListempty(points): 138 | bboxes = np.array(bboxes).reshape(-1, 4) 139 | points = np.array(points).reshape(-1, 10) 140 | padding = bboxes[:, 0].reshape(-1, 1) 141 | return np.concatenate((bboxes, padding, points), axis=1) 142 | else: 143 | return None 144 | 145 | def __call__(self, *args): 146 | return self.predict(*args) 147 | -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/facelib/detection/yolov5face/models/__init__.py -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/models/experimental.py: -------------------------------------------------------------------------------- 1 | # # This file contains experimental modules 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | from codeformer.facelib.detection.yolov5face.models.common import Conv 8 | 9 | 10 | class CrossConv(nn.Module): 11 | # Cross Convolution Downsample 12 | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): 13 | # ch_in, ch_out, kernel, stride, groups, expansion, shortcut 14 | super().__init__() 15 | c_ = int(c2 * e) # hidden channels 16 | self.cv1 = Conv(c1, c_, (1, k), (1, s)) 17 | self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) 18 | self.add = shortcut and c1 == c2 19 | 20 | def forward(self, x): 21 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) 22 | 23 | 24 | class MixConv2d(nn.Module): 25 | # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 26 | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): 27 | super().__init__() 28 | groups = len(k) 29 | if equal_ch: # equal c_ per group 30 | i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices 31 | c_ = [(i == g).sum() for g in range(groups)] # intermediate channels 32 | else: # equal weight.numel() per group 33 | b = [c2] + [0] * groups 34 | a = np.eye(groups + 1, groups, k=-1) 35 | a -= np.roll(a, 1, axis=1) 36 | a *= np.array(k) ** 2 37 | a[0] = 1 38 | c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b 39 | 40 | self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) 41 | self.bn = nn.BatchNorm2d(c2) 42 | self.act = nn.LeakyReLU(0.1, inplace=True) 43 | 44 | def forward(self, x): 45 | return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) 46 | -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/models/yolov5l.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 1 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [4,5, 8,10, 13,16] # P3/8 9 | - [23,29, 43,55, 73,105] # P4/16 10 | - [146,217, 231,300, 335,433] # P5/32 11 | 12 | # YOLOv5 backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 16 | [-1, 3, C3, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 18 | [-1, 9, C3, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 20 | [-1, 9, C3, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 22 | [-1, 1, SPP, [1024, [3,5,7]]], 23 | [-1, 3, C3, [1024, False]], # 8 24 | ] 25 | 26 | # YOLOv5 head 27 | head: 28 | [[-1, 1, Conv, [512, 1, 1]], 29 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 30 | [[-1, 5], 1, Concat, [1]], # cat backbone P4 31 | [-1, 3, C3, [512, False]], # 12 32 | 33 | [-1, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 3], 1, Concat, [1]], # cat backbone P3 36 | [-1, 3, C3, [256, False]], # 16 (P3/8-small) 37 | 38 | [-1, 1, Conv, [256, 3, 2]], 39 | [[-1, 13], 1, Concat, [1]], # cat head P4 40 | [-1, 3, C3, [512, False]], # 19 (P4/16-medium) 41 | 42 | [-1, 1, Conv, [512, 3, 2]], 43 | [[-1, 9], 1, Concat, [1]], # cat head P5 44 | [-1, 3, C3, [1024, False]], # 22 (P5/32-large) 45 | 46 | [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) 47 | ] -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/models/yolov5n.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 1 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [4,5, 8,10, 13,16] # P3/8 9 | - [23,29, 43,55, 73,105] # P4/16 10 | - [146,217, 231,300, 335,433] # P5/32 11 | 12 | # YOLOv5 backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 16 | [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 17 | [-1, 3, ShuffleV2Block, [128, 1]], # 2 18 | [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 19 | [-1, 7, ShuffleV2Block, [256, 1]], # 4 20 | [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 21 | [-1, 3, ShuffleV2Block, [512, 1]], # 6 22 | ] 23 | 24 | # YOLOv5 head 25 | head: 26 | [[-1, 1, Conv, [128, 1, 1]], 27 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 28 | [[-1, 4], 1, Concat, [1]], # cat backbone P4 29 | [-1, 1, C3, [128, False]], # 10 30 | 31 | [-1, 1, Conv, [128, 1, 1]], 32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 33 | [[-1, 2], 1, Concat, [1]], # cat backbone P3 34 | [-1, 1, C3, [128, False]], # 14 (P3/8-small) 35 | 36 | [-1, 1, Conv, [128, 3, 2]], 37 | [[-1, 11], 1, Concat, [1]], # cat head P4 38 | [-1, 1, C3, [128, False]], # 17 (P4/16-medium) 39 | 40 | [-1, 1, Conv, [128, 3, 2]], 41 | [[-1, 7], 1, Concat, [1]], # cat head P5 42 | [-1, 1, C3, [128, False]], # 20 (P5/32-large) 43 | 44 | [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) 45 | ] 46 | -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/facelib/detection/yolov5face/utils/__init__.py -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/utils/autoanchor.py: -------------------------------------------------------------------------------- 1 | # Auto-anchor utils 2 | 3 | 4 | def check_anchor_order(m): 5 | # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary 6 | a = m.anchor_grid.prod(-1).view(-1) # anchor area 7 | da = a[-1] - a[0] # delta a 8 | ds = m.stride[-1] - m.stride[0] # delta s 9 | if da.sign() != ds.sign(): # same order 10 | print("Reversing anchor order") 11 | m.anchors[:] = m.anchors.flip(0) 12 | m.anchor_grid[:] = m.anchor_grid.flip(0) 13 | -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): 6 | # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 7 | shape = img.shape[:2] # current shape [height, width] 8 | if isinstance(new_shape, int): 9 | new_shape = (new_shape, new_shape) 10 | 11 | # Scale ratio (new / old) 12 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 13 | if not scaleup: # only scale down, do not scale up (for better test mAP) 14 | r = min(r, 1.0) 15 | 16 | # Compute padding 17 | ratio = r, r # width, height ratios 18 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 19 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding 20 | if auto: # minimum rectangle 21 | dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding 22 | elif scale_fill: # stretch 23 | dw, dh = 0.0, 0.0 24 | new_unpad = (new_shape[1], new_shape[0]) 25 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios 26 | 27 | dw /= 2 # divide padding into 2 sides 28 | dh /= 2 29 | 30 | if shape[::-1] != new_unpad: # resize 31 | img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) 32 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 33 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 34 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border 35 | return img, ratio, (dw, dh) 36 | -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/utils/extract_ckpt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | 5 | sys.path.insert(0, "./facelib/detection/yolov5face") 6 | model = torch.load("facelib/detection/yolov5face/yolov5n-face.pt", map_location="cpu")["model"] 7 | torch.save(model.state_dict(), "weights/facelib/yolov5n-face.pth") 8 | -------------------------------------------------------------------------------- /codeformer/facelib/detection/yolov5face/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def fuse_conv_and_bn(conv, bn): 6 | # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ 7 | fusedconv = ( 8 | nn.Conv2d( 9 | conv.in_channels, 10 | conv.out_channels, 11 | kernel_size=conv.kernel_size, 12 | stride=conv.stride, 13 | padding=conv.padding, 14 | groups=conv.groups, 15 | bias=True, 16 | ) 17 | .requires_grad_(False) 18 | .to(conv.weight.device) 19 | ) 20 | 21 | # prepare filters 22 | w_conv = conv.weight.clone().view(conv.out_channels, -1) 23 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) 24 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) 25 | 26 | # prepare spatial bias 27 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias 28 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) 29 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) 30 | 31 | return fusedconv 32 | 33 | 34 | def copy_attr(a, b, include=(), exclude=()): 35 | # Copy attributes from b to a, options to only include [...] and to exclude [...] 36 | for k, v in b.__dict__.items(): 37 | if (include and k not in include) or k.startswith("_") or k in exclude: 38 | continue 39 | 40 | setattr(a, k, v) 41 | -------------------------------------------------------------------------------- /codeformer/facelib/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from codeformer.facelib.parsing.bisenet import BiSeNet 4 | from codeformer.facelib.parsing.parsenet import ParseNet 5 | from codeformer.facelib.utils import load_file_from_url 6 | 7 | 8 | def init_parsing_model(model_name="bisenet", half=False, device="cuda"): 9 | if model_name == "bisenet": 10 | model = BiSeNet(num_class=19) 11 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth" 12 | elif model_name == "parsenet": 13 | model = ParseNet(in_size=512, out_size=512, parsing_ch=19) 14 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth" 15 | else: 16 | raise NotImplementedError(f"{model_name} is not implemented.") 17 | 18 | model_path = load_file_from_url(url=model_url, model_dir="weights/facelib", progress=True, file_name=None) 19 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 20 | model.load_state_dict(load_net, strict=True) 21 | model.eval() 22 | model = model.to(device) 23 | return model 24 | -------------------------------------------------------------------------------- /codeformer/facelib/parsing/bisenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .resnet import ResNet18 6 | 7 | 8 | class ConvBNReLU(nn.Module): 9 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): 10 | super(ConvBNReLU, self).__init__() 11 | self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) 12 | self.bn = nn.BatchNorm2d(out_chan) 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | x = F.relu(self.bn(x)) 17 | return x 18 | 19 | 20 | class BiSeNetOutput(nn.Module): 21 | def __init__(self, in_chan, mid_chan, num_class): 22 | super(BiSeNetOutput, self).__init__() 23 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 24 | self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) 25 | 26 | def forward(self, x): 27 | feat = self.conv(x) 28 | out = self.conv_out(feat) 29 | return out, feat 30 | 31 | 32 | class AttentionRefinementModule(nn.Module): 33 | def __init__(self, in_chan, out_chan): 34 | super(AttentionRefinementModule, self).__init__() 35 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 36 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) 37 | self.bn_atten = nn.BatchNorm2d(out_chan) 38 | self.sigmoid_atten = nn.Sigmoid() 39 | 40 | def forward(self, x): 41 | feat = self.conv(x) 42 | atten = F.avg_pool2d(feat, feat.size()[2:]) 43 | atten = self.conv_atten(atten) 44 | atten = self.bn_atten(atten) 45 | atten = self.sigmoid_atten(atten) 46 | out = torch.mul(feat, atten) 47 | return out 48 | 49 | 50 | class ContextPath(nn.Module): 51 | def __init__(self): 52 | super(ContextPath, self).__init__() 53 | self.resnet = ResNet18() 54 | self.arm16 = AttentionRefinementModule(256, 128) 55 | self.arm32 = AttentionRefinementModule(512, 128) 56 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 57 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 58 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 59 | 60 | def forward(self, x): 61 | feat8, feat16, feat32 = self.resnet(x) 62 | h8, w8 = feat8.size()[2:] 63 | h16, w16 = feat16.size()[2:] 64 | h32, w32 = feat32.size()[2:] 65 | 66 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 67 | avg = self.conv_avg(avg) 68 | avg_up = F.interpolate(avg, (h32, w32), mode="nearest") 69 | 70 | feat32_arm = self.arm32(feat32) 71 | feat32_sum = feat32_arm + avg_up 72 | feat32_up = F.interpolate(feat32_sum, (h16, w16), mode="nearest") 73 | feat32_up = self.conv_head32(feat32_up) 74 | 75 | feat16_arm = self.arm16(feat16) 76 | feat16_sum = feat16_arm + feat32_up 77 | feat16_up = F.interpolate(feat16_sum, (h8, w8), mode="nearest") 78 | feat16_up = self.conv_head16(feat16_up) 79 | 80 | return feat8, feat16_up, feat32_up # x8, x8, x16 81 | 82 | 83 | class FeatureFusionModule(nn.Module): 84 | def __init__(self, in_chan, out_chan): 85 | super(FeatureFusionModule, self).__init__() 86 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 87 | self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) 88 | self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) 89 | self.relu = nn.ReLU(inplace=True) 90 | self.sigmoid = nn.Sigmoid() 91 | 92 | def forward(self, fsp, fcp): 93 | fcat = torch.cat([fsp, fcp], dim=1) 94 | feat = self.convblk(fcat) 95 | atten = F.avg_pool2d(feat, feat.size()[2:]) 96 | atten = self.conv1(atten) 97 | atten = self.relu(atten) 98 | atten = self.conv2(atten) 99 | atten = self.sigmoid(atten) 100 | feat_atten = torch.mul(feat, atten) 101 | feat_out = feat_atten + feat 102 | return feat_out 103 | 104 | 105 | class BiSeNet(nn.Module): 106 | def __init__(self, num_class): 107 | super(BiSeNet, self).__init__() 108 | self.cp = ContextPath() 109 | self.ffm = FeatureFusionModule(256, 256) 110 | self.conv_out = BiSeNetOutput(256, 256, num_class) 111 | self.conv_out16 = BiSeNetOutput(128, 64, num_class) 112 | self.conv_out32 = BiSeNetOutput(128, 64, num_class) 113 | 114 | def forward(self, x, return_feat=False): 115 | h, w = x.size()[2:] 116 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature 117 | feat_sp = feat_res8 # replace spatial path feature with res3b1 feature 118 | feat_fuse = self.ffm(feat_sp, feat_cp8) 119 | 120 | out, feat = self.conv_out(feat_fuse) 121 | out16, feat16 = self.conv_out16(feat_cp8) 122 | out32, feat32 = self.conv_out32(feat_cp16) 123 | 124 | out = F.interpolate(out, (h, w), mode="bilinear", align_corners=True) 125 | out16 = F.interpolate(out16, (h, w), mode="bilinear", align_corners=True) 126 | out32 = F.interpolate(out32, (h, w), mode="bilinear", align_corners=True) 127 | 128 | if return_feat: 129 | feat = F.interpolate(feat, (h, w), mode="bilinear", align_corners=True) 130 | feat16 = F.interpolate(feat16, (h, w), mode="bilinear", align_corners=True) 131 | feat32 = F.interpolate(feat32, (h, w), mode="bilinear", align_corners=True) 132 | return out, out16, out32, feat, feat16, feat32 133 | else: 134 | return out, out16, out32 135 | -------------------------------------------------------------------------------- /codeformer/facelib/parsing/parsenet.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/chaofengc/PSFRGAN 2 | """ 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class NormLayer(nn.Module): 9 | """Normalization Layers. 10 | 11 | Args: 12 | channels: input channels, for batch norm and instance norm. 13 | input_size: input shape without batch size, for layer norm. 14 | """ 15 | 16 | def __init__(self, channels, normalize_shape=None, norm_type="bn"): 17 | super(NormLayer, self).__init__() 18 | norm_type = norm_type.lower() 19 | self.norm_type = norm_type 20 | if norm_type == "bn": 21 | self.norm = nn.BatchNorm2d(channels, affine=True) 22 | elif norm_type == "in": 23 | self.norm = nn.InstanceNorm2d(channels, affine=False) 24 | elif norm_type == "gn": 25 | self.norm = nn.GroupNorm(32, channels, affine=True) 26 | elif norm_type == "pixel": 27 | self.norm = lambda x: F.normalize(x, p=2, dim=1) 28 | elif norm_type == "layer": 29 | self.norm = nn.LayerNorm(normalize_shape) 30 | elif norm_type == "none": 31 | self.norm = lambda x: x * 1.0 32 | else: 33 | assert 1 == 0, f"Norm type {norm_type} not support." 34 | 35 | def forward(self, x, ref=None): 36 | if self.norm_type == "spade": 37 | return self.norm(x, ref) 38 | else: 39 | return self.norm(x) 40 | 41 | 42 | class ReluLayer(nn.Module): 43 | """Relu Layer. 44 | 45 | Args: 46 | relu type: type of relu layer, candidates are 47 | - ReLU 48 | - LeakyReLU: default relu slope 0.2 49 | - PRelu 50 | - SELU 51 | - none: direct pass 52 | """ 53 | 54 | def __init__(self, channels, relu_type="relu"): 55 | super(ReluLayer, self).__init__() 56 | relu_type = relu_type.lower() 57 | if relu_type == "relu": 58 | self.func = nn.ReLU(True) 59 | elif relu_type == "leakyrelu": 60 | self.func = nn.LeakyReLU(0.2, inplace=True) 61 | elif relu_type == "prelu": 62 | self.func = nn.PReLU(channels) 63 | elif relu_type == "selu": 64 | self.func = nn.SELU(True) 65 | elif relu_type == "none": 66 | self.func = lambda x: x * 1.0 67 | else: 68 | assert 1 == 0, f"Relu type {relu_type} not support." 69 | 70 | def forward(self, x): 71 | return self.func(x) 72 | 73 | 74 | class ConvLayer(nn.Module): 75 | def __init__( 76 | self, 77 | in_channels, 78 | out_channels, 79 | kernel_size=3, 80 | scale="none", 81 | norm_type="none", 82 | relu_type="none", 83 | use_pad=True, 84 | bias=True, 85 | ): 86 | super(ConvLayer, self).__init__() 87 | self.use_pad = use_pad 88 | self.norm_type = norm_type 89 | if norm_type in ["bn"]: 90 | bias = False 91 | 92 | stride = 2 if scale == "down" else 1 93 | 94 | self.scale_func = lambda x: x 95 | if scale == "up": 96 | self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode="nearest") 97 | 98 | self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.0) / 2))) 99 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) 100 | 101 | self.relu = ReluLayer(out_channels, relu_type) 102 | self.norm = NormLayer(out_channels, norm_type=norm_type) 103 | 104 | def forward(self, x): 105 | out = self.scale_func(x) 106 | if self.use_pad: 107 | out = self.reflection_pad(out) 108 | out = self.conv2d(out) 109 | out = self.norm(out) 110 | out = self.relu(out) 111 | return out 112 | 113 | 114 | class ResidualBlock(nn.Module): 115 | """ 116 | Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html 117 | """ 118 | 119 | def __init__(self, c_in, c_out, relu_type="prelu", norm_type="bn", scale="none"): 120 | super(ResidualBlock, self).__init__() 121 | 122 | if scale == "none" and c_in == c_out: 123 | self.shortcut_func = lambda x: x 124 | else: 125 | self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) 126 | 127 | scale_config_dict = {"down": ["none", "down"], "up": ["up", "none"], "none": ["none", "none"]} 128 | scale_conf = scale_config_dict[scale] 129 | 130 | self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) 131 | self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type="none") 132 | 133 | def forward(self, x): 134 | identity = self.shortcut_func(x) 135 | 136 | res = self.conv1(x) 137 | res = self.conv2(res) 138 | return identity + res 139 | 140 | 141 | class ParseNet(nn.Module): 142 | def __init__( 143 | self, 144 | in_size=128, 145 | out_size=128, 146 | min_feat_size=32, 147 | base_ch=64, 148 | parsing_ch=19, 149 | res_depth=10, 150 | relu_type="LeakyReLU", 151 | norm_type="bn", 152 | ch_range=[32, 256], 153 | ): 154 | super().__init__() 155 | self.res_depth = res_depth 156 | act_args = {"norm_type": norm_type, "relu_type": relu_type} 157 | min_ch, max_ch = ch_range 158 | 159 | ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 160 | min_feat_size = min(in_size, min_feat_size) 161 | 162 | down_steps = int(np.log2(in_size // min_feat_size)) 163 | up_steps = int(np.log2(out_size // min_feat_size)) 164 | 165 | # =============== define encoder-body-decoder ==================== 166 | self.encoder = [] 167 | self.encoder.append(ConvLayer(3, base_ch, 3, 1)) 168 | head_ch = base_ch 169 | for i in range(down_steps): 170 | cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) 171 | self.encoder.append(ResidualBlock(cin, cout, scale="down", **act_args)) 172 | head_ch = head_ch * 2 173 | 174 | self.body = [] 175 | for i in range(res_depth): 176 | self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) 177 | 178 | self.decoder = [] 179 | for i in range(up_steps): 180 | cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) 181 | self.decoder.append(ResidualBlock(cin, cout, scale="up", **act_args)) 182 | head_ch = head_ch // 2 183 | 184 | self.encoder = nn.Sequential(*self.encoder) 185 | self.body = nn.Sequential(*self.body) 186 | self.decoder = nn.Sequential(*self.decoder) 187 | self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) 188 | self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) 189 | 190 | def forward(self, x): 191 | feat = self.encoder(x) 192 | x = feat + self.body(feat) 193 | x = self.decoder(x) 194 | out_img = self.out_img_conv(x) 195 | out_mask = self.out_mask_conv(x) 196 | return out_mask, out_img 197 | -------------------------------------------------------------------------------- /codeformer/facelib/parsing/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_chan, out_chan, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = conv3x3(in_chan, out_chan, stride) 14 | self.bn1 = nn.BatchNorm2d(out_chan) 15 | self.conv2 = conv3x3(out_chan, out_chan) 16 | self.bn2 = nn.BatchNorm2d(out_chan) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.downsample = None 19 | if in_chan != out_chan or stride != 1: 20 | self.downsample = nn.Sequential( 21 | nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(out_chan), 23 | ) 24 | 25 | def forward(self, x): 26 | residual = self.conv1(x) 27 | residual = F.relu(self.bn1(residual)) 28 | residual = self.conv2(residual) 29 | residual = self.bn2(residual) 30 | 31 | shortcut = x 32 | if self.downsample is not None: 33 | shortcut = self.downsample(x) 34 | 35 | out = shortcut + residual 36 | out = self.relu(out) 37 | return out 38 | 39 | 40 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 41 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 42 | for i in range(bnum - 1): 43 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 44 | return nn.Sequential(*layers) 45 | 46 | 47 | class ResNet18(nn.Module): 48 | def __init__(self): 49 | super(ResNet18, self).__init__() 50 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 51 | self.bn1 = nn.BatchNorm2d(64) 52 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 53 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 54 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 55 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 56 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 57 | 58 | def forward(self, x): 59 | x = self.conv1(x) 60 | x = F.relu(self.bn1(x)) 61 | x = self.maxpool(x) 62 | 63 | x = self.layer1(x) 64 | feat8 = self.layer2(x) # 1/8 65 | feat16 = self.layer3(feat8) # 1/16 66 | feat32 = self.layer4(feat16) # 1/32 67 | return feat8, feat16, feat32 68 | -------------------------------------------------------------------------------- /codeformer/facelib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back 2 | from .misc import download_pretrained_models, img2tensor, load_file_from_url, scandir 3 | 4 | __all__ = [ 5 | "align_crop_face_landmarks", 6 | "compute_increased_bbox", 7 | "get_valid_bboxes", 8 | "load_file_from_url", 9 | "download_pretrained_models", 10 | "paste_face_back", 11 | "img2tensor", 12 | "scandir", 13 | ] 14 | -------------------------------------------------------------------------------- /codeformer/facelib/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from urllib.parse import urlparse 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torch.hub import download_url_to_file, get_dir 10 | 11 | # from basicsr.utils.download_util import download_file_from_google_drive 12 | 13 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 14 | 15 | 16 | def download_pretrained_models(file_ids, save_path_root): 17 | import gdown 18 | 19 | os.makedirs(save_path_root, exist_ok=True) 20 | 21 | for file_name, file_id in file_ids.items(): 22 | file_url = "https://drive.google.com/uc?id=" + file_id 23 | save_path = osp.abspath(osp.join(save_path_root, file_name)) 24 | if osp.exists(save_path): 25 | user_response = input(f"{file_name} already exist. Do you want to cover it? Y/N\n") 26 | if user_response.lower() == "y": 27 | print(f"Covering {file_name} to {save_path}") 28 | gdown.download(file_url, save_path, quiet=False) 29 | # download_file_from_google_drive(file_id, save_path) 30 | elif user_response.lower() == "n": 31 | print(f"Skipping {file_name}") 32 | else: 33 | raise ValueError("Wrong input. Only accepts Y/N.") 34 | else: 35 | print(f"Downloading {file_name} to {save_path}") 36 | gdown.download(file_url, save_path, quiet=False) 37 | # download_file_from_google_drive(file_id, save_path) 38 | 39 | 40 | def imwrite(img, file_path, params=None, auto_mkdir=True): 41 | """Write image to file. 42 | 43 | Args: 44 | img (ndarray): Image array to be written. 45 | file_path (str): Image file path. 46 | params (None or list): Same as opencv's :func:`imwrite` interface. 47 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 48 | whether to create it automatically. 49 | 50 | Returns: 51 | bool: Successful or not. 52 | """ 53 | if auto_mkdir: 54 | dir_name = os.path.abspath(os.path.dirname(file_path)) 55 | os.makedirs(dir_name, exist_ok=True) 56 | return cv2.imwrite(file_path, img, params) 57 | 58 | 59 | def img2tensor(imgs, bgr2rgb=True, float32=True): 60 | """Numpy array to tensor. 61 | 62 | Args: 63 | imgs (list[ndarray] | ndarray): Input images. 64 | bgr2rgb (bool): Whether to change bgr to rgb. 65 | float32 (bool): Whether to change to float32. 66 | 67 | Returns: 68 | list[tensor] | tensor: Tensor images. If returned results only have 69 | one element, just return tensor. 70 | """ 71 | 72 | def _totensor(img, bgr2rgb, float32): 73 | if img.shape[2] == 3 and bgr2rgb: 74 | if img.dtype == "float64": 75 | img = img.astype("float32") 76 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 77 | img = torch.from_numpy(img.transpose(2, 0, 1)) 78 | if float32: 79 | img = img.float() 80 | return img 81 | 82 | if isinstance(imgs, list): 83 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 84 | else: 85 | return _totensor(imgs, bgr2rgb, float32) 86 | 87 | 88 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 89 | """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py""" 90 | if model_dir is None: 91 | hub_dir = get_dir() 92 | model_dir = os.path.join(hub_dir, "checkpoints") 93 | 94 | os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) 95 | 96 | parts = urlparse(url) 97 | filename = os.path.basename(parts.path) 98 | if file_name is not None: 99 | filename = file_name 100 | cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) 101 | if not os.path.exists(cached_file): 102 | print(f'Downloading: "{url}" to {cached_file}\n') 103 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 104 | return cached_file 105 | 106 | 107 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 108 | """Scan a directory to find the interested files. 109 | Args: 110 | dir_path (str): Path of the directory. 111 | suffix (str | tuple(str), optional): File suffix that we are 112 | interested in. Default: None. 113 | recursive (bool, optional): If set to True, recursively scan the 114 | directory. Default: False. 115 | full_path (bool, optional): If set to True, include the dir_path. 116 | Default: False. 117 | Returns: 118 | A generator for all the interested files with relative paths. 119 | """ 120 | 121 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 122 | raise TypeError('"suffix" must be a string or tuple of strings') 123 | 124 | root = dir_path 125 | 126 | def _scandir(dir_path, suffix, recursive): 127 | for entry in os.scandir(dir_path): 128 | if not entry.name.startswith(".") and entry.is_file(): 129 | if full_path: 130 | return_path = entry.path 131 | else: 132 | return_path = osp.relpath(entry.path, root) 133 | 134 | if suffix is None: 135 | yield return_path 136 | elif return_path.endswith(suffix): 137 | yield return_path 138 | else: 139 | if recursive: 140 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 141 | else: 142 | continue 143 | 144 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 145 | 146 | 147 | def is_gray(img, threshold=10): 148 | img = Image.fromarray(img) 149 | if len(img.getbands()) == 1: 150 | return True 151 | img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16) 152 | img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16) 153 | img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16) 154 | diff1 = (img1 - img2).var() 155 | diff2 = (img2 - img3).var() 156 | diff3 = (img3 - img1).var() 157 | diff_sum = (diff1 + diff2 + diff3) / 3.0 158 | if diff_sum <= threshold: 159 | return True 160 | else: 161 | return False 162 | 163 | 164 | def rgb2gray(img, out_channel=3): 165 | r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2] 166 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 167 | if out_channel == 3: 168 | gray = gray[:, :, np.newaxis].repeat(3, axis=2) 169 | return gray 170 | 171 | 172 | def bgr2gray(img, out_channel=3): 173 | b, g, r = img[:, :, 0], img[:, :, 1], img[:, :, 2] 174 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 175 | if out_channel == 3: 176 | gray = gray[:, :, np.newaxis].repeat(3, axis=2) 177 | return gray 178 | 179 | 180 | def calc_mean_std(feat, eps=1e-5): 181 | """ 182 | Args: 183 | feat (numpy): 3D [w h c]s 184 | """ 185 | size = feat.shape 186 | assert len(size) == 3, "The input feature should be 3D tensor." 187 | c = size[2] 188 | feat_var = feat.reshape(-1, c).var(axis=0) + eps 189 | feat_std = np.sqrt(feat_var).reshape(1, 1, c) 190 | feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c) 191 | return feat_mean, feat_std 192 | 193 | 194 | def adain_npy(content_feat, style_feat): 195 | """Adaptive instance normalization for numpy. 196 | 197 | Args: 198 | content_feat (numpy): The input feature. 199 | style_feat (numpy): The reference feature. 200 | """ 201 | size = content_feat.shape 202 | style_mean, style_std = calc_mean_std(style_feat) 203 | content_mean, content_std = calc_mean_std(content_feat) 204 | normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size) 205 | return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size) 206 | -------------------------------------------------------------------------------- /codeformer/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/scripts/__init__.py -------------------------------------------------------------------------------- /codeformer/scripts/code_format.sh: -------------------------------------------------------------------------------- 1 | black . --config pyproject.toml 2 | isort . -------------------------------------------------------------------------------- /codeformer/scripts/crop_align_face.py: -------------------------------------------------------------------------------- 1 | """ 2 | brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset) 3 | author: lzhbrian (https://lzhbrian.me) 4 | link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5 5 | date: 2020.1.5 6 | note: code is heavily borrowed from 7 | https://github.com/NVlabs/ffhq-dataset 8 | http://dlib.net/face_landmark_detection.py.html 9 | requirements: 10 | conda install Pillow numpy scipy 11 | conda install -c conda-forge dlib 12 | # download face landmark model from: 13 | # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 14 | """ 15 | 16 | import argparse 17 | import glob 18 | import os 19 | import sys 20 | 21 | import cv2 22 | import dlib 23 | import numpy as np 24 | import PIL 25 | import PIL.Image 26 | import scipy 27 | import scipy.ndimage 28 | 29 | # download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 30 | predictor = dlib.shape_predictor("weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat") 31 | 32 | 33 | def get_landmark(filepath, only_keep_largest=True): 34 | """get landmark with dlib 35 | :return: np.array shape=(68, 2) 36 | """ 37 | detector = dlib.get_frontal_face_detector() 38 | 39 | img = dlib.load_rgb_image(filepath) 40 | dets = detector(img, 1) 41 | 42 | # Shangchen modified 43 | print("Number of faces detected: {}".format(len(dets))) 44 | if only_keep_largest: 45 | print("Detect several faces and only keep the largest.") 46 | face_areas = [] 47 | for k, d in enumerate(dets): 48 | face_area = (d.right() - d.left()) * (d.bottom() - d.top()) 49 | face_areas.append(face_area) 50 | 51 | largest_idx = face_areas.index(max(face_areas)) 52 | d = dets[largest_idx] 53 | shape = predictor(img, d) 54 | print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1))) 55 | else: 56 | for k, d in enumerate(dets): 57 | print( 58 | "Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format( 59 | k, d.left(), d.top(), d.right(), d.bottom() 60 | ) 61 | ) 62 | # Get the landmarks/parts for the face in box d. 63 | shape = predictor(img, d) 64 | print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1))) 65 | 66 | t = list(shape.parts()) 67 | a = [] 68 | for tt in t: 69 | a.append([tt.x, tt.y]) 70 | lm = np.array(a) 71 | # lm is a shape=(68,2) np.array 72 | return lm 73 | 74 | 75 | def align_face(filepath, out_path): 76 | """ 77 | :param filepath: str 78 | :return: PIL Image 79 | """ 80 | try: 81 | lm = get_landmark(filepath) 82 | except: 83 | print("No landmark ...") 84 | return 85 | 86 | lm_chin = lm[0:17] # left-right 87 | lm_eyebrow_left = lm[17:22] # left-right 88 | lm_eyebrow_right = lm[22:27] # left-right 89 | lm_nose = lm[27:31] # top-down 90 | lm_nostrils = lm[31:36] # top-down 91 | lm_eye_left = lm[36:42] # left-clockwise 92 | lm_eye_right = lm[42:48] # left-clockwise 93 | lm_mouth_outer = lm[48:60] # left-clockwise 94 | lm_mouth_inner = lm[60:68] # left-clockwise 95 | 96 | # Calculate auxiliary vectors. 97 | eye_left = np.mean(lm_eye_left, axis=0) 98 | eye_right = np.mean(lm_eye_right, axis=0) 99 | eye_avg = (eye_left + eye_right) * 0.5 100 | eye_to_eye = eye_right - eye_left 101 | mouth_left = lm_mouth_outer[0] 102 | mouth_right = lm_mouth_outer[6] 103 | mouth_avg = (mouth_left + mouth_right) * 0.5 104 | eye_to_mouth = mouth_avg - eye_avg 105 | 106 | # Choose oriented crop rectangle. 107 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 108 | x /= np.hypot(*x) 109 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 110 | y = np.flipud(x) * [-1, 1] 111 | c = eye_avg + eye_to_mouth * 0.1 112 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 113 | qsize = np.hypot(*x) * 2 114 | 115 | # read image 116 | img = PIL.Image.open(filepath) 117 | 118 | output_size = 512 119 | transform_size = 4096 120 | enable_padding = False 121 | 122 | # Shrink. 123 | shrink = int(np.floor(qsize / output_size * 0.5)) 124 | if shrink > 1: 125 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 126 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 127 | quad /= shrink 128 | qsize /= shrink 129 | 130 | # Crop. 131 | border = max(int(np.rint(qsize * 0.1)), 3) 132 | crop = ( 133 | int(np.floor(min(quad[:, 0]))), 134 | int(np.floor(min(quad[:, 1]))), 135 | int(np.ceil(max(quad[:, 0]))), 136 | int(np.ceil(max(quad[:, 1]))), 137 | ) 138 | crop = ( 139 | max(crop[0] - border, 0), 140 | max(crop[1] - border, 0), 141 | min(crop[2] + border, img.size[0]), 142 | min(crop[3] + border, img.size[1]), 143 | ) 144 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 145 | img = img.crop(crop) 146 | quad -= crop[0:2] 147 | 148 | # Pad. 149 | pad = ( 150 | int(np.floor(min(quad[:, 0]))), 151 | int(np.floor(min(quad[:, 1]))), 152 | int(np.ceil(max(quad[:, 0]))), 153 | int(np.ceil(max(quad[:, 1]))), 154 | ) 155 | pad = ( 156 | max(-pad[0] + border, 0), 157 | max(-pad[1] + border, 0), 158 | max(pad[2] - img.size[0] + border, 0), 159 | max(pad[3] - img.size[1] + border, 0), 160 | ) 161 | if enable_padding and max(pad) > border - 4: 162 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 163 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), "reflect") 164 | h, w, _ = img.shape 165 | y, x, _ = np.ogrid[:h, :w, :1] 166 | mask = np.maximum( 167 | 1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 168 | 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]), 169 | ) 170 | blur = qsize * 0.02 171 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 172 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 173 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), "RGB") 174 | quad += pad[:2] 175 | 176 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 177 | 178 | if output_size < transform_size: 179 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 180 | 181 | # Save aligned image. 182 | print("saveing: ", out_path) 183 | img.save(out_path) 184 | 185 | return img, np.max(quad[:, 0]) - np.min(quad[:, 0]) 186 | 187 | 188 | if __name__ == "__main__": 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument("--in_dir", type=str, default="./inputs/whole_imgs") 191 | parser.add_argument("--out_dir", type=str, default="./inputs/cropped_faces") 192 | args = parser.parse_args() 193 | 194 | img_list = sorted(glob.glob(f"{args.in_dir}/*.png")) 195 | img_list = sorted(img_list) 196 | 197 | for in_path in img_list: 198 | out_path = os.path.join(args.out_dir, in_path.split("/")[-1]) 199 | out_path = out_path.replace(".jpg", ".png") 200 | size_ = align_face(in_path, out_path) 201 | -------------------------------------------------------------------------------- /codeformer/scripts/download_pretrained_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path as osp 4 | 5 | from basicsr.utils.download_util import load_file_from_url 6 | 7 | 8 | def download_pretrained_models(method, file_urls): 9 | save_path_root = f"./weights/{method}" 10 | os.makedirs(save_path_root, exist_ok=True) 11 | 12 | for file_name, file_url in file_urls.items(): 13 | save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name) 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument( 20 | "method", type=str, help=("Options: 'CodeFormer' 'facelib' 'dlib'. Set to 'all' to download all the models.") 21 | ) 22 | args = parser.parse_args() 23 | 24 | file_urls = { 25 | "CodeFormer": { 26 | "codeformer.pth": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" 27 | }, 28 | "facelib": { 29 | # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth', 30 | "detection_Resnet50_Final.pth": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth", 31 | "parsing_parsenet.pth": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth", 32 | }, 33 | "dlib": { 34 | "mmod_human_face_detector-4cb19393.dat": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat", 35 | "shape_predictor_5_face_landmarks-c4b1e980.dat": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat", 36 | }, 37 | } 38 | 39 | if args.method == "all": 40 | for method in file_urls.keys(): 41 | download_pretrained_models(method, file_urls[method]) 42 | else: 43 | download_pretrained_models(args.method, file_urls[args.method]) 44 | -------------------------------------------------------------------------------- /codeformer/scripts/download_pretrained_models_from_gdrive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path as osp 4 | 5 | # from basicsr.utils.download_util import download_file_from_google_drive 6 | import gdown 7 | 8 | 9 | def download_pretrained_models(method, file_ids): 10 | save_path_root = f"./weights/{method}" 11 | os.makedirs(save_path_root, exist_ok=True) 12 | 13 | for file_name, file_id in file_ids.items(): 14 | file_url = "https://drive.google.com/uc?id=" + file_id 15 | save_path = osp.abspath(osp.join(save_path_root, file_name)) 16 | if osp.exists(save_path): 17 | user_response = input(f"{file_name} already exist. Do you want to cover it? Y/N\n") 18 | if user_response.lower() == "y": 19 | print(f"Covering {file_name} to {save_path}") 20 | gdown.download(file_url, save_path, quiet=False) 21 | # download_file_from_google_drive(file_id, save_path) 22 | elif user_response.lower() == "n": 23 | print(f"Skipping {file_name}") 24 | else: 25 | raise ValueError("Wrong input. Only accepts Y/N.") 26 | else: 27 | print(f"Downloading {file_name} to {save_path}") 28 | gdown.download(file_url, save_path, quiet=False) 29 | # download_file_from_google_drive(file_id, save_path) 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | 35 | parser.add_argument( 36 | "method", type=str, help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models.") 37 | ) 38 | args = parser.parse_args() 39 | 40 | # file name: file id 41 | # 'dlib': { 42 | # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX', 43 | # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg', 44 | # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq' 45 | # } 46 | file_ids = { 47 | "CodeFormer": {"codeformer.pth": "1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB"}, 48 | "facelib": { 49 | "yolov5l-face.pth": "131578zMA6B2x8VQHyHfa6GEPtulMCNzV", 50 | "parsing_parsenet.pth": "16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK", 51 | }, 52 | } 53 | 54 | if args.method == "all": 55 | for method in file_ids.keys(): 56 | download_pretrained_models(method, file_ids[method]) 57 | else: 58 | download_pretrained_models(args.method, file_ids[args.method]) 59 | -------------------------------------------------------------------------------- /codeformer/scripts/package.sh: -------------------------------------------------------------------------------- 1 | python setup.py sdist 2 | twine upload dist/* -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | 4 | [tool.isort] 5 | line_length = 120 6 | profile = "black" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | future 3 | lmdb 4 | numpy 5 | opencv-python 6 | Pillow 7 | pyyaml 8 | requests 9 | scikit-image 10 | scipy 11 | tb-nightly 12 | torch>=1.7.1 13 | torchvision 14 | tqdm 15 | yapf 16 | lpips 17 | gdown # supports downloading the large file from Google Drive -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 3 | max-complexity = 18 4 | exclude =.git,__pycache__,docs/source/conf.py,build,dist 5 | ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,E501,E722,W503,B006 6 | inline-quotes = " 7 | statistics = true 8 | count = true 9 | [mypy] 10 | ignore_missing_imports = True -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | 5 | import setuptools 6 | 7 | 8 | def get_long_description(): 9 | base_dir = os.path.abspath(os.path.dirname(__file__)) 10 | with io.open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f: 11 | return f.read() 12 | 13 | 14 | def get_requirements(): 15 | with open("requirements.txt") as f: 16 | return f.read().splitlines() 17 | 18 | 19 | def get_version(): 20 | current_dir = os.path.abspath(os.path.dirname(__file__)) 21 | version_file = os.path.join(current_dir, "codeformer", "__init__.py") 22 | with io.open(version_file, encoding="utf-8") as f: 23 | return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', f.read(), re.M).group(1) 24 | 25 | 26 | _DEV_REQUIREMENTS = [ 27 | "black==21.7b0", 28 | "flake8==3.9.2", 29 | "isort==5.9.2", 30 | "click==8.0.4", 31 | "importlib-metadata>=1.1.0,<4.3;python_version<'3.8'", 32 | ] 33 | 34 | 35 | setuptools.setup( 36 | name="codeformer-pip", 37 | version=get_version(), 38 | author="kadirnar", 39 | license="S-Lab", 40 | description="PyTorch implementation of CodeFormer", 41 | long_description=get_long_description(), 42 | long_description_content_type="text/markdown", 43 | url="https://github.com/kadirnar/codeformer-pip", 44 | packages=setuptools.find_packages(exclude=["tests"]), 45 | include_package_data=True, 46 | python_requires=">=3.6", 47 | install_requires=get_requirements(), 48 | classifiers=[ 49 | "Development Status :: 5 - Production/Stable", 50 | "Operating System :: OS Independent", 51 | "Intended Audience :: Developers", 52 | "Intended Audience :: Science/Research", 53 | "Programming Language :: Python :: 3", 54 | "Programming Language :: Python :: 3.7", 55 | "Programming Language :: Python :: 3.8", 56 | "Programming Language :: Python :: 3.9", 57 | "Topic :: Software Development :: Libraries", 58 | "Topic :: Software Development :: Libraries :: Python Modules", 59 | "Topic :: Education", 60 | "Topic :: Scientific/Engineering", 61 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 62 | ], 63 | keywords="pytorch, codeformer, face, face-swap, face-swapping, face-swap-pytorch, face-swapping-pytorch", 64 | ) 65 | --------------------------------------------------------------------------------