├── .gitignore ├── LICENSE ├── README.md ├── assets ├── CodeFormer_logo.png ├── color_enhancement_result1.png ├── color_enhancement_result2.png ├── imgsli_1.jpg ├── imgsli_2.jpg ├── imgsli_3.jpg ├── inpainting_result1.png ├── inpainting_result2.png ├── network.jpg ├── restoration_result1.png ├── restoration_result2.png ├── restoration_result3.png └── restoration_result4.png ├── basicsr ├── VERSION ├── __init__.py ├── archs │ ├── __init__.py │ ├── arcface_arch.py │ ├── arch_util.py │ ├── codeformer_arch.py │ ├── rrdbnet_arch.py │ ├── vgg_arch.py │ └── vqgan_arch.py ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── ffhq_blind_dataset.py │ ├── ffhq_blind_joint_dataset.py │ ├── gaussian_kernels.py │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ └── transforms.py ├── losses │ ├── __init__.py │ ├── loss_util.py │ └── losses.py ├── metrics │ ├── __init__.py │ ├── metric_util.py │ └── psnr_ssim.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── codeformer_idx_model.py │ ├── codeformer_joint_model.py │ ├── codeformer_model.py │ ├── lr_scheduler.py │ ├── sr_model.py │ └── vqgan_model.py ├── ops │ ├── __init__.py │ ├── dcn │ │ ├── __init__.py │ │ ├── deform_conv.py │ │ └── src │ │ │ ├── deform_conv_cuda.cpp │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ └── deform_conv_ext.cpp │ ├── fused_act │ │ ├── __init__.py │ │ ├── fused_act.py │ │ └── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ └── upfirdn2d │ │ ├── __init__.py │ │ ├── src │ │ ├── upfirdn2d.cpp │ │ └── upfirdn2d_kernel.cu │ │ └── upfirdn2d.py ├── setup.py ├── train.py └── utils │ ├── __init__.py │ ├── dist_util.py │ ├── download_util.py │ ├── file_client.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ ├── realesrgan_utils.py │ ├── registry.py │ └── video_util.py ├── docs ├── history_changelog.md ├── train.md └── train_CN.md ├── facelib ├── detection │ ├── __init__.py │ ├── align_trans.py │ ├── matlab_cp2tform.py │ ├── retinaface │ │ ├── retinaface.py │ │ ├── retinaface_net.py │ │ └── retinaface_utils.py │ └── yolov5face │ │ ├── __init__.py │ │ ├── face_detector.py │ │ ├── models │ │ ├── __init__.py │ │ ├── common.py │ │ ├── experimental.py │ │ ├── yolo.py │ │ ├── yolov5l.yaml │ │ └── yolov5n.yaml │ │ └── utils │ │ ├── __init__.py │ │ ├── autoanchor.py │ │ ├── datasets.py │ │ ├── extract_ckpt.py │ │ ├── general.py │ │ └── torch_utils.py ├── parsing │ ├── __init__.py │ ├── bisenet.py │ ├── parsenet.py │ └── resnet.py └── utils │ ├── __init__.py │ ├── face_restoration_helper.py │ ├── face_utils.py │ └── misc.py ├── inference_codeformer.py ├── inference_colorization.py ├── inference_inpainting.py ├── inputs ├── cropped_faces │ ├── 0143.png │ ├── 0240.png │ ├── 0342.png │ ├── 0345.png │ ├── 0368.png │ ├── 0412.png │ ├── 0444.png │ ├── 0478.png │ ├── 0500.png │ ├── 0599.png │ ├── 0717.png │ ├── 0720.png │ ├── 0729.png │ ├── 0763.png │ ├── 0770.png │ ├── 0777.png │ ├── 0885.png │ ├── 0934.png │ ├── Solvay_conference_1927_0018.png │ └── Solvay_conference_1927_2_16.png ├── gray_faces │ ├── 067_David_Beckham_00.png │ ├── 089_Miley_Cyrus_00.png │ ├── 099_Victoria_Beckham_00.png │ ├── 111_Alexa_Chung_00.png │ ├── 132_Robert_Downey_Jr_00.png │ ├── 158_Jimmy_Fallon_00.png │ ├── 161_Zac_Efron_00.png │ ├── 169_John_Lennon_00.png │ ├── 170_Marilyn_Monroe_00.png │ ├── Einstein01.png │ ├── Einstein02.png │ ├── Hepburn01.png │ └── Hepburn02.png ├── masked_faces │ ├── 00105.png │ ├── 00108.png │ ├── 00169.png │ ├── 00588.png │ └── 00664.png └── whole_imgs │ ├── 00.jpg │ ├── 01.jpg │ ├── 02.png │ ├── 03.jpg │ ├── 04.jpg │ ├── 05.jpg │ └── 06.png ├── options ├── CodeFormer_colorization.yml ├── CodeFormer_inpainting.yml ├── CodeFormer_stage2.yml ├── CodeFormer_stage3.yml └── VQGAN_512_ds32_nearest_stage1.yml ├── requirements.txt ├── scripts ├── crop_align_face.py ├── download_pretrained_models.py ├── download_pretrained_models_from_gdrive.py ├── generate_latent_gt.py └── inference_vqgan.py ├── web-demos ├── hugging_face │ └── app.py └── replicate │ ├── cog.yaml │ └── predict.py └── weights ├── CodeFormer └── .gitkeep ├── README.md └── facelib └── .gitkeep /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | # ignored files 4 | version.py 5 | 6 | # ignored files with suffix 7 | *.html 8 | # *.png 9 | # *.jpeg 10 | # *.jpg 11 | *.pt 12 | *.gif 13 | *.pth 14 | *.dat 15 | *.zip 16 | 17 | # template 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | 124 | # project 125 | results/ 126 | experiments/ 127 | tb_logger/ 128 | run.sh 129 | *debug* 130 | *_old* 131 | 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2022 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /assets/CodeFormer_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/CodeFormer_logo.png -------------------------------------------------------------------------------- /assets/color_enhancement_result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/color_enhancement_result1.png -------------------------------------------------------------------------------- /assets/color_enhancement_result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/color_enhancement_result2.png -------------------------------------------------------------------------------- /assets/imgsli_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/imgsli_1.jpg -------------------------------------------------------------------------------- /assets/imgsli_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/imgsli_2.jpg -------------------------------------------------------------------------------- /assets/imgsli_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/imgsli_3.jpg -------------------------------------------------------------------------------- /assets/inpainting_result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/inpainting_result1.png -------------------------------------------------------------------------------- /assets/inpainting_result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/inpainting_result2.png -------------------------------------------------------------------------------- /assets/network.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/network.jpg -------------------------------------------------------------------------------- /assets/restoration_result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/restoration_result1.png -------------------------------------------------------------------------------- /assets/restoration_result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/restoration_result2.png -------------------------------------------------------------------------------- /assets/restoration_result3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/restoration_result3.png -------------------------------------------------------------------------------- /assets/restoration_result4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/assets/restoration_result4.png -------------------------------------------------------------------------------- /basicsr/VERSION: -------------------------------------------------------------------------------- 1 | 1.3.2 2 | -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .train import * 10 | from .utils import * 11 | from .version import __gitsha__, __version__ 12 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules 16 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | -------------------------------------------------------------------------------- /basicsr/archs/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | 12 | Used in RRDB block in ESRGAN. 13 | 14 | Args: 15 | num_feat (int): Channel number of intermediate features. 16 | num_grow_ch (int): Channels for each growth. 17 | """ 18 | 19 | def __init__(self, num_feat=64, num_grow_ch=32): 20 | super(ResidualDenseBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 26 | 27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | 29 | # initialization 30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 31 | 32 | def forward(self, x): 33 | x1 = self.lrelu(self.conv1(x)) 34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 38 | # Emperically, we use 0.2 to scale the residual for better performance 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | """Residual in Residual Dense Block. 44 | 45 | Used in RRDB-Net in ESRGAN. 46 | 47 | Args: 48 | num_feat (int): Channel number of intermediate features. 49 | num_grow_ch (int): Channels for each growth. 50 | """ 51 | 52 | def __init__(self, num_feat, num_grow_ch=32): 53 | super(RRDB, self).__init__() 54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 57 | 58 | def forward(self, x): 59 | out = self.rdb1(x) 60 | out = self.rdb2(out) 61 | out = self.rdb3(out) 62 | # Emperically, we use 0.2 to scale the residual for better performance 63 | return out * 0.2 + x 64 | 65 | 66 | @ARCH_REGISTRY.register() 67 | class RRDBNet(nn.Module): 68 | """Networks consisting of Residual in Residual Dense Block, which is used 69 | in ESRGAN. 70 | 71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 72 | 73 | We extend ESRGAN for scale x2 and scale x1. 74 | Note: This is one option for scale 1, scale 2 in RRDBNet. 75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 77 | 78 | Args: 79 | num_in_ch (int): Channel number of inputs. 80 | num_out_ch (int): Channel number of outputs. 81 | num_feat (int): Channel number of intermediate features. 82 | Default: 64 83 | num_block (int): Block number in the trunk network. Defaults: 23 84 | num_grow_ch (int): Channels for each growth. Default: 32. 85 | """ 86 | 87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 88 | super(RRDBNet, self).__init__() 89 | self.scale = scale 90 | if scale == 2: 91 | num_in_ch = num_in_ch * 4 92 | elif scale == 1: 93 | num_in_ch = num_in_ch * 16 94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | # upsample 98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 102 | 103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 104 | 105 | def forward(self, x): 106 | if self.scale == 2: 107 | feat = pixel_unshuffle(x, scale=2) 108 | elif self.scale == 1: 109 | feat = pixel_unshuffle(x, scale=4) 110 | else: 111 | feat = x 112 | feat = self.conv_first(feat) 113 | body_feat = self.conv_body(self.body(feat)) 114 | feat = feat + body_feat 115 | # upsample 116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 118 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 119 | return out -------------------------------------------------------------------------------- /basicsr/archs/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch import nn as nn 5 | from torchvision.models import vgg as vgg 6 | 7 | from basicsr.utils.registry import ARCH_REGISTRY 8 | 9 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' 10 | NAMES = { 11 | 'vgg11': [ 12 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 13 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 14 | 'pool5' 15 | ], 16 | 'vgg13': [ 17 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 18 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 19 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' 20 | ], 21 | 'vgg16': [ 22 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 23 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 24 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 25 | 'pool5' 26 | ], 27 | 'vgg19': [ 28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 29 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 30 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 31 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' 32 | ] 33 | } 34 | 35 | 36 | def insert_bn(names): 37 | """Insert bn layer after each conv. 38 | 39 | Args: 40 | names (list): The list of layer names. 41 | 42 | Returns: 43 | list: The list of layer names with bn layers. 44 | """ 45 | names_bn = [] 46 | for name in names: 47 | names_bn.append(name) 48 | if 'conv' in name: 49 | position = name.replace('conv', '') 50 | names_bn.append('bn' + position) 51 | return names_bn 52 | 53 | 54 | @ARCH_REGISTRY.register() 55 | class VGGFeatureExtractor(nn.Module): 56 | """VGG network for feature extraction. 57 | 58 | In this implementation, we allow users to choose whether use normalization 59 | in the input feature and the type of vgg network. Note that the pretrained 60 | path must fit the vgg type. 61 | 62 | Args: 63 | layer_name_list (list[str]): Forward function returns the corresponding 64 | features according to the layer_name_list. 65 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 66 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 67 | use_input_norm (bool): If True, normalize the input image. Importantly, 68 | the input feature must in the range [0, 1]. Default: True. 69 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 70 | Default: False. 71 | requires_grad (bool): If true, the parameters of VGG network will be 72 | optimized. Default: False. 73 | remove_pooling (bool): If true, the max pooling operations in VGG net 74 | will be removed. Default: False. 75 | pooling_stride (int): The stride of max pooling operation. Default: 2. 76 | """ 77 | 78 | def __init__(self, 79 | layer_name_list, 80 | vgg_type='vgg19', 81 | use_input_norm=True, 82 | range_norm=False, 83 | requires_grad=False, 84 | remove_pooling=False, 85 | pooling_stride=2): 86 | super(VGGFeatureExtractor, self).__init__() 87 | 88 | self.layer_name_list = layer_name_list 89 | self.use_input_norm = use_input_norm 90 | self.range_norm = range_norm 91 | 92 | self.names = NAMES[vgg_type.replace('_bn', '')] 93 | if 'bn' in vgg_type: 94 | self.names = insert_bn(self.names) 95 | 96 | # only borrow layers that will be used to avoid unused params 97 | max_idx = 0 98 | for v in layer_name_list: 99 | idx = self.names.index(v) 100 | if idx > max_idx: 101 | max_idx = idx 102 | 103 | if os.path.exists(VGG_PRETRAIN_PATH): 104 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 105 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) 106 | vgg_net.load_state_dict(state_dict) 107 | else: 108 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 109 | 110 | features = vgg_net.features[:max_idx + 1] 111 | 112 | modified_net = OrderedDict() 113 | for k, v in zip(self.names, features): 114 | if 'pool' in k: 115 | # if remove_pooling is true, pooling operation will be removed 116 | if remove_pooling: 117 | continue 118 | else: 119 | # in some cases, we may want to change the default stride 120 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 121 | else: 122 | modified_net[k] = v 123 | 124 | self.vgg_net = nn.Sequential(modified_net) 125 | 126 | if not requires_grad: 127 | self.vgg_net.eval() 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | else: 131 | self.vgg_net.train() 132 | for param in self.parameters(): 133 | param.requires_grad = True 134 | 135 | if self.use_input_norm: 136 | # the mean is for image with range [0, 1] 137 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 138 | # the std is for image with range [0, 1] 139 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 140 | 141 | def forward(self, x): 142 | """Forward function. 143 | 144 | Args: 145 | x (Tensor): Input tensor with shape (n, c, h, w). 146 | 147 | Returns: 148 | Tensor: Forward results. 149 | """ 150 | if self.range_norm: 151 | x = (x + 1) / 2 152 | if self.use_input_norm: 153 | x = (x - self.mean) / self.std 154 | output = {} 155 | 156 | for key, layer in self.vgg_net._modules.items(): 157 | x = layer(x) 158 | if key in self.layer_name_list: 159 | output[key] = x.clone() 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must constain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | 84 | prefetch_mode = dataset_opt.get('prefetch_mode') 85 | if prefetch_mode == 'cpu': # CPUPrefetcher 86 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 87 | logger = get_root_logger() 88 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}') 89 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 90 | else: 91 | # prefetch_mode=None: Normal dataloader 92 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 93 | return torch.utils.data.DataLoader(**dataloader_args) 94 | 95 | 96 | def worker_init_fn(worker_id, num_workers, rank, seed): 97 | # Set the worker seed to num_workers * rank + worker_id + seed 98 | worker_seed = num_workers * rank + worker_id + seed 99 | np.random.seed(worker_seed) 100 | random.seed(worker_seed) 101 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | 4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file 5 | from basicsr.data.transforms import augment, paired_random_crop 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class PairedImageDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and 15 | GT image pairs. 16 | 17 | There are three modes: 18 | 1. 'lmdb': Use lmdb files. 19 | If opt['io_backend'] == lmdb. 20 | 2. 'meta_info_file': Use meta information file to generate paths. 21 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 22 | 3. 'folder': Scan folders to generate paths. 23 | The rest. 24 | 25 | Args: 26 | opt (dict): Config for train datasets. It contains the following keys: 27 | dataroot_gt (str): Data root path for gt. 28 | dataroot_lq (str): Data root path for lq. 29 | meta_info_file (str): Path for meta information file. 30 | io_backend (dict): IO backend type and other kwarg. 31 | filename_tmpl (str): Template for each filename. Note that the 32 | template excludes the file extension. Default: '{}'. 33 | gt_size (int): Cropped patched size for gt patches. 34 | use_flip (bool): Use horizontal flips. 35 | use_rot (bool): Use rotation (use vertical flip and transposing h 36 | and w for implementation). 37 | 38 | scale (bool): Scale, which will be added automatically. 39 | phase (str): 'train' or 'val'. 40 | """ 41 | 42 | def __init__(self, opt): 43 | super(PairedImageDataset, self).__init__() 44 | self.opt = opt 45 | # file client (io backend) 46 | self.file_client = None 47 | self.io_backend_opt = opt['io_backend'] 48 | self.mean = opt['mean'] if 'mean' in opt else None 49 | self.std = opt['std'] if 'std' in opt else None 50 | 51 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 52 | if 'filename_tmpl' in opt: 53 | self.filename_tmpl = opt['filename_tmpl'] 54 | else: 55 | self.filename_tmpl = '{}' 56 | 57 | if self.io_backend_opt['type'] == 'lmdb': 58 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 59 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 60 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 61 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: 62 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], 63 | self.opt['meta_info_file'], self.filename_tmpl) 64 | else: 65 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 66 | 67 | def __getitem__(self, index): 68 | if self.file_client is None: 69 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 70 | 71 | scale = self.opt['scale'] 72 | 73 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 74 | # image range: [0, 1], float32. 75 | gt_path = self.paths[index]['gt_path'] 76 | img_bytes = self.file_client.get(gt_path, 'gt') 77 | img_gt = imfrombytes(img_bytes, float32=True) 78 | lq_path = self.paths[index]['lq_path'] 79 | img_bytes = self.file_client.get(lq_path, 'lq') 80 | img_lq = imfrombytes(img_bytes, float32=True) 81 | 82 | # augmentation for training 83 | if self.opt['phase'] == 'train': 84 | gt_size = self.opt['gt_size'] 85 | # random crop 86 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 87 | # flip, rotation 88 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) 89 | 90 | # TODO: color space transform 91 | # BGR to RGB, HWC to CHW, numpy to tensor 92 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 93 | # normalize 94 | if self.mean is not None or self.std is not None: 95 | normalize(img_lq, self.mean, self.std, inplace=True) 96 | normalize(img_gt, self.mean, self.std, inplace=True) 97 | 98 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 99 | 100 | def __len__(self): 101 | return len(self.paths) 102 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 116 | 117 | def next(self): 118 | torch.cuda.current_stream().wait_stream(self.stream) 119 | batch = self.batch 120 | self.preload() 121 | return batch 122 | 123 | def reset(self): 124 | self.loader = iter(self.ori_loader) 125 | self.preload() 126 | -------------------------------------------------------------------------------- /basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | 4 | 5 | def mod_crop(img, scale): 6 | """Mod crop images, used during testing. 7 | 8 | Args: 9 | img (ndarray): Input image. 10 | scale (int): Scale factor. 11 | 12 | Returns: 13 | ndarray: Result image. 14 | """ 15 | img = img.copy() 16 | if img.ndim in (2, 3): 17 | h, w = img.shape[0], img.shape[1] 18 | h_remainder, w_remainder = h % scale, w % scale 19 | img = img[:h - h_remainder, :w - w_remainder, ...] 20 | else: 21 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 22 | return img 23 | 24 | 25 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): 26 | """Paired random crop. 27 | 28 | It crops lists of lq and gt images with corresponding locations. 29 | 30 | Args: 31 | img_gts (list[ndarray] | ndarray): GT images. Note that all images 32 | should have the same shape. If the input is an ndarray, it will 33 | be transformed to a list containing itself. 34 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 35 | should have the same shape. If the input is an ndarray, it will 36 | be transformed to a list containing itself. 37 | gt_patch_size (int): GT patch size. 38 | scale (int): Scale factor. 39 | gt_path (str): Path to ground-truth. 40 | 41 | Returns: 42 | list[ndarray] | ndarray: GT images and LQ images. If returned results 43 | only have one element, just return ndarray. 44 | """ 45 | 46 | if not isinstance(img_gts, list): 47 | img_gts = [img_gts] 48 | if not isinstance(img_lqs, list): 49 | img_lqs = [img_lqs] 50 | 51 | h_lq, w_lq, _ = img_lqs[0].shape 52 | h_gt, w_gt, _ = img_gts[0].shape 53 | lq_patch_size = gt_patch_size // scale 54 | 55 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 56 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 57 | f'multiplication of LQ ({h_lq}, {w_lq}).') 58 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 59 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 60 | f'({lq_patch_size}, {lq_patch_size}). ' 61 | f'Please remove {gt_path}.') 62 | 63 | # randomly choose top and left coordinates for lq patch 64 | top = random.randint(0, h_lq - lq_patch_size) 65 | left = random.randint(0, w_lq - lq_patch_size) 66 | 67 | # crop lq patch 68 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 69 | 70 | # crop corresponding gt patch 71 | top_gt, left_gt = int(top * scale), int(left * scale) 72 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 73 | if len(img_gts) == 1: 74 | img_gts = img_gts[0] 75 | if len(img_lqs) == 1: 76 | img_lqs = img_lqs[0] 77 | return img_gts, img_lqs 78 | 79 | 80 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 81 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 82 | 83 | We use vertical flip and transpose for rotation implementation. 84 | All the images in the list use the same augmentation. 85 | 86 | Args: 87 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 88 | is an ndarray, it will be transformed to a list. 89 | hflip (bool): Horizontal flip. Default: True. 90 | rotation (bool): Ratotation. Default: True. 91 | flows (list[ndarray]: Flows to be augmented. If the input is an 92 | ndarray, it will be transformed to a list. 93 | Dimension is (h, w, 2). Default: None. 94 | return_status (bool): Return the status of flip and rotation. 95 | Default: False. 96 | 97 | Returns: 98 | list[ndarray] | ndarray: Augmented images and flows. If returned 99 | results only have one element, just return ndarray. 100 | 101 | """ 102 | hflip = hflip and random.random() < 0.5 103 | vflip = rotation and random.random() < 0.5 104 | rot90 = rotation and random.random() < 0.5 105 | 106 | def _augment(img): 107 | if hflip: # horizontal 108 | cv2.flip(img, 1, img) 109 | if vflip: # vertical 110 | cv2.flip(img, 0, img) 111 | if rot90: 112 | img = img.transpose(1, 0, 2) 113 | return img 114 | 115 | def _augment_flow(flow): 116 | if hflip: # horizontal 117 | cv2.flip(flow, 1, flow) 118 | flow[:, :, 0] *= -1 119 | if vflip: # vertical 120 | cv2.flip(flow, 0, flow) 121 | flow[:, :, 1] *= -1 122 | if rot90: 123 | flow = flow.transpose(1, 0, 2) 124 | flow = flow[:, :, [1, 0]] 125 | return flow 126 | 127 | if not isinstance(imgs, list): 128 | imgs = [imgs] 129 | imgs = [_augment(img) for img in imgs] 130 | if len(imgs) == 1: 131 | imgs = imgs[0] 132 | 133 | if flows is not None: 134 | if not isinstance(flows, list): 135 | flows = [flows] 136 | flows = [_augment_flow(flow) for flow in flows] 137 | if len(flows) == 1: 138 | flows = flows[0] 139 | return imgs, flows 140 | else: 141 | if return_status: 142 | return imgs, (hflip, vflip, rot90) 143 | else: 144 | return imgs 145 | 146 | 147 | def img_rotate(img, angle, center=None, scale=1.0): 148 | """Rotate image. 149 | 150 | Args: 151 | img (ndarray): Image to be rotated. 152 | angle (float): Rotation angle in degrees. Positive values mean 153 | counter-clockwise rotation. 154 | center (tuple[int]): Rotation center. If the center is None, 155 | initialize it as the center of the image. Default: None. 156 | scale (float): Isotropic scale factor. Default: 1.0. 157 | """ 158 | (h, w) = img.shape[:2] 159 | 160 | if center is None: 161 | center = (w // 2, h // 2) 162 | 163 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 164 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 165 | return rotated_img 166 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils import get_root_logger 4 | from basicsr.utils.registry import LOSS_REGISTRY 5 | from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize, 6 | gradient_penalty_loss, r1_penalty) 7 | 8 | __all__ = [ 9 | 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss', 10 | 'r1_penalty', 'g_path_regularize' 11 | ] 12 | 13 | 14 | def build_loss(opt): 15 | """Build loss from options. 16 | 17 | Args: 18 | opt (dict): Configuration. It must constain: 19 | type (str): Model type. 20 | """ 21 | opt = deepcopy(opt) 22 | loss_type = opt.pop('type') 23 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 24 | logger = get_root_logger() 25 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 26 | return loss 27 | -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .psnr_ssim import calculate_psnr, calculate_ssim 5 | 6 | __all__ = ['calculate_psnr', 'calculate_ssim'] 7 | 8 | 9 | def calculate_metric(data, opt): 10 | """Calculate metric from data and options. 11 | 12 | Args: 13 | opt (dict): Configuration. It must constain: 14 | type (str): Model type. 15 | """ 16 | opt = deepcopy(opt) 17 | metric_type = opt.pop('type') 18 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 19 | return metric 20 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /basicsr/metrics/psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from basicsr.metrics.metric_util import reorder_image, to_y_channel 5 | from basicsr.utils.registry import METRIC_REGISTRY 6 | 7 | 8 | @METRIC_REGISTRY.register() 9 | def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 10 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 11 | 12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 13 | 14 | Args: 15 | img1 (ndarray): Images with range [0, 255]. 16 | img2 (ndarray): Images with range [0, 255]. 17 | crop_border (int): Cropped pixels in each edge of an image. These 18 | pixels are not involved in the PSNR calculation. 19 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 20 | Default: 'HWC'. 21 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 22 | 23 | Returns: 24 | float: psnr result. 25 | """ 26 | 27 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 28 | if input_order not in ['HWC', 'CHW']: 29 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 30 | img1 = reorder_image(img1, input_order=input_order) 31 | img2 = reorder_image(img2, input_order=input_order) 32 | img1 = img1.astype(np.float64) 33 | img2 = img2.astype(np.float64) 34 | 35 | if crop_border != 0: 36 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 38 | 39 | if test_y_channel: 40 | img1 = to_y_channel(img1) 41 | img2 = to_y_channel(img2) 42 | 43 | mse = np.mean((img1 - img2)**2) 44 | if mse == 0: 45 | return float('inf') 46 | return 20. * np.log10(255. / np.sqrt(mse)) 47 | 48 | 49 | def _ssim(img1, img2): 50 | """Calculate SSIM (structural similarity) for one channel images. 51 | 52 | It is called by func:`calculate_ssim`. 53 | 54 | Args: 55 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 56 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 57 | 58 | Returns: 59 | float: ssim result. 60 | """ 61 | 62 | C1 = (0.01 * 255)**2 63 | C2 = (0.03 * 255)**2 64 | 65 | img1 = img1.astype(np.float64) 66 | img2 = img2.astype(np.float64) 67 | kernel = cv2.getGaussianKernel(11, 1.5) 68 | window = np.outer(kernel, kernel.transpose()) 69 | 70 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 71 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 72 | mu1_sq = mu1**2 73 | mu2_sq = mu2**2 74 | mu1_mu2 = mu1 * mu2 75 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 76 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 77 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 78 | 79 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 80 | return ssim_map.mean() 81 | 82 | 83 | @METRIC_REGISTRY.register() 84 | def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 85 | """Calculate SSIM (structural similarity). 86 | 87 | Ref: 88 | Image quality assessment: From error visibility to structural similarity 89 | 90 | The results are the same as that of the official released MATLAB code in 91 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 92 | 93 | For three-channel images, SSIM is calculated for each channel and then 94 | averaged. 95 | 96 | Args: 97 | img1 (ndarray): Images with range [0, 255]. 98 | img2 (ndarray): Images with range [0, 255]. 99 | crop_border (int): Cropped pixels in each edge of an image. These 100 | pixels are not involved in the SSIM calculation. 101 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 102 | Default: 'HWC'. 103 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 104 | 105 | Returns: 106 | float: ssim result. 107 | """ 108 | 109 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 110 | if input_order not in ['HWC', 'CHW']: 111 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 112 | img1 = reorder_image(img1, input_order=input_order) 113 | img2 = reorder_image(img2, input_order=input_order) 114 | img1 = img1.astype(np.float64) 115 | img2 = img2.astype(np.float64) 116 | 117 | if crop_border != 0: 118 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 119 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 120 | 121 | if test_y_channel: 122 | img1 = to_y_channel(img1) 123 | img2 = to_y_channel(img2) 124 | 125 | ssims = [] 126 | for i in range(img1.shape[2]): 127 | ssims.append(_ssim(img1[..., i], img2[..., i])) 128 | return np.array(ssims).mean() 129 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with 12 | # '_model.py' 13 | model_folder = osp.dirname(osp.abspath(__file__)) 14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 15 | # import all the model modules 16 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 17 | 18 | 19 | def build_model(opt): 20 | """Build model from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must constain: 24 | model_type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 28 | logger = get_root_logger() 29 | logger.info(f'Model [{model.__class__.__name__}] is created.') 30 | return model 31 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The mimimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | 7 | try: 8 | from . import fused_act_ext 9 | except ImportError: 10 | import os 11 | BASICSR_JIT = os.getenv('BASICSR_JIT') 12 | if BASICSR_JIT == 'True': 13 | from torch.utils.cpp_extension import load 14 | module_path = os.path.dirname(__file__) 15 | fused_act_ext = load( 16 | 'fused', 17 | sources=[ 18 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 19 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 20 | ], 21 | ) 22 | 23 | 24 | class FusedLeakyReLUFunctionBackward(Function): 25 | 26 | @staticmethod 27 | def forward(ctx, grad_output, out, negative_slope, scale): 28 | ctx.save_for_backward(out) 29 | ctx.negative_slope = negative_slope 30 | ctx.scale = scale 31 | 32 | empty = grad_output.new_empty(0) 33 | 34 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 35 | 36 | dim = [0] 37 | 38 | if grad_input.ndim > 2: 39 | dim += list(range(2, grad_input.ndim)) 40 | 41 | grad_bias = grad_input.sum(dim).detach() 42 | 43 | return grad_input, grad_bias 44 | 45 | @staticmethod 46 | def backward(ctx, gradgrad_input, gradgrad_bias): 47 | out, = ctx.saved_tensors 48 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 49 | ctx.scale) 50 | 51 | return gradgrad_out, None, None, None 52 | 53 | 54 | class FusedLeakyReLUFunction(Function): 55 | 56 | @staticmethod 57 | def forward(ctx, input, bias, negative_slope, scale): 58 | empty = input.new_empty(0) 59 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 60 | ctx.save_for_backward(out) 61 | ctx.negative_slope = negative_slope 62 | ctx.scale = scale 63 | 64 | return out 65 | 66 | @staticmethod 67 | def backward(ctx, grad_output): 68 | out, = ctx.saved_tensors 69 | 70 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 71 | 72 | return grad_input, grad_bias, None, None 73 | 74 | 75 | class FusedLeakyReLU(nn.Module): 76 | 77 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 78 | super().__init__() 79 | 80 | self.bias = nn.Parameter(torch.zeros(channel)) 81 | self.negative_slope = negative_slope 82 | self.scale = scale 83 | 84 | def forward(self, input): 85 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 86 | 87 | 88 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 89 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 90 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.nn import functional as F 6 | 7 | try: 8 | from . import upfirdn2d_ext 9 | except ImportError: 10 | import os 11 | BASICSR_JIT = os.getenv('BASICSR_JIT') 12 | if BASICSR_JIT == 'True': 13 | from torch.utils.cpp_extension import load 14 | module_path = os.path.dirname(__file__) 15 | upfirdn2d_ext = load( 16 | 'upfirdn2d', 17 | sources=[ 18 | os.path.join(module_path, 'src', 'upfirdn2d.cpp'), 19 | os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), 20 | ], 21 | ) 22 | 23 | 24 | class UpFirDn2dBackward(Function): 25 | 26 | @staticmethod 27 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 28 | 29 | up_x, up_y = up 30 | down_x, down_y = down 31 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 32 | 33 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 34 | 35 | grad_input = upfirdn2d_ext.upfirdn2d( 36 | grad_output, 37 | grad_kernel, 38 | down_x, 39 | down_y, 40 | up_x, 41 | up_y, 42 | g_pad_x0, 43 | g_pad_x1, 44 | g_pad_y0, 45 | g_pad_y1, 46 | ) 47 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 48 | 49 | ctx.save_for_backward(kernel) 50 | 51 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 52 | 53 | ctx.up_x = up_x 54 | ctx.up_y = up_y 55 | ctx.down_x = down_x 56 | ctx.down_y = down_y 57 | ctx.pad_x0 = pad_x0 58 | ctx.pad_x1 = pad_x1 59 | ctx.pad_y0 = pad_y0 60 | ctx.pad_y1 = pad_y1 61 | ctx.in_size = in_size 62 | ctx.out_size = out_size 63 | 64 | return grad_input 65 | 66 | @staticmethod 67 | def backward(ctx, gradgrad_input): 68 | kernel, = ctx.saved_tensors 69 | 70 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 71 | 72 | gradgrad_out = upfirdn2d_ext.upfirdn2d( 73 | gradgrad_input, 74 | kernel, 75 | ctx.up_x, 76 | ctx.up_y, 77 | ctx.down_x, 78 | ctx.down_y, 79 | ctx.pad_x0, 80 | ctx.pad_x1, 81 | ctx.pad_y0, 82 | ctx.pad_y1, 83 | ) 84 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], 85 | # ctx.out_size[1], ctx.in_size[3]) 86 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 87 | 88 | return gradgrad_out, None, None, None, None, None, None, None, None 89 | 90 | 91 | class UpFirDn2d(Function): 92 | 93 | @staticmethod 94 | def forward(ctx, input, kernel, up, down, pad): 95 | up_x, up_y = up 96 | down_x, down_y = down 97 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 98 | 99 | kernel_h, kernel_w = kernel.shape 100 | batch, channel, in_h, in_w = input.shape 101 | ctx.in_size = input.shape 102 | 103 | input = input.reshape(-1, in_h, in_w, 1) 104 | 105 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 106 | 107 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 108 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 109 | ctx.out_size = (out_h, out_w) 110 | 111 | ctx.up = (up_x, up_y) 112 | ctx.down = (down_x, down_y) 113 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 114 | 115 | g_pad_x0 = kernel_w - pad_x0 - 1 116 | g_pad_y0 = kernel_h - pad_y0 - 1 117 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 118 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 119 | 120 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 121 | 122 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 123 | # out = out.view(major, out_h, out_w, minor) 124 | out = out.view(-1, channel, out_h, out_w) 125 | 126 | return out 127 | 128 | @staticmethod 129 | def backward(ctx, grad_output): 130 | kernel, grad_kernel = ctx.saved_tensors 131 | 132 | grad_input = UpFirDn2dBackward.apply( 133 | grad_output, 134 | kernel, 135 | grad_kernel, 136 | ctx.up, 137 | ctx.down, 138 | ctx.pad, 139 | ctx.g_pad, 140 | ctx.in_size, 141 | ctx.out_size, 142 | ) 143 | 144 | return grad_input, None, None, None, None 145 | 146 | 147 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 148 | if input.device.type == 'cpu': 149 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 150 | else: 151 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 152 | 153 | return out 154 | 155 | 156 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 157 | _, channel, in_h, in_w = input.shape 158 | input = input.reshape(-1, in_h, in_w, 1) 159 | 160 | _, in_h, in_w, minor = input.shape 161 | kernel_h, kernel_w = kernel.shape 162 | 163 | out = input.view(-1, in_h, 1, in_w, 1, minor) 164 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 165 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 166 | 167 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 168 | out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 172 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 173 | out = F.conv2d(out, w) 174 | out = out.reshape( 175 | -1, 176 | minor, 177 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 178 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 179 | ) 180 | out = out.permute(0, 2, 3, 1) 181 | out = out[:, ::down_y, ::down_x, :] 182 | 183 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 184 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 185 | 186 | return out.view(-1, channel, out_h, out_w) 187 | -------------------------------------------------------------------------------- /basicsr/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import sys 8 | import time 9 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 10 | from utils.misc import gpu_is_available 11 | 12 | version_file = './basicsr/version.py' 13 | 14 | 15 | def readme(): 16 | with open('README.md', encoding='utf-8') as f: 17 | content = f.read() 18 | return content 19 | 20 | 21 | def get_git_hash(): 22 | 23 | def _minimal_ext_cmd(cmd): 24 | # construct minimal environment 25 | env = {} 26 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 27 | v = os.environ.get(k) 28 | if v is not None: 29 | env[k] = v 30 | # LANGUAGE is used on win32 31 | env['LANGUAGE'] = 'C' 32 | env['LANG'] = 'C' 33 | env['LC_ALL'] = 'C' 34 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 35 | return out 36 | 37 | try: 38 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 39 | sha = out.strip().decode('ascii') 40 | except OSError: 41 | sha = 'unknown' 42 | 43 | return sha 44 | 45 | 46 | def get_hash(): 47 | if os.path.exists('.git'): 48 | sha = get_git_hash()[:7] 49 | elif os.path.exists(version_file): 50 | try: 51 | from version import __version__ 52 | sha = __version__.split('+')[-1] 53 | except ImportError: 54 | raise ImportError('Unable to get git version') 55 | else: 56 | sha = 'unknown' 57 | 58 | return sha 59 | 60 | 61 | def write_version_py(): 62 | content = """# GENERATED VERSION FILE 63 | # TIME: {} 64 | __version__ = '{}' 65 | __gitsha__ = '{}' 66 | version_info = ({}) 67 | """ 68 | sha = get_hash() 69 | with open('./basicsr/VERSION', 'r') as f: 70 | SHORT_VERSION = f.read().strip() 71 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 72 | 73 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) 74 | with open(version_file, 'w') as f: 75 | f.write(version_file_str) 76 | 77 | 78 | def get_version(): 79 | with open(version_file, 'r') as f: 80 | exec(compile(f.read(), version_file, 'exec')) 81 | return locals()['__version__'] 82 | 83 | 84 | def make_cuda_ext(name, module, sources, sources_cuda=None): 85 | if sources_cuda is None: 86 | sources_cuda = [] 87 | define_macros = [] 88 | extra_compile_args = {'cxx': []} 89 | 90 | # if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 91 | if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1': 92 | define_macros += [('WITH_CUDA', None)] 93 | extension = CUDAExtension 94 | extra_compile_args['nvcc'] = [ 95 | '-D__CUDA_NO_HALF_OPERATORS__', 96 | '-D__CUDA_NO_HALF_CONVERSIONS__', 97 | '-D__CUDA_NO_HALF2_OPERATORS__', 98 | ] 99 | sources += sources_cuda 100 | else: 101 | print(f'Compiling {name} without CUDA') 102 | extension = CppExtension 103 | 104 | return extension( 105 | name=f'{module}.{name}', 106 | sources=[os.path.join(*module.split('.'), p) for p in sources], 107 | define_macros=define_macros, 108 | extra_compile_args=extra_compile_args) 109 | 110 | 111 | def get_requirements(filename='requirements.txt'): 112 | with open(os.path.join('.', filename), 'r') as f: 113 | requires = [line.replace('\n', '') for line in f.readlines()] 114 | return requires 115 | 116 | 117 | if __name__ == '__main__': 118 | if '--cuda_ext' in sys.argv: 119 | ext_modules = [ 120 | make_cuda_ext( 121 | name='deform_conv_ext', 122 | module='ops.dcn', 123 | sources=['src/deform_conv_ext.cpp'], 124 | sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), 125 | make_cuda_ext( 126 | name='fused_act_ext', 127 | module='ops.fused_act', 128 | sources=['src/fused_bias_act.cpp'], 129 | sources_cuda=['src/fused_bias_act_kernel.cu']), 130 | make_cuda_ext( 131 | name='upfirdn2d_ext', 132 | module='ops.upfirdn2d', 133 | sources=['src/upfirdn2d.cpp'], 134 | sources_cuda=['src/upfirdn2d_kernel.cu']), 135 | ] 136 | sys.argv.remove('--cuda_ext') 137 | else: 138 | ext_modules = [] 139 | 140 | write_version_py() 141 | setup( 142 | name='basicsr', 143 | version=get_version(), 144 | description='Open Source Image and Video Super-Resolution Toolbox', 145 | long_description=readme(), 146 | long_description_content_type='text/markdown', 147 | author='Xintao Wang', 148 | author_email='xintao.wang@outlook.com', 149 | keywords='computer vision, restoration, super resolution', 150 | url='https://github.com/xinntao/BasicSR', 151 | include_package_data=True, 152 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), 153 | classifiers=[ 154 | 'Development Status :: 4 - Beta', 155 | 'License :: OSI Approved :: Apache Software License', 156 | 'Operating System :: OS Independent', 157 | 'Programming Language :: Python :: 3', 158 | 'Programming Language :: Python :: 3.7', 159 | 'Programming Language :: Python :: 3.8', 160 | ], 161 | license='Apache License 2.0', 162 | setup_requires=['cython', 'numpy'], 163 | install_requires=get_requirements(), 164 | ext_modules=ext_modules, 165 | cmdclass={'build_ext': BuildExtension}, 166 | zip_safe=False) 167 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 3 | from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 4 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 5 | 6 | __all__ = [ 7 | # file_client.py 8 | 'FileClient', 9 | # img_util.py 10 | 'img2tensor', 11 | 'tensor2img', 12 | 'imfrombytes', 13 | 'imwrite', 14 | 'crop_border', 15 | # logger.py 16 | 'MessageLogger', 17 | 'init_tb_logger', 18 | 'init_wandb_logger', 19 | 'get_root_logger', 20 | 'get_env_info', 21 | # misc.py 22 | 'set_random_seed', 23 | 'get_time_str', 24 | 'mkdir_and_rename', 25 | 'make_exp_dirs', 26 | 'scandir', 27 | 'check_resume', 28 | 'sizeof_fmt' 29 | ] 30 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | Ref: 14 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 15 | Args: 16 | file_id (str): File id. 17 | save_path (str): Save path. 18 | """ 19 | 20 | session = requests.Session() 21 | URL = 'https://docs.google.com/uc?export=download' 22 | params = {'id': file_id} 23 | 24 | response = session.get(URL, params=params, stream=True) 25 | token = get_confirm_token(response) 26 | if token: 27 | params['confirm'] = token 28 | response = session.get(URL, params=params, stream=True) 29 | 30 | # get file size 31 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | print(response_file_size) 33 | if 'Content-Range' in response_file_size.headers: 34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 49 | if file_size is not None: 50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 51 | 52 | readable_file_size = sizeof_fmt(file_size) 53 | else: 54 | pbar = None 55 | 56 | with open(destination, 'wb') as f: 57 | downloaded_size = 0 58 | for chunk in response.iter_content(chunk_size): 59 | downloaded_size += chunk_size 60 | if pbar is not None: 61 | pbar.update(1) 62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | if pbar is not None: 66 | pbar.close() 67 | 68 | 69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 70 | """Load file form http url, will download models if necessary. 71 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 72 | Args: 73 | url (str): URL to be downloaded. 74 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 75 | Default: None. 76 | progress (bool): Whether to show the download progress. Default: True. 77 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 78 | Returns: 79 | str: The path to the downloaded file. 80 | """ 81 | if model_dir is None: # use the pytorch hub_dir 82 | hub_dir = get_dir() 83 | model_dir = os.path.join(hub_dir, 'checkpoints') 84 | 85 | os.makedirs(model_dir, exist_ok=True) 86 | 87 | parts = urlparse(url) 88 | filename = os.path.basename(parts.path) 89 | if file_name is not None: 90 | filename = file_name 91 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 92 | if not os.path.exists(cached_file): 93 | print(f'Downloading: "{url}" to {cached_file}\n') 94 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 95 | return cached_file -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing differnet lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | return cv2.imwrite(file_path, img, params) 152 | 153 | 154 | def crop_border(imgs, crop_border): 155 | """Crop borders of images. 156 | 157 | Args: 158 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 159 | crop_border (int): Crop border for each end of height and weight. 160 | 161 | Returns: 162 | list[ndarray]: Cropped images. 163 | """ 164 | if crop_border == 0: 165 | return imgs 166 | else: 167 | if isinstance(imgs, list): 168 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 169 | else: 170 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 171 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class MessageLogger(): 11 | """Message logger for printing. 12 | Args: 13 | opt (dict): Config. It contains the following keys: 14 | name (str): Exp name. 15 | logger (dict): Contains 'print_freq' (str) for logger interval. 16 | train (dict): Contains 'total_iter' (int) for total iters. 17 | use_tb_logger (bool): Use tensorboard logger. 18 | start_iter (int): Start iter. Default: 1. 19 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 20 | """ 21 | 22 | def __init__(self, opt, start_iter=1, tb_logger=None): 23 | self.exp_name = opt['name'] 24 | self.interval = opt['logger']['print_freq'] 25 | self.start_iter = start_iter 26 | self.max_iters = opt['train']['total_iter'] 27 | self.use_tb_logger = opt['logger']['use_tb_logger'] 28 | self.tb_logger = tb_logger 29 | self.start_time = time.time() 30 | self.logger = get_root_logger() 31 | 32 | @master_only 33 | def __call__(self, log_vars): 34 | """Format logging message. 35 | Args: 36 | log_vars (dict): It contains the following keys: 37 | epoch (int): Epoch number. 38 | iter (int): Current iter. 39 | lrs (list): List for learning rates. 40 | time (float): Iter time. 41 | data_time (float): Data time for each iter. 42 | """ 43 | # epoch, iter, learning rates 44 | epoch = log_vars.pop('epoch') 45 | current_iter = log_vars.pop('iter') 46 | lrs = log_vars.pop('lrs') 47 | 48 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') 49 | for v in lrs: 50 | message += f'{v:.3e},' 51 | message += ')] ' 52 | 53 | # time and estimated time 54 | if 'time' in log_vars.keys(): 55 | iter_time = log_vars.pop('time') 56 | data_time = log_vars.pop('data_time') 57 | 58 | total_time = time.time() - self.start_time 59 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 60 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 61 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 62 | message += f'[eta: {eta_str}, ' 63 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 64 | 65 | # other items, especially losses 66 | for k, v in log_vars.items(): 67 | message += f'{k}: {v:.4e} ' 68 | # tensorboard logger 69 | if self.use_tb_logger: 70 | # if k.startswith('l_'): 71 | # self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 72 | # else: 73 | self.tb_logger.add_scalar(k, v, current_iter) 74 | self.logger.info(message) 75 | 76 | 77 | @master_only 78 | def init_tb_logger(log_dir): 79 | from torch.utils.tensorboard import SummaryWriter 80 | tb_logger = SummaryWriter(log_dir=log_dir) 81 | return tb_logger 82 | 83 | 84 | @master_only 85 | def init_wandb_logger(opt): 86 | """We now only use wandb to sync tensorboard log.""" 87 | import wandb 88 | logger = logging.getLogger('basicsr') 89 | 90 | project = opt['logger']['wandb']['project'] 91 | resume_id = opt['logger']['wandb'].get('resume_id') 92 | if resume_id: 93 | wandb_id = resume_id 94 | resume = 'allow' 95 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 96 | else: 97 | wandb_id = wandb.util.generate_id() 98 | resume = 'never' 99 | 100 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 101 | 102 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 103 | 104 | 105 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 106 | """Get the root logger. 107 | The logger will be initialized if it has not been initialized. By default a 108 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 109 | also be added. 110 | Args: 111 | logger_name (str): root logger name. Default: 'basicsr'. 112 | log_file (str | None): The log filename. If specified, a FileHandler 113 | will be added to the root logger. 114 | log_level (int): The root logger level. Note that only the process of 115 | rank 0 is affected, while other processes will set the level to 116 | "Error" and be silent most of the time. 117 | Returns: 118 | logging.Logger: The root logger. 119 | """ 120 | logger = logging.getLogger(logger_name) 121 | # if the logger has been initialized, just return it 122 | if logger_name in initialized_logger: 123 | return logger 124 | 125 | format_str = '%(asctime)s %(levelname)s: %(message)s' 126 | stream_handler = logging.StreamHandler() 127 | stream_handler.setFormatter(logging.Formatter(format_str)) 128 | logger.addHandler(stream_handler) 129 | logger.propagate = False 130 | rank, _ = get_dist_info() 131 | if rank != 0: 132 | logger.setLevel('ERROR') 133 | elif log_file is not None: 134 | logger.setLevel(log_level) 135 | # add file handler 136 | # file_handler = logging.FileHandler(log_file, 'w') 137 | file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log 138 | file_handler.setFormatter(logging.Formatter(format_str)) 139 | file_handler.setLevel(log_level) 140 | logger.addHandler(file_handler) 141 | initialized_logger[logger_name] = True 142 | return logger 143 | 144 | 145 | def get_env_info(): 146 | """Get environment information. 147 | Currently, only log the software version. 148 | """ 149 | import torch 150 | import torchvision 151 | 152 | from basicsr.version import __version__ 153 | msg = r""" 154 | ____ _ _____ ____ 155 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 156 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 157 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 158 | /_____/ \__,_//____//_/ \___//____//_/ |_| 159 | ______ __ __ __ __ 160 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 161 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 162 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 163 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 164 | """ 165 | msg += ('\nVersion Information: ' 166 | f'\n\tBasicSR: {__version__}' 167 | f'\n\tPyTorch: {torch.__version__}' 168 | f'\n\tTorchVision: {torchvision.__version__}') 169 | return msg -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import time 5 | import torch 6 | import numpy as np 7 | from os import path as osp 8 | 9 | from .dist_util import master_only 10 | from .logger import get_root_logger 11 | 12 | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ 13 | torch.__version__)[0][:3])] >= [1, 12, 0] 14 | 15 | def gpu_is_available(): 16 | if IS_HIGH_VERSION: 17 | if torch.backends.mps.is_available(): 18 | return True 19 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False 20 | 21 | def get_device(gpu_id=None): 22 | if gpu_id is None: 23 | gpu_str = '' 24 | elif isinstance(gpu_id, int): 25 | gpu_str = f':{gpu_id}' 26 | else: 27 | raise TypeError('Input should be int value.') 28 | 29 | if IS_HIGH_VERSION: 30 | if torch.backends.mps.is_available(): 31 | return torch.device('mps'+gpu_str) 32 | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') 33 | 34 | 35 | def set_random_seed(seed): 36 | """Set random seeds.""" 37 | random.seed(seed) 38 | np.random.seed(seed) 39 | torch.manual_seed(seed) 40 | torch.cuda.manual_seed(seed) 41 | torch.cuda.manual_seed_all(seed) 42 | 43 | 44 | def get_time_str(): 45 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 46 | 47 | 48 | def mkdir_and_rename(path): 49 | """mkdirs. If path exists, rename it with timestamp and create a new one. 50 | 51 | Args: 52 | path (str): Folder path. 53 | """ 54 | if osp.exists(path): 55 | new_name = path + '_archived_' + get_time_str() 56 | print(f'Path already exists. Rename it to {new_name}', flush=True) 57 | os.rename(path, new_name) 58 | os.makedirs(path, exist_ok=True) 59 | 60 | 61 | @master_only 62 | def make_exp_dirs(opt): 63 | """Make dirs for experiments.""" 64 | path_opt = opt['path'].copy() 65 | if opt['is_train']: 66 | mkdir_and_rename(path_opt.pop('experiments_root')) 67 | else: 68 | mkdir_and_rename(path_opt.pop('results_root')) 69 | for key, path in path_opt.items(): 70 | if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): 71 | os.makedirs(path, exist_ok=True) 72 | 73 | 74 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 75 | """Scan a directory to find the interested files. 76 | 77 | Args: 78 | dir_path (str): Path of the directory. 79 | suffix (str | tuple(str), optional): File suffix that we are 80 | interested in. Default: None. 81 | recursive (bool, optional): If set to True, recursively scan the 82 | directory. Default: False. 83 | full_path (bool, optional): If set to True, include the dir_path. 84 | Default: False. 85 | 86 | Returns: 87 | A generator for all the interested files with relative pathes. 88 | """ 89 | 90 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 91 | raise TypeError('"suffix" must be a string or tuple of strings') 92 | 93 | root = dir_path 94 | 95 | def _scandir(dir_path, suffix, recursive): 96 | for entry in os.scandir(dir_path): 97 | if not entry.name.startswith('.') and entry.is_file(): 98 | if full_path: 99 | return_path = entry.path 100 | else: 101 | return_path = osp.relpath(entry.path, root) 102 | 103 | if suffix is None: 104 | yield return_path 105 | elif return_path.endswith(suffix): 106 | yield return_path 107 | else: 108 | if recursive: 109 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 110 | else: 111 | continue 112 | 113 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 114 | 115 | 116 | def check_resume(opt, resume_iter): 117 | """Check resume states and pretrain_network paths. 118 | 119 | Args: 120 | opt (dict): Options. 121 | resume_iter (int): Resume iteration. 122 | """ 123 | logger = get_root_logger() 124 | if opt['path']['resume_state']: 125 | # get all the networks 126 | networks = [key for key in opt.keys() if key.startswith('network_')] 127 | flag_pretrain = False 128 | for network in networks: 129 | if opt['path'].get(f'pretrain_{network}') is not None: 130 | flag_pretrain = True 131 | if flag_pretrain: 132 | logger.warning('pretrain_network path will be ignored during resuming.') 133 | # set pretrained model paths 134 | for network in networks: 135 | name = f'pretrain_{network}' 136 | basename = network.replace('network_', '') 137 | if opt['path'].get('ignore_resume_networks') is None or (basename 138 | not in opt['path']['ignore_resume_networks']): 139 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 140 | logger.info(f"Set {name} to {opt['path'][name]}") 141 | 142 | 143 | def sizeof_fmt(size, suffix='B'): 144 | """Get human readable file size. 145 | 146 | Args: 147 | size (int): File size. 148 | suffix (str): Suffix. Default: 'B'. 149 | 150 | Return: 151 | str: Formated file siz. 152 | """ 153 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 154 | if abs(size) < 1024.0: 155 | return f'{size:3.1f} {unit}{suffix}' 156 | size /= 1024.0 157 | return f'{size:3.1f} Y{suffix}' 158 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import time 3 | from collections import OrderedDict 4 | from os import path as osp 5 | from basicsr.utils.misc import get_time_str 6 | 7 | def ordered_yaml(): 8 | """Support OrderedDict for yaml. 9 | 10 | Returns: 11 | yaml Loader and Dumper. 12 | """ 13 | try: 14 | from yaml import CDumper as Dumper 15 | from yaml import CLoader as Loader 16 | except ImportError: 17 | from yaml import Dumper, Loader 18 | 19 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 20 | 21 | def dict_representer(dumper, data): 22 | return dumper.represent_dict(data.items()) 23 | 24 | def dict_constructor(loader, node): 25 | return OrderedDict(loader.construct_pairs(node)) 26 | 27 | Dumper.add_representer(OrderedDict, dict_representer) 28 | Loader.add_constructor(_mapping_tag, dict_constructor) 29 | return Loader, Dumper 30 | 31 | 32 | def parse(opt_path, root_path, is_train=True): 33 | """Parse option file. 34 | 35 | Args: 36 | opt_path (str): Option file path. 37 | is_train (str): Indicate whether in training or not. Default: True. 38 | 39 | Returns: 40 | (dict): Options. 41 | """ 42 | with open(opt_path, mode='r') as f: 43 | Loader, _ = ordered_yaml() 44 | opt = yaml.load(f, Loader=Loader) 45 | 46 | opt['is_train'] = is_train 47 | 48 | # opt['name'] = f"{get_time_str()}_{opt['name']}" 49 | if opt['path'].get('resume_state', None): # Shangchen added 50 | resume_state_path = opt['path'].get('resume_state') 51 | opt['name'] = resume_state_path.split("/")[-3] 52 | else: 53 | opt['name'] = f"{get_time_str()}_{opt['name']}" 54 | 55 | 56 | # datasets 57 | for phase, dataset in opt['datasets'].items(): 58 | # for several datasets, e.g., test_1, test_2 59 | phase = phase.split('_')[0] 60 | dataset['phase'] = phase 61 | if 'scale' in opt: 62 | dataset['scale'] = opt['scale'] 63 | if dataset.get('dataroot_gt') is not None: 64 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 65 | if dataset.get('dataroot_lq') is not None: 66 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 67 | 68 | # paths 69 | for key, val in opt['path'].items(): 70 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 71 | opt['path'][key] = osp.expanduser(val) 72 | 73 | if is_train: 74 | experiments_root = osp.join(root_path, 'experiments', opt['name']) 75 | opt['path']['experiments_root'] = experiments_root 76 | opt['path']['models'] = osp.join(experiments_root, 'models') 77 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 78 | opt['path']['log'] = experiments_root 79 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 80 | 81 | else: # test 82 | results_root = osp.join(root_path, 'results', opt['name']) 83 | opt['path']['results_root'] = results_root 84 | opt['path']['log'] = results_root 85 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 86 | 87 | return opt 88 | 89 | 90 | def dict2str(opt, indent_level=1): 91 | """dict to string for printing options. 92 | 93 | Args: 94 | opt (dict): Option dict. 95 | indent_level (int): Indent level. Default: 1. 96 | 97 | Return: 98 | (str): Option string for printing. 99 | """ 100 | msg = '\n' 101 | for k, v in opt.items(): 102 | if isinstance(v, dict): 103 | msg += ' ' * (indent_level * 2) + k + ':[' 104 | msg += dict2str(v, indent_level + 1) 105 | msg += ' ' * (indent_level * 2) + ']\n' 106 | else: 107 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 108 | return msg 109 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj): 39 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 40 | f"in '{self._name}' registry!") 41 | self._obj_map[name] = obj 42 | 43 | def register(self, obj=None): 44 | """ 45 | Register the given object under the the name `obj.__name__`. 46 | Can be used as either a decorator or not. 47 | See docstring of this class for usage. 48 | """ 49 | if obj is None: 50 | # used as a decorator 51 | def deco(func_or_class): 52 | name = func_or_class.__name__ 53 | self._do_register(name, func_or_class) 54 | return func_or_class 55 | 56 | return deco 57 | 58 | # used as a function call 59 | name = obj.__name__ 60 | self._do_register(name, obj) 61 | 62 | def get(self, name): 63 | ret = self._obj_map.get(name) 64 | if ret is None: 65 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 66 | return ret 67 | 68 | def __contains__(self, name): 69 | return name in self._obj_map 70 | 71 | def __iter__(self): 72 | return iter(self._obj_map.items()) 73 | 74 | def keys(self): 75 | return self._obj_map.keys() 76 | 77 | 78 | DATASET_REGISTRY = Registry('dataset') 79 | ARCH_REGISTRY = Registry('arch') 80 | MODEL_REGISTRY = Registry('model') 81 | LOSS_REGISTRY = Registry('loss') 82 | METRIC_REGISTRY = Registry('metric') 83 | -------------------------------------------------------------------------------- /basicsr/utils/video_util.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The code is modified from the Real-ESRGAN: 3 | https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py 4 | 5 | ''' 6 | import cv2 7 | import sys 8 | import numpy as np 9 | 10 | try: 11 | import ffmpeg 12 | except ImportError: 13 | import pip 14 | pip.main(['install', '--user', 'ffmpeg-python']) 15 | import ffmpeg 16 | 17 | def get_video_meta_info(video_path): 18 | ret = {} 19 | probe = ffmpeg.probe(video_path) 20 | video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] 21 | has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams']) 22 | ret['width'] = video_streams[0]['width'] 23 | ret['height'] = video_streams[0]['height'] 24 | ret['fps'] = eval(video_streams[0]['avg_frame_rate']) 25 | ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None 26 | ret['nb_frames'] = int(video_streams[0]['nb_frames']) 27 | return ret 28 | 29 | class VideoReader: 30 | def __init__(self, video_path): 31 | self.paths = [] # for image&folder type 32 | self.audio = None 33 | try: 34 | self.stream_reader = ( 35 | ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24', 36 | loglevel='error').run_async( 37 | pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) 38 | except FileNotFoundError: 39 | print('Please install ffmpeg (not ffmpeg-python) by running\n', 40 | '\t$ conda install -c conda-forge ffmpeg') 41 | sys.exit(0) 42 | 43 | meta = get_video_meta_info(video_path) 44 | self.width = meta['width'] 45 | self.height = meta['height'] 46 | self.input_fps = meta['fps'] 47 | self.audio = meta['audio'] 48 | self.nb_frames = meta['nb_frames'] 49 | 50 | self.idx = 0 51 | 52 | def get_resolution(self): 53 | return self.height, self.width 54 | 55 | def get_fps(self): 56 | if self.input_fps is not None: 57 | return self.input_fps 58 | return 24 59 | 60 | def get_audio(self): 61 | return self.audio 62 | 63 | def __len__(self): 64 | return self.nb_frames 65 | 66 | def get_frame_from_stream(self): 67 | img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel 68 | if not img_bytes: 69 | return None 70 | img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3]) 71 | return img 72 | 73 | def get_frame_from_list(self): 74 | if self.idx >= self.nb_frames: 75 | return None 76 | img = cv2.imread(self.paths[self.idx]) 77 | self.idx += 1 78 | return img 79 | 80 | def get_frame(self): 81 | return self.get_frame_from_stream() 82 | 83 | 84 | def close(self): 85 | self.stream_reader.stdin.close() 86 | self.stream_reader.wait() 87 | 88 | 89 | class VideoWriter: 90 | def __init__(self, video_save_path, height, width, fps, audio): 91 | if height > 2160: 92 | print('You are generating video that is larger than 4K, which will be very slow due to IO speed.', 93 | 'We highly recommend to decrease the outscale(aka, -s).') 94 | if audio is not None: 95 | self.stream_writer = ( 96 | ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', 97 | framerate=fps).output( 98 | audio, 99 | video_save_path, 100 | pix_fmt='yuv420p', 101 | vcodec='libx264', 102 | loglevel='error', 103 | acodec='copy').overwrite_output().run_async( 104 | pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) 105 | else: 106 | self.stream_writer = ( 107 | ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', 108 | framerate=fps).output( 109 | video_save_path, pix_fmt='yuv420p', vcodec='libx264', 110 | loglevel='error').overwrite_output().run_async( 111 | pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) 112 | 113 | def write_frame(self, frame): 114 | try: 115 | frame = frame.astype(np.uint8).tobytes() 116 | self.stream_writer.stdin.write(frame) 117 | except BrokenPipeError: 118 | print('Please re-install ffmpeg and libx264 by running\n', 119 | '\t$ conda install -c conda-forge ffmpeg\n', 120 | '\t$ conda install -c conda-forge x264') 121 | sys.exit(0) 122 | 123 | def close(self): 124 | self.stream_writer.stdin.close() 125 | self.stream_writer.wait() -------------------------------------------------------------------------------- /docs/history_changelog.md: -------------------------------------------------------------------------------- 1 | # History of Changelog 2 | 3 | - **2023.04.19**: :whale: Training codes and config files are public available now. 4 | - **2023.04.09**: Add features of inpainting and colorization for cropped face images. 5 | - **2023.02.10**: Include `dlib` as a new face detector option, it produces more accurate face identity. 6 | - **2022.10.05**: Support video input `--input_path [YOUR_VIDEO.mp4]`. Try it to enhance your videos! :clapper: 7 | - **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer) 8 | - **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) 9 | - **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement. 10 | - **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement. 11 | - **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. 12 | - **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`. 13 | - **2022.07.17**: Add Colab demo of CodeFormer. google colab logo 14 | - **2022.07.16**: Release inference code for face restoration. :blush: 15 | - **2022.06.21**: This repo is created. -------------------------------------------------------------------------------- /docs/train.md: -------------------------------------------------------------------------------- 1 | # :milky_way: Training Procedures 2 | [English](train.md) **|** [简体中文](train_CN.md) 3 | ## Preparing Dataset 4 | 5 | - Download training dataset: [FFHQ](https://github.com/NVlabs/ffhq-dataset) 6 | 7 | --- 8 | 9 | ## Training 10 | ``` 11 | For PyTorch versions >= 1.10, please replace `python -m torch.distributed.launch` in the commands below with `torchrun`. 12 | ``` 13 | 14 | ### 👾 Stage I - VQGAN 15 | - Training VQGAN: 16 | > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch 17 | 18 | - After VQGAN training, you can pre-calculate code sequence for the training dataset to speed up the later training stages: 19 | > python scripts/generate_latent_gt.py 20 | 21 | - If you don't require training your own VQGAN, you can find pre-trained VQGAN (`vqgan_code1024.pth`) and the corresponding code sequence (`latent_gt_code1024.pth`) in the folder of Releases v0.1.0: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 22 | 23 | ### 🚀 Stage II - CodeFormer (w=0) 24 | - Training Code Sequence Prediction Module: 25 | > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4322 basicsr/train.py -opt options/CodeFormer_stage2.yml --launcher pytorch 26 | 27 | - Pre-trained CodeFormer of stage II (`codeformer_stage2.pth`) can be found in the folder of Releases v0.1.0: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 28 | 29 | ### 🛸 Stage III - CodeFormer (w=1) 30 | - Training Controllable Module: 31 | > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4323 basicsr/train.py -opt options/CodeFormer_stage3.yml --launcher pytorch 32 | 33 | - Pre-trained CodeFormer (`codeformer.pth`) can be found in the folder of Releases v0.1.0: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 34 | 35 | --- 36 | 37 | :whale: The project was built using the framework [BasicSR](https://github.com/XPixelGroup/BasicSR). For detailed information on training, resuming, and other related topics, please refer to the documentation: https://github.com/XPixelGroup/BasicSR/blob/master/docs/TrainTest.md 38 | -------------------------------------------------------------------------------- /docs/train_CN.md: -------------------------------------------------------------------------------- 1 | # :milky_way: 训练文档 2 | [English](train.md) **|** [简体中文](train_CN.md) 3 | 4 | ## 准备数据集 5 | - 下载训练数据集: [FFHQ](https://github.com/NVlabs/ffhq-dataset) 6 | 7 | --- 8 | 9 | ## 训练 10 | ``` 11 | 对于PyTorch版本 >= 1.10, 请将下面命令中的`python -m torch.distributed.launch`替换为`torchrun`. 12 | ``` 13 | 14 | ### 👾 阶段 I - VQGAN 15 | - 训练VQGAN: 16 | > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4321 basicsr/train.py -opt options/VQGAN_512_ds32_nearest_stage1.yml --launcher pytorch 17 | 18 | - 训练完VQGAN后,可以通过下面代码预先获得训练数据集的密码本序列,从而加速后面阶段的训练过程: 19 | > python scripts/generate_latent_gt.py 20 | 21 | - 如果你不需要训练自己的VQGAN,可以在Release v0.1.0文档中找到预训练的VQGAN (`vqgan_code1024.pth`)和对应的密码本序列 (`latent_gt_code1024.pth`): https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 22 | 23 | ### 🚀 阶段 II - CodeFormer (w=0) 24 | - 训练密码本训练预测模块: 25 | > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4322 basicsr/train.py -opt options/CodeFormer_stage2.yml --launcher pytorch 26 | 27 | - 预训练CodeFormer第二阶段模型 (`codeformer_stage2.pth`)可以在Releases v0.1.0文档里下载: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 28 | 29 | ### 🛸 阶段 III - CodeFormer (w=1) 30 | - 训练可调模块: 31 | > python -m torch.distributed.launch --nproc_per_node=gpu_num --master_port=4323 basicsr/train.py -opt options/CodeFormer_stage3.yml --launcher pytorch 32 | 33 | - 预训练CodeFormer模型 (`codeformer.pth`)可以在Releases v0.1.0文档里下载: https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0 34 | 35 | --- 36 | 37 | :whale: 该项目是基于[BasicSR](https://github.com/XPixelGroup/BasicSR)框架搭建,有关训练、Resume等详细介绍可以查看文档: https://github.com/XPixelGroup/BasicSR/blob/master/docs/TrainTest_CN.md -------------------------------------------------------------------------------- /facelib/detection/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from copy import deepcopy 5 | 6 | from facelib.utils import load_file_from_url 7 | from facelib.utils import download_pretrained_models 8 | from facelib.detection.yolov5face.models.common import Conv 9 | 10 | from .retinaface.retinaface import RetinaFace 11 | from .yolov5face.face_detector import YoloDetector 12 | 13 | 14 | def init_detection_model(model_name, half=False, device='cuda'): 15 | if 'retinaface' in model_name: 16 | model = init_retinaface_model(model_name, half, device) 17 | elif 'YOLOv5' in model_name: 18 | model = init_yolov5face_model(model_name, device) 19 | else: 20 | raise NotImplementedError(f'{model_name} is not implemented.') 21 | 22 | return model 23 | 24 | 25 | def init_retinaface_model(model_name, half=False, device='cuda'): 26 | if model_name == 'retinaface_resnet50': 27 | model = RetinaFace(network_name='resnet50', half=half) 28 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth' 29 | elif model_name == 'retinaface_mobile0.25': 30 | model = RetinaFace(network_name='mobile0.25', half=half) 31 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' 32 | else: 33 | raise NotImplementedError(f'{model_name} is not implemented.') 34 | 35 | model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) 36 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 37 | # remove unnecessary 'module.' 38 | for k, v in deepcopy(load_net).items(): 39 | if k.startswith('module.'): 40 | load_net[k[7:]] = v 41 | load_net.pop(k) 42 | model.load_state_dict(load_net, strict=True) 43 | model.eval() 44 | model = model.to(device) 45 | 46 | return model 47 | 48 | 49 | def init_yolov5face_model(model_name, device='cuda'): 50 | if model_name == 'YOLOv5l': 51 | model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) 52 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth' 53 | elif model_name == 'YOLOv5n': 54 | model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) 55 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth' 56 | else: 57 | raise NotImplementedError(f'{model_name} is not implemented.') 58 | 59 | model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) 60 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 61 | model.detector.load_state_dict(load_net, strict=True) 62 | model.detector.eval() 63 | model.detector = model.detector.to(device).float() 64 | 65 | for m in model.detector.modules(): 66 | if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 67 | m.inplace = True # pytorch 1.7.0 compatibility 68 | elif isinstance(m, Conv): 69 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 70 | 71 | return model 72 | 73 | 74 | # Download from Google Drive 75 | # def init_yolov5face_model(model_name, device='cuda'): 76 | # if model_name == 'YOLOv5l': 77 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) 78 | # f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} 79 | # elif model_name == 'YOLOv5n': 80 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) 81 | # f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} 82 | # else: 83 | # raise NotImplementedError(f'{model_name} is not implemented.') 84 | 85 | # model_path = os.path.join('weights/facelib', list(f_id.keys())[0]) 86 | # if not os.path.exists(model_path): 87 | # download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib') 88 | 89 | # load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 90 | # model.detector.load_state_dict(load_net, strict=True) 91 | # model.detector.eval() 92 | # model.detector = model.detector.to(device).float() 93 | 94 | # for m in model.detector.modules(): 95 | # if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 96 | # m.inplace = True # pytorch 1.7.0 compatibility 97 | # elif isinstance(m, Conv): 98 | # m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 99 | 100 | # return model -------------------------------------------------------------------------------- /facelib/detection/retinaface/retinaface_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv_bn(inp, oup, stride=1, leaky=0): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), 9 | nn.LeakyReLU(negative_slope=leaky, inplace=True)) 10 | 11 | 12 | def conv_bn_no_relu(inp, oup, stride): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | nn.BatchNorm2d(oup), 16 | ) 17 | 18 | 19 | def conv_bn1X1(inp, oup, stride, leaky=0): 20 | return nn.Sequential( 21 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), 22 | nn.LeakyReLU(negative_slope=leaky, inplace=True)) 23 | 24 | 25 | def conv_dw(inp, oup, stride, leaky=0.1): 26 | return nn.Sequential( 27 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 28 | nn.BatchNorm2d(inp), 29 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 30 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 31 | nn.BatchNorm2d(oup), 32 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 33 | ) 34 | 35 | 36 | class SSH(nn.Module): 37 | 38 | def __init__(self, in_channel, out_channel): 39 | super(SSH, self).__init__() 40 | assert out_channel % 4 == 0 41 | leaky = 0 42 | if (out_channel <= 64): 43 | leaky = 0.1 44 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) 45 | 46 | self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) 47 | self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 48 | 49 | self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) 50 | self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 51 | 52 | def forward(self, input): 53 | conv3X3 = self.conv3X3(input) 54 | 55 | conv5X5_1 = self.conv5X5_1(input) 56 | conv5X5 = self.conv5X5_2(conv5X5_1) 57 | 58 | conv7X7_2 = self.conv7X7_2(conv5X5_1) 59 | conv7X7 = self.conv7x7_3(conv7X7_2) 60 | 61 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) 62 | out = F.relu(out) 63 | return out 64 | 65 | 66 | class FPN(nn.Module): 67 | 68 | def __init__(self, in_channels_list, out_channels): 69 | super(FPN, self).__init__() 70 | leaky = 0 71 | if (out_channels <= 64): 72 | leaky = 0.1 73 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) 74 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) 75 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) 76 | 77 | self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) 78 | self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) 79 | 80 | def forward(self, input): 81 | # names = list(input.keys()) 82 | # input = list(input.values()) 83 | 84 | output1 = self.output1(input[0]) 85 | output2 = self.output2(input[1]) 86 | output3 = self.output3(input[2]) 87 | 88 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') 89 | output2 = output2 + up3 90 | output2 = self.merge2(output2) 91 | 92 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') 93 | output1 = output1 + up2 94 | output1 = self.merge1(output1) 95 | 96 | out = [output1, output2, output3] 97 | return out 98 | 99 | 100 | class MobileNetV1(nn.Module): 101 | 102 | def __init__(self): 103 | super(MobileNetV1, self).__init__() 104 | self.stage1 = nn.Sequential( 105 | conv_bn(3, 8, 2, leaky=0.1), # 3 106 | conv_dw(8, 16, 1), # 7 107 | conv_dw(16, 32, 2), # 11 108 | conv_dw(32, 32, 1), # 19 109 | conv_dw(32, 64, 2), # 27 110 | conv_dw(64, 64, 1), # 43 111 | ) 112 | self.stage2 = nn.Sequential( 113 | conv_dw(64, 128, 2), # 43 + 16 = 59 114 | conv_dw(128, 128, 1), # 59 + 32 = 91 115 | conv_dw(128, 128, 1), # 91 + 32 = 123 116 | conv_dw(128, 128, 1), # 123 + 32 = 155 117 | conv_dw(128, 128, 1), # 155 + 32 = 187 118 | conv_dw(128, 128, 1), # 187 + 32 = 219 119 | ) 120 | self.stage3 = nn.Sequential( 121 | conv_dw(128, 256, 2), # 219 +3 2 = 241 122 | conv_dw(256, 256, 1), # 241 + 64 = 301 123 | ) 124 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 125 | self.fc = nn.Linear(256, 1000) 126 | 127 | def forward(self, x): 128 | x = self.stage1(x) 129 | x = self.stage2(x) 130 | x = self.stage3(x) 131 | x = self.avg(x) 132 | # x = self.model(x) 133 | x = x.view(-1, 256) 134 | x = self.fc(x) 135 | return x 136 | 137 | 138 | class ClassHead(nn.Module): 139 | 140 | def __init__(self, inchannels=512, num_anchors=3): 141 | super(ClassHead, self).__init__() 142 | self.num_anchors = num_anchors 143 | self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) 144 | 145 | def forward(self, x): 146 | out = self.conv1x1(x) 147 | out = out.permute(0, 2, 3, 1).contiguous() 148 | 149 | return out.view(out.shape[0], -1, 2) 150 | 151 | 152 | class BboxHead(nn.Module): 153 | 154 | def __init__(self, inchannels=512, num_anchors=3): 155 | super(BboxHead, self).__init__() 156 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) 157 | 158 | def forward(self, x): 159 | out = self.conv1x1(x) 160 | out = out.permute(0, 2, 3, 1).contiguous() 161 | 162 | return out.view(out.shape[0], -1, 4) 163 | 164 | 165 | class LandmarkHead(nn.Module): 166 | 167 | def __init__(self, inchannels=512, num_anchors=3): 168 | super(LandmarkHead, self).__init__() 169 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) 170 | 171 | def forward(self, x): 172 | out = self.conv1x1(x) 173 | out = out.permute(0, 2, 3, 1).contiguous() 174 | 175 | return out.view(out.shape[0], -1, 10) 176 | 177 | 178 | def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): 179 | classhead = nn.ModuleList() 180 | for i in range(fpn_num): 181 | classhead.append(ClassHead(inchannels, anchor_num)) 182 | return classhead 183 | 184 | 185 | def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): 186 | bboxhead = nn.ModuleList() 187 | for i in range(fpn_num): 188 | bboxhead.append(BboxHead(inchannels, anchor_num)) 189 | return bboxhead 190 | 191 | 192 | def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): 193 | landmarkhead = nn.ModuleList() 194 | for i in range(fpn_num): 195 | landmarkhead.append(LandmarkHead(inchannels, anchor_num)) 196 | return landmarkhead 197 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/facelib/detection/yolov5face/__init__.py -------------------------------------------------------------------------------- /facelib/detection/yolov5face/face_detector.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import copy 3 | import re 4 | import torch 5 | import numpy as np 6 | 7 | from pathlib import Path 8 | from facelib.detection.yolov5face.models.yolo import Model 9 | from facelib.detection.yolov5face.utils.datasets import letterbox 10 | from facelib.detection.yolov5face.utils.general import ( 11 | check_img_size, 12 | non_max_suppression_face, 13 | scale_coords, 14 | scale_coords_landmarks, 15 | ) 16 | 17 | # IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9) 18 | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ 19 | torch.__version__)[0][:3])] >= [1, 9, 0] 20 | 21 | 22 | def isListempty(inList): 23 | if isinstance(inList, list): # Is a list 24 | return all(map(isListempty, inList)) 25 | return False # Not a list 26 | 27 | class YoloDetector: 28 | def __init__( 29 | self, 30 | config_name, 31 | min_face=10, 32 | target_size=None, 33 | device='cuda', 34 | ): 35 | """ 36 | config_name: name of .yaml config with network configuration from models/ folder. 37 | min_face : minimal face size in pixels. 38 | target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. 39 | None for original resolution. 40 | """ 41 | self._class_path = Path(__file__).parent.absolute() 42 | self.target_size = target_size 43 | self.min_face = min_face 44 | self.detector = Model(cfg=config_name) 45 | self.device = device 46 | 47 | 48 | def _preprocess(self, imgs): 49 | """ 50 | Preprocessing image before passing through the network. Resize and conversion to torch tensor. 51 | """ 52 | pp_imgs = [] 53 | for img in imgs: 54 | h0, w0 = img.shape[:2] # orig hw 55 | if self.target_size: 56 | r = self.target_size / min(h0, w0) # resize image to img_size 57 | if r < 1: 58 | img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) 59 | 60 | imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size 61 | img = letterbox(img, new_shape=imgsz)[0] 62 | pp_imgs.append(img) 63 | pp_imgs = np.array(pp_imgs) 64 | pp_imgs = pp_imgs.transpose(0, 3, 1, 2) 65 | pp_imgs = torch.from_numpy(pp_imgs).to(self.device) 66 | pp_imgs = pp_imgs.float() # uint8 to fp16/32 67 | return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 68 | 69 | def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): 70 | """ 71 | Postprocessing of raw pytorch model output. 72 | Returns: 73 | bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. 74 | points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). 75 | """ 76 | bboxes = [[] for _ in range(len(origimgs))] 77 | landmarks = [[] for _ in range(len(origimgs))] 78 | 79 | pred = non_max_suppression_face(pred, conf_thres, iou_thres) 80 | 81 | for image_id, origimg in enumerate(origimgs): 82 | img_shape = origimg.shape 83 | image_height, image_width = img_shape[:2] 84 | gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh 85 | gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks 86 | det = pred[image_id].cpu() 87 | scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() 88 | scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() 89 | 90 | for j in range(det.size()[0]): 91 | box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() 92 | box = list( 93 | map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height]) 94 | ) 95 | if box[3] - box[1] < self.min_face: 96 | continue 97 | lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() 98 | lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) 99 | lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] 100 | bboxes[image_id].append(box) 101 | landmarks[image_id].append(lm) 102 | return bboxes, landmarks 103 | 104 | def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): 105 | """ 106 | Get bbox coordinates and keypoints of faces on original image. 107 | Params: 108 | imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) 109 | conf_thres: confidence threshold for each prediction 110 | iou_thres: threshold for NMS (filter of intersecting bboxes) 111 | Returns: 112 | bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. 113 | points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). 114 | """ 115 | # Pass input images through face detector 116 | images = imgs if isinstance(imgs, list) else [imgs] 117 | images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] 118 | origimgs = copy.deepcopy(images) 119 | 120 | images = self._preprocess(images) 121 | 122 | if IS_HIGH_VERSION: 123 | with torch.inference_mode(): # for pytorch>=1.9 124 | pred = self.detector(images)[0] 125 | else: 126 | with torch.no_grad(): # for pytorch<1.9 127 | pred = self.detector(images)[0] 128 | 129 | bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) 130 | 131 | # return bboxes, points 132 | if not isListempty(points): 133 | bboxes = np.array(bboxes).reshape(-1,4) 134 | points = np.array(points).reshape(-1,10) 135 | padding = bboxes[:,0].reshape(-1,1) 136 | return np.concatenate((bboxes, padding, points), axis=1) 137 | else: 138 | return None 139 | 140 | def __call__(self, *args): 141 | return self.predict(*args) 142 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/facelib/detection/yolov5face/models/__init__.py -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/experimental.py: -------------------------------------------------------------------------------- 1 | # # This file contains experimental modules 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | from facelib.detection.yolov5face.models.common import Conv 8 | 9 | 10 | class CrossConv(nn.Module): 11 | # Cross Convolution Downsample 12 | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): 13 | # ch_in, ch_out, kernel, stride, groups, expansion, shortcut 14 | super().__init__() 15 | c_ = int(c2 * e) # hidden channels 16 | self.cv1 = Conv(c1, c_, (1, k), (1, s)) 17 | self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) 18 | self.add = shortcut and c1 == c2 19 | 20 | def forward(self, x): 21 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) 22 | 23 | 24 | class MixConv2d(nn.Module): 25 | # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 26 | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): 27 | super().__init__() 28 | groups = len(k) 29 | if equal_ch: # equal c_ per group 30 | i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices 31 | c_ = [(i == g).sum() for g in range(groups)] # intermediate channels 32 | else: # equal weight.numel() per group 33 | b = [c2] + [0] * groups 34 | a = np.eye(groups + 1, groups, k=-1) 35 | a -= np.roll(a, 1, axis=1) 36 | a *= np.array(k) ** 2 37 | a[0] = 1 38 | c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b 39 | 40 | self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) 41 | self.bn = nn.BatchNorm2d(c2) 42 | self.act = nn.LeakyReLU(0.1, inplace=True) 43 | 44 | def forward(self, x): 45 | return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) 46 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/yolov5l.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 1 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [4,5, 8,10, 13,16] # P3/8 9 | - [23,29, 43,55, 73,105] # P4/16 10 | - [146,217, 231,300, 335,433] # P5/32 11 | 12 | # YOLOv5 backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 16 | [-1, 3, C3, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 18 | [-1, 9, C3, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 20 | [-1, 9, C3, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 22 | [-1, 1, SPP, [1024, [3,5,7]]], 23 | [-1, 3, C3, [1024, False]], # 8 24 | ] 25 | 26 | # YOLOv5 head 27 | head: 28 | [[-1, 1, Conv, [512, 1, 1]], 29 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 30 | [[-1, 5], 1, Concat, [1]], # cat backbone P4 31 | [-1, 3, C3, [512, False]], # 12 32 | 33 | [-1, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 3], 1, Concat, [1]], # cat backbone P3 36 | [-1, 3, C3, [256, False]], # 16 (P3/8-small) 37 | 38 | [-1, 1, Conv, [256, 3, 2]], 39 | [[-1, 13], 1, Concat, [1]], # cat head P4 40 | [-1, 3, C3, [512, False]], # 19 (P4/16-medium) 41 | 42 | [-1, 1, Conv, [512, 3, 2]], 43 | [[-1, 9], 1, Concat, [1]], # cat head P5 44 | [-1, 3, C3, [1024, False]], # 22 (P5/32-large) 45 | 46 | [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) 47 | ] -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/yolov5n.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 1 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [4,5, 8,10, 13,16] # P3/8 9 | - [23,29, 43,55, 73,105] # P4/16 10 | - [146,217, 231,300, 335,433] # P5/32 11 | 12 | # YOLOv5 backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 16 | [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 17 | [-1, 3, ShuffleV2Block, [128, 1]], # 2 18 | [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 19 | [-1, 7, ShuffleV2Block, [256, 1]], # 4 20 | [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 21 | [-1, 3, ShuffleV2Block, [512, 1]], # 6 22 | ] 23 | 24 | # YOLOv5 head 25 | head: 26 | [[-1, 1, Conv, [128, 1, 1]], 27 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 28 | [[-1, 4], 1, Concat, [1]], # cat backbone P4 29 | [-1, 1, C3, [128, False]], # 10 30 | 31 | [-1, 1, Conv, [128, 1, 1]], 32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 33 | [[-1, 2], 1, Concat, [1]], # cat backbone P3 34 | [-1, 1, C3, [128, False]], # 14 (P3/8-small) 35 | 36 | [-1, 1, Conv, [128, 3, 2]], 37 | [[-1, 11], 1, Concat, [1]], # cat head P4 38 | [-1, 1, C3, [128, False]], # 17 (P4/16-medium) 39 | 40 | [-1, 1, Conv, [128, 3, 2]], 41 | [[-1, 7], 1, Concat, [1]], # cat head P5 42 | [-1, 1, C3, [128, False]], # 20 (P5/32-large) 43 | 44 | [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) 45 | ] 46 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/facelib/detection/yolov5face/utils/__init__.py -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/autoanchor.py: -------------------------------------------------------------------------------- 1 | # Auto-anchor utils 2 | 3 | 4 | def check_anchor_order(m): 5 | # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary 6 | a = m.anchor_grid.prod(-1).view(-1) # anchor area 7 | da = a[-1] - a[0] # delta a 8 | ds = m.stride[-1] - m.stride[0] # delta s 9 | if da.sign() != ds.sign(): # same order 10 | print("Reversing anchor order") 11 | m.anchors[:] = m.anchors.flip(0) 12 | m.anchor_grid[:] = m.anchor_grid.flip(0) 13 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): 6 | # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 7 | shape = img.shape[:2] # current shape [height, width] 8 | if isinstance(new_shape, int): 9 | new_shape = (new_shape, new_shape) 10 | 11 | # Scale ratio (new / old) 12 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 13 | if not scaleup: # only scale down, do not scale up (for better test mAP) 14 | r = min(r, 1.0) 15 | 16 | # Compute padding 17 | ratio = r, r # width, height ratios 18 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 19 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding 20 | if auto: # minimum rectangle 21 | dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding 22 | elif scale_fill: # stretch 23 | dw, dh = 0.0, 0.0 24 | new_unpad = (new_shape[1], new_shape[0]) 25 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios 26 | 27 | dw /= 2 # divide padding into 2 sides 28 | dh /= 2 29 | 30 | if shape[::-1] != new_unpad: # resize 31 | img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) 32 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 33 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 34 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border 35 | return img, ratio, (dw, dh) 36 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/extract_ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | sys.path.insert(0,'./facelib/detection/yolov5face') 4 | model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model'] 5 | torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth') -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def fuse_conv_and_bn(conv, bn): 6 | # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ 7 | fusedconv = ( 8 | nn.Conv2d( 9 | conv.in_channels, 10 | conv.out_channels, 11 | kernel_size=conv.kernel_size, 12 | stride=conv.stride, 13 | padding=conv.padding, 14 | groups=conv.groups, 15 | bias=True, 16 | ) 17 | .requires_grad_(False) 18 | .to(conv.weight.device) 19 | ) 20 | 21 | # prepare filters 22 | w_conv = conv.weight.clone().view(conv.out_channels, -1) 23 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) 24 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) 25 | 26 | # prepare spatial bias 27 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias 28 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) 29 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) 30 | 31 | return fusedconv 32 | 33 | 34 | def copy_attr(a, b, include=(), exclude=()): 35 | # Copy attributes from b to a, options to only include [...] and to exclude [...] 36 | for k, v in b.__dict__.items(): 37 | if (include and k not in include) or k.startswith("_") or k in exclude: 38 | continue 39 | 40 | setattr(a, k, v) 41 | -------------------------------------------------------------------------------- /facelib/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from facelib.utils import load_file_from_url 4 | from .bisenet import BiSeNet 5 | from .parsenet import ParseNet 6 | 7 | 8 | def init_parsing_model(model_name='bisenet', half=False, device='cuda'): 9 | if model_name == 'bisenet': 10 | model = BiSeNet(num_class=19) 11 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' 12 | elif model_name == 'parsenet': 13 | model = ParseNet(in_size=512, out_size=512, parsing_ch=19) 14 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' 15 | else: 16 | raise NotImplementedError(f'{model_name} is not implemented.') 17 | 18 | model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) 19 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 20 | model.load_state_dict(load_net, strict=True) 21 | model.eval() 22 | model = model.to(device) 23 | return model 24 | -------------------------------------------------------------------------------- /facelib/parsing/bisenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .resnet import ResNet18 6 | 7 | 8 | class ConvBNReLU(nn.Module): 9 | 10 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): 11 | super(ConvBNReLU, self).__init__() 12 | self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) 13 | self.bn = nn.BatchNorm2d(out_chan) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = F.relu(self.bn(x)) 18 | return x 19 | 20 | 21 | class BiSeNetOutput(nn.Module): 22 | 23 | def __init__(self, in_chan, mid_chan, num_class): 24 | super(BiSeNetOutput, self).__init__() 25 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 26 | self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) 27 | 28 | def forward(self, x): 29 | feat = self.conv(x) 30 | out = self.conv_out(feat) 31 | return out, feat 32 | 33 | 34 | class AttentionRefinementModule(nn.Module): 35 | 36 | def __init__(self, in_chan, out_chan): 37 | super(AttentionRefinementModule, self).__init__() 38 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 39 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) 40 | self.bn_atten = nn.BatchNorm2d(out_chan) 41 | self.sigmoid_atten = nn.Sigmoid() 42 | 43 | def forward(self, x): 44 | feat = self.conv(x) 45 | atten = F.avg_pool2d(feat, feat.size()[2:]) 46 | atten = self.conv_atten(atten) 47 | atten = self.bn_atten(atten) 48 | atten = self.sigmoid_atten(atten) 49 | out = torch.mul(feat, atten) 50 | return out 51 | 52 | 53 | class ContextPath(nn.Module): 54 | 55 | def __init__(self): 56 | super(ContextPath, self).__init__() 57 | self.resnet = ResNet18() 58 | self.arm16 = AttentionRefinementModule(256, 128) 59 | self.arm32 = AttentionRefinementModule(512, 128) 60 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 61 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 62 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 63 | 64 | def forward(self, x): 65 | feat8, feat16, feat32 = self.resnet(x) 66 | h8, w8 = feat8.size()[2:] 67 | h16, w16 = feat16.size()[2:] 68 | h32, w32 = feat32.size()[2:] 69 | 70 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 71 | avg = self.conv_avg(avg) 72 | avg_up = F.interpolate(avg, (h32, w32), mode='nearest') 73 | 74 | feat32_arm = self.arm32(feat32) 75 | feat32_sum = feat32_arm + avg_up 76 | feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') 77 | feat32_up = self.conv_head32(feat32_up) 78 | 79 | feat16_arm = self.arm16(feat16) 80 | feat16_sum = feat16_arm + feat32_up 81 | feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') 82 | feat16_up = self.conv_head16(feat16_up) 83 | 84 | return feat8, feat16_up, feat32_up # x8, x8, x16 85 | 86 | 87 | class FeatureFusionModule(nn.Module): 88 | 89 | def __init__(self, in_chan, out_chan): 90 | super(FeatureFusionModule, self).__init__() 91 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 92 | self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) 93 | self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.sigmoid = nn.Sigmoid() 96 | 97 | def forward(self, fsp, fcp): 98 | fcat = torch.cat([fsp, fcp], dim=1) 99 | feat = self.convblk(fcat) 100 | atten = F.avg_pool2d(feat, feat.size()[2:]) 101 | atten = self.conv1(atten) 102 | atten = self.relu(atten) 103 | atten = self.conv2(atten) 104 | atten = self.sigmoid(atten) 105 | feat_atten = torch.mul(feat, atten) 106 | feat_out = feat_atten + feat 107 | return feat_out 108 | 109 | 110 | class BiSeNet(nn.Module): 111 | 112 | def __init__(self, num_class): 113 | super(BiSeNet, self).__init__() 114 | self.cp = ContextPath() 115 | self.ffm = FeatureFusionModule(256, 256) 116 | self.conv_out = BiSeNetOutput(256, 256, num_class) 117 | self.conv_out16 = BiSeNetOutput(128, 64, num_class) 118 | self.conv_out32 = BiSeNetOutput(128, 64, num_class) 119 | 120 | def forward(self, x, return_feat=False): 121 | h, w = x.size()[2:] 122 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature 123 | feat_sp = feat_res8 # replace spatial path feature with res3b1 feature 124 | feat_fuse = self.ffm(feat_sp, feat_cp8) 125 | 126 | out, feat = self.conv_out(feat_fuse) 127 | out16, feat16 = self.conv_out16(feat_cp8) 128 | out32, feat32 = self.conv_out32(feat_cp16) 129 | 130 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) 131 | out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) 132 | out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) 133 | 134 | if return_feat: 135 | feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) 136 | feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) 137 | feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) 138 | return out, out16, out32, feat, feat16, feat32 139 | else: 140 | return out, out16, out32 141 | -------------------------------------------------------------------------------- /facelib/parsing/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | 12 | def __init__(self, in_chan, out_chan, stride=1): 13 | super(BasicBlock, self).__init__() 14 | self.conv1 = conv3x3(in_chan, out_chan, stride) 15 | self.bn1 = nn.BatchNorm2d(out_chan) 16 | self.conv2 = conv3x3(out_chan, out_chan) 17 | self.bn2 = nn.BatchNorm2d(out_chan) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = None 20 | if in_chan != out_chan or stride != 1: 21 | self.downsample = nn.Sequential( 22 | nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(out_chan), 24 | ) 25 | 26 | def forward(self, x): 27 | residual = self.conv1(x) 28 | residual = F.relu(self.bn1(residual)) 29 | residual = self.conv2(residual) 30 | residual = self.bn2(residual) 31 | 32 | shortcut = x 33 | if self.downsample is not None: 34 | shortcut = self.downsample(x) 35 | 36 | out = shortcut + residual 37 | out = self.relu(out) 38 | return out 39 | 40 | 41 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 42 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 43 | for i in range(bnum - 1): 44 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 45 | return nn.Sequential(*layers) 46 | 47 | 48 | class ResNet18(nn.Module): 49 | 50 | def __init__(self): 51 | super(ResNet18, self).__init__() 52 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 53 | self.bn1 = nn.BatchNorm2d(64) 54 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 55 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 56 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 57 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 58 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = F.relu(self.bn1(x)) 63 | x = self.maxpool(x) 64 | 65 | x = self.layer1(x) 66 | feat8 = self.layer2(x) # 1/8 67 | feat16 = self.layer3(feat8) # 1/16 68 | feat32 = self.layer4(feat16) # 1/32 69 | return feat8, feat16, feat32 70 | -------------------------------------------------------------------------------- /facelib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back 2 | from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir 3 | 4 | __all__ = [ 5 | 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 6 | 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' 7 | ] 8 | -------------------------------------------------------------------------------- /inference_colorization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import glob 5 | import torch 6 | from torchvision.transforms.functional import normalize 7 | from basicsr.utils import imwrite, img2tensor, tensor2img 8 | from basicsr.utils.download_util import load_file_from_url 9 | from basicsr.utils.misc import get_device 10 | from basicsr.utils.registry import ARCH_REGISTRY 11 | 12 | pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer_colorization.pth' 13 | 14 | if __name__ == '__main__': 15 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | device = get_device() 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('-i', '--input_path', type=str, default='./inputs/gray_faces', 20 | help='Input image or folder. Default: inputs/gray_faces') 21 | parser.add_argument('-o', '--output_path', type=str, default=None, 22 | help='Output folder. Default: results/') 23 | parser.add_argument('--suffix', type=str, default=None, 24 | help='Suffix of the restored faces. Default: None') 25 | args = parser.parse_args() 26 | 27 | # ------------------------ input & output ------------------------ 28 | print('[NOTE] The input face images should be aligned and cropped to a resolution of 512x512.') 29 | if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path 30 | input_img_list = [args.input_path] 31 | result_root = f'results/test_colorization_img' 32 | else: # input img folder 33 | if args.input_path.endswith('/'): # solve when path ends with / 34 | args.input_path = args.input_path[:-1] 35 | # scan all the jpg and png images 36 | input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]'))) 37 | result_root = f'results/{os.path.basename(args.input_path)}' 38 | 39 | if not args.output_path is None: # set output path 40 | result_root = args.output_path 41 | 42 | test_img_num = len(input_img_list) 43 | 44 | # ------------------ set up CodeFormer restorer ------------------- 45 | net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, 46 | connect_list=['32', '64', '128']).to(device) 47 | 48 | # ckpt_path = 'weights/CodeFormer/codeformer.pth' 49 | ckpt_path = load_file_from_url(url=pretrain_model_url, 50 | model_dir='weights/CodeFormer', progress=True, file_name=None) 51 | checkpoint = torch.load(ckpt_path)['params_ema'] 52 | net.load_state_dict(checkpoint) 53 | net.eval() 54 | 55 | # -------------------- start to processing --------------------- 56 | for i, img_path in enumerate(input_img_list): 57 | img_name = os.path.basename(img_path) 58 | basename, ext = os.path.splitext(img_name) 59 | print(f'[{i+1}/{test_img_num}] Processing: {img_name}') 60 | input_face = cv2.imread(img_path) 61 | assert input_face.shape[:2] == (512, 512), 'Input resolution must be 512x512 for colorization.' 62 | # input_face = cv2.resize(input_face, (512, 512), interpolation=cv2.INTER_LINEAR) 63 | input_face = img2tensor(input_face / 255., bgr2rgb=True, float32=True) 64 | normalize(input_face, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 65 | input_face = input_face.unsqueeze(0).to(device) 66 | try: 67 | with torch.no_grad(): 68 | # w is fixed to 0 since we didn't train the Stage III for colorization 69 | output_face = net(input_face, w=0, adain=True)[0] 70 | save_face = tensor2img(output_face, rgb2bgr=True, min_max=(-1, 1)) 71 | del output_face 72 | torch.cuda.empty_cache() 73 | except Exception as error: 74 | print(f'\tFailed inference for CodeFormer: {error}') 75 | save_face = tensor2img(input_face, rgb2bgr=True, min_max=(-1, 1)) 76 | 77 | save_face = save_face.astype('uint8') 78 | 79 | # save face 80 | if args.suffix is not None: 81 | basename = f'{basename}_{args.suffix}' 82 | save_restore_path = os.path.join(result_root, f'{basename}.png') 83 | imwrite(save_face, save_restore_path) 84 | 85 | print(f'\nAll results are saved in {result_root}') 86 | 87 | -------------------------------------------------------------------------------- /inference_inpainting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import glob 5 | import torch 6 | from torchvision.transforms.functional import normalize 7 | from basicsr.utils import imwrite, img2tensor, tensor2img 8 | from basicsr.utils.download_util import load_file_from_url 9 | from basicsr.utils.misc import get_device 10 | from basicsr.utils.registry import ARCH_REGISTRY 11 | 12 | pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer_inpainting.pth' 13 | 14 | if __name__ == '__main__': 15 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | device = get_device() 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('-i', '--input_path', type=str, default='./inputs/masked_faces', 20 | help='Input image or folder. Default: inputs/masked_faces') 21 | parser.add_argument('-o', '--output_path', type=str, default=None, 22 | help='Output folder. Default: results/') 23 | parser.add_argument('--suffix', type=str, default=None, 24 | help='Suffix of the restored faces. Default: None') 25 | args = parser.parse_args() 26 | 27 | # ------------------------ input & output ------------------------ 28 | print('[NOTE] The input face images should be aligned and cropped to a resolution of 512x512.') 29 | if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path 30 | input_img_list = [args.input_path] 31 | result_root = f'results/test_inpainting_img' 32 | else: # input img folder 33 | if args.input_path.endswith('/'): # solve when path ends with / 34 | args.input_path = args.input_path[:-1] 35 | # scan all the jpg and png images 36 | input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]'))) 37 | result_root = f'results/{os.path.basename(args.input_path)}' 38 | 39 | if not args.output_path is None: # set output path 40 | result_root = args.output_path 41 | 42 | test_img_num = len(input_img_list) 43 | 44 | # ------------------ set up CodeFormer restorer ------------------- 45 | net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=512, n_head=8, n_layers=9, 46 | connect_list=['32', '64', '128']).to(device) 47 | 48 | # ckpt_path = 'weights/CodeFormer/codeformer.pth' 49 | ckpt_path = load_file_from_url(url=pretrain_model_url, 50 | model_dir='weights/CodeFormer', progress=True, file_name=None) 51 | checkpoint = torch.load(ckpt_path)['params_ema'] 52 | net.load_state_dict(checkpoint) 53 | net.eval() 54 | 55 | # -------------------- start to processing --------------------- 56 | for i, img_path in enumerate(input_img_list): 57 | img_name = os.path.basename(img_path) 58 | basename, ext = os.path.splitext(img_name) 59 | print(f'[{i+1}/{test_img_num}] Processing: {img_name}') 60 | input_face = cv2.imread(img_path) 61 | assert input_face.shape[:2] == (512, 512), 'Input resolution must be 512x512 for inpainting.' 62 | # input_face = cv2.resize(input_face, (512, 512), interpolation=cv2.INTER_LINEAR) 63 | input_face = img2tensor(input_face / 255., bgr2rgb=True, float32=True) 64 | normalize(input_face, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 65 | input_face = input_face.unsqueeze(0).to(device) 66 | try: 67 | with torch.no_grad(): 68 | mask = torch.zeros(512, 512) 69 | m_ind = torch.sum(input_face[0], dim=0) 70 | mask[m_ind==3] = 1.0 71 | mask = mask.view(1, 1, 512, 512).to(device) 72 | # w is fixed to 1, adain=False for inpainting 73 | output_face = net(input_face, w=1, adain=False)[0] 74 | output_face = (1-mask)*input_face + mask*output_face 75 | save_face = tensor2img(output_face, rgb2bgr=True, min_max=(-1, 1)) 76 | del output_face 77 | torch.cuda.empty_cache() 78 | except Exception as error: 79 | print(f'\tFailed inference for CodeFormer: {error}') 80 | save_face = tensor2img(input_face, rgb2bgr=True, min_max=(-1, 1)) 81 | 82 | save_face = save_face.astype('uint8') 83 | 84 | # save face 85 | if args.suffix is not None: 86 | basename = f'{basename}_{args.suffix}' 87 | save_restore_path = os.path.join(result_root, f'{basename}.png') 88 | imwrite(save_face, save_restore_path) 89 | 90 | print(f'\nAll results are saved in {result_root}') 91 | 92 | -------------------------------------------------------------------------------- /inputs/cropped_faces/0143.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0143.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0240.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0240.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0342.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0342.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0345.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0345.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0368.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0368.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0412.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0412.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0444.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0444.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0478.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0478.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0500.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0500.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0599.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0599.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0717.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0717.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0720.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0720.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0729.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0729.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0763.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0763.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0770.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0770.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0777.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0777.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0885.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0885.png -------------------------------------------------------------------------------- /inputs/cropped_faces/0934.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/0934.png -------------------------------------------------------------------------------- /inputs/cropped_faces/Solvay_conference_1927_0018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/Solvay_conference_1927_0018.png -------------------------------------------------------------------------------- /inputs/cropped_faces/Solvay_conference_1927_2_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/cropped_faces/Solvay_conference_1927_2_16.png -------------------------------------------------------------------------------- /inputs/gray_faces/067_David_Beckham_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/067_David_Beckham_00.png -------------------------------------------------------------------------------- /inputs/gray_faces/089_Miley_Cyrus_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/089_Miley_Cyrus_00.png -------------------------------------------------------------------------------- /inputs/gray_faces/099_Victoria_Beckham_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/099_Victoria_Beckham_00.png -------------------------------------------------------------------------------- /inputs/gray_faces/111_Alexa_Chung_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/111_Alexa_Chung_00.png -------------------------------------------------------------------------------- /inputs/gray_faces/132_Robert_Downey_Jr_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/132_Robert_Downey_Jr_00.png -------------------------------------------------------------------------------- /inputs/gray_faces/158_Jimmy_Fallon_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/158_Jimmy_Fallon_00.png -------------------------------------------------------------------------------- /inputs/gray_faces/161_Zac_Efron_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/161_Zac_Efron_00.png -------------------------------------------------------------------------------- /inputs/gray_faces/169_John_Lennon_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/169_John_Lennon_00.png -------------------------------------------------------------------------------- /inputs/gray_faces/170_Marilyn_Monroe_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/170_Marilyn_Monroe_00.png -------------------------------------------------------------------------------- /inputs/gray_faces/Einstein01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/Einstein01.png -------------------------------------------------------------------------------- /inputs/gray_faces/Einstein02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/Einstein02.png -------------------------------------------------------------------------------- /inputs/gray_faces/Hepburn01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/Hepburn01.png -------------------------------------------------------------------------------- /inputs/gray_faces/Hepburn02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/gray_faces/Hepburn02.png -------------------------------------------------------------------------------- /inputs/masked_faces/00105.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/masked_faces/00105.png -------------------------------------------------------------------------------- /inputs/masked_faces/00108.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/masked_faces/00108.png -------------------------------------------------------------------------------- /inputs/masked_faces/00169.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/masked_faces/00169.png -------------------------------------------------------------------------------- /inputs/masked_faces/00588.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/masked_faces/00588.png -------------------------------------------------------------------------------- /inputs/masked_faces/00664.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/masked_faces/00664.png -------------------------------------------------------------------------------- /inputs/whole_imgs/00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/whole_imgs/00.jpg -------------------------------------------------------------------------------- /inputs/whole_imgs/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/whole_imgs/01.jpg -------------------------------------------------------------------------------- /inputs/whole_imgs/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/whole_imgs/02.png -------------------------------------------------------------------------------- /inputs/whole_imgs/03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/whole_imgs/03.jpg -------------------------------------------------------------------------------- /inputs/whole_imgs/04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/whole_imgs/04.jpg -------------------------------------------------------------------------------- /inputs/whole_imgs/05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/whole_imgs/05.jpg -------------------------------------------------------------------------------- /inputs/whole_imgs/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/inputs/whole_imgs/06.png -------------------------------------------------------------------------------- /options/CodeFormer_colorization.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: CodeFormer_colorization 3 | model_type: CodeFormerIdxModel 4 | num_gpu: 8 5 | manual_seed: 0 6 | 7 | # dataset and data loader settings 8 | datasets: 9 | train: 10 | name: FFHQ 11 | type: FFHQBlindDataset 12 | dataroot_gt: datasets/ffhq/ffhq_512 13 | filename_tmpl: '{}' 14 | io_backend: 15 | type: disk 16 | 17 | in_size: 512 18 | gt_size: 512 19 | mean: [0.5, 0.5, 0.5] 20 | std: [0.5, 0.5, 0.5] 21 | use_hflip: true 22 | use_corrupt: true 23 | 24 | # large degradation in stageII 25 | blur_kernel_size: 41 26 | use_motion_kernel: false 27 | motion_kernel_prob: 0.001 28 | kernel_list: ['iso', 'aniso'] 29 | kernel_prob: [0.5, 0.5] 30 | blur_sigma: [1, 15] 31 | downsample_range: [4, 30] 32 | noise_range: [0, 20] 33 | jpeg_range: [30, 80] 34 | 35 | # color jitter and gray 36 | color_jitter_prob: 0.3 37 | color_jitter_shift: 20 38 | color_jitter_pt_prob: 0.3 39 | gray_prob: 0.01 40 | 41 | latent_gt_path: ~ # without pre-calculated latent code 42 | # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth' 43 | 44 | # data loader 45 | num_worker_per_gpu: 2 46 | batch_size_per_gpu: 4 47 | dataset_enlarge_ratio: 100 48 | prefetch_mode: ~ 49 | 50 | # val: 51 | # name: CelebA-HQ-512 52 | # type: PairedImageDataset 53 | # dataroot_lq: datasets/faces/validation/lq 54 | # dataroot_gt: datasets/faces/validation/gt 55 | # io_backend: 56 | # type: disk 57 | # mean: [0.5, 0.5, 0.5] 58 | # std: [0.5, 0.5, 0.5] 59 | # scale: 1 60 | 61 | # network structures 62 | network_g: 63 | type: CodeFormer 64 | dim_embd: 512 65 | n_head: 8 66 | n_layers: 9 67 | codebook_size: 1024 68 | connect_list: ['32', '64', '128', '256'] 69 | fix_modules: ['quantize','generator'] 70 | vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN 71 | 72 | network_vqgan: # this config is needed if no pre-calculated latent 73 | type: VQAutoEncoder 74 | img_size: 512 75 | nf: 64 76 | ch_mult: [1, 2, 2, 4, 4, 8] 77 | quantizer: 'nearest' 78 | codebook_size: 1024 79 | 80 | # path 81 | path: 82 | pretrain_network_g: ~ 83 | param_key_g: params_ema 84 | strict_load_g: false 85 | pretrain_network_d: ~ 86 | strict_load_d: true 87 | resume_state: ~ 88 | 89 | # base_lr(4.5e-6)*bach_size(4) 90 | train: 91 | use_hq_feat_loss: true 92 | feat_loss_weight: 1.0 93 | cross_entropy_loss: true 94 | entropy_loss_weight: 0.5 95 | fidelity_weight: 0 96 | 97 | optim_g: 98 | type: Adam 99 | lr: !!float 1e-4 100 | weight_decay: 0 101 | betas: [0.9, 0.99] 102 | 103 | scheduler: 104 | type: MultiStepLR 105 | milestones: [400000, 450000] 106 | gamma: 0.5 107 | 108 | total_iter: 500000 109 | 110 | warmup_iter: -1 # no warm up 111 | ema_decay: 0.995 112 | 113 | use_adaptive_weight: true 114 | 115 | net_g_start_iter: 0 116 | net_d_iters: 1 117 | net_d_start_iter: 0 118 | manual_seed: 0 119 | 120 | # validation settings 121 | val: 122 | val_freq: !!float 5e10 # no validation 123 | save_img: true 124 | 125 | metrics: 126 | psnr: # metric name, can be arbitrary 127 | type: calculate_psnr 128 | crop_border: 4 129 | test_y_channel: false 130 | 131 | # logging settings 132 | logger: 133 | print_freq: 100 134 | save_checkpoint_freq: !!float 1e4 135 | use_tb_logger: true 136 | wandb: 137 | project: ~ 138 | resume_id: ~ 139 | 140 | # dist training settings 141 | dist_params: 142 | backend: nccl 143 | port: 29419 144 | 145 | find_unused_parameters: true 146 | -------------------------------------------------------------------------------- /options/CodeFormer_inpainting.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: CodeFormer_inpainting 3 | model_type: CodeFormerModel 4 | num_gpu: 4 5 | manual_seed: 0 6 | 7 | # dataset and data loader settings 8 | datasets: 9 | train: 10 | name: FFHQ 11 | type: FFHQBlindDataset 12 | dataroot_gt: datasets/ffhq/ffhq_512 13 | filename_tmpl: '{}' 14 | io_backend: 15 | type: disk 16 | 17 | in_size: 512 18 | gt_size: 512 19 | mean: [0.5, 0.5, 0.5] 20 | std: [0.5, 0.5, 0.5] 21 | use_hflip: true 22 | use_corrupt: false 23 | gen_inpaint_mask: true 24 | 25 | latent_gt_path: ~ # without pre-calculated latent code 26 | # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth' 27 | 28 | # data loader 29 | num_worker_per_gpu: 2 30 | batch_size_per_gpu: 3 31 | dataset_enlarge_ratio: 100 32 | prefetch_mode: ~ 33 | 34 | # val: 35 | # name: CelebA-HQ-512 36 | # type: PairedImageDataset 37 | # dataroot_lq: datasets/faces/validation/lq 38 | # dataroot_gt: datasets/faces/validation/gt 39 | # io_backend: 40 | # type: disk 41 | # mean: [0.5, 0.5, 0.5] 42 | # std: [0.5, 0.5, 0.5] 43 | # scale: 1 44 | 45 | # network structures 46 | network_g: 47 | type: CodeFormer 48 | dim_embd: 512 49 | n_head: 8 50 | n_layers: 9 51 | codebook_size: 1024 52 | connect_list: ['32', '64', '128'] 53 | fix_modules: ['quantize','generator'] 54 | vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN 55 | 56 | network_vqgan: # this config is needed if no pre-calculated latent 57 | type: VQAutoEncoder 58 | img_size: 512 59 | nf: 64 60 | ch_mult: [1, 2, 2, 4, 4, 8] 61 | quantizer: 'nearest' 62 | codebook_size: 1024 63 | 64 | network_d: 65 | type: VQGANDiscriminator 66 | nc: 3 67 | ndf: 64 68 | n_layers: 4 69 | model_path: ~ 70 | 71 | # path 72 | path: 73 | pretrain_network_g: ~ 74 | param_key_g: params_ema 75 | strict_load_g: true 76 | pretrain_network_d: ~ 77 | strict_load_d: true 78 | resume_state: ~ 79 | 80 | # base_lr(4.5e-6)*bach_size(4) 81 | train: 82 | use_hq_feat_loss: true 83 | feat_loss_weight: 1.0 84 | cross_entropy_loss: true 85 | entropy_loss_weight: 0.5 86 | scale_adaptive_gan_weight: 0.1 87 | fidelity_weight: 1.0 88 | 89 | optim_g: 90 | type: Adam 91 | lr: !!float 7e-5 92 | weight_decay: 0 93 | betas: [0.9, 0.99] 94 | optim_d: 95 | type: Adam 96 | lr: !!float 7e-5 97 | weight_decay: 0 98 | betas: [0.9, 0.99] 99 | 100 | scheduler: 101 | type: MultiStepLR 102 | milestones: [250000, 300000] 103 | gamma: 0.5 104 | 105 | total_iter: 300000 106 | 107 | warmup_iter: -1 # no warm up 108 | ema_decay: 0.997 109 | 110 | pixel_opt: 111 | type: L1Loss 112 | loss_weight: 1.0 113 | reduction: mean 114 | 115 | perceptual_opt: 116 | type: LPIPSLoss 117 | loss_weight: 1.0 118 | use_input_norm: true 119 | range_norm: true 120 | 121 | gan_opt: 122 | type: GANLoss 123 | gan_type: hinge 124 | loss_weight: !!float 1.0 # adaptive_weighting 125 | 126 | 127 | use_adaptive_weight: true 128 | 129 | net_g_start_iter: 0 130 | net_d_iters: 1 131 | net_d_start_iter: 296001 132 | manual_seed: 0 133 | 134 | # validation settings 135 | val: 136 | val_freq: !!float 5e10 # no validation 137 | save_img: true 138 | 139 | metrics: 140 | psnr: # metric name, can be arbitrary 141 | type: calculate_psnr 142 | crop_border: 4 143 | test_y_channel: false 144 | 145 | # logging settings 146 | logger: 147 | print_freq: 100 148 | save_checkpoint_freq: !!float 1e4 149 | use_tb_logger: true 150 | wandb: 151 | project: ~ 152 | resume_id: ~ 153 | 154 | # dist training settings 155 | dist_params: 156 | backend: nccl 157 | port: 29420 158 | 159 | find_unused_parameters: true 160 | -------------------------------------------------------------------------------- /options/CodeFormer_stage2.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: CodeFormer_stage2 3 | model_type: CodeFormerIdxModel 4 | num_gpu: 8 5 | manual_seed: 0 6 | 7 | # dataset and data loader settings 8 | datasets: 9 | train: 10 | name: FFHQ 11 | type: FFHQBlindDataset 12 | dataroot_gt: datasets/ffhq/ffhq_512 13 | filename_tmpl: '{}' 14 | io_backend: 15 | type: disk 16 | 17 | in_size: 512 18 | gt_size: 512 19 | mean: [0.5, 0.5, 0.5] 20 | std: [0.5, 0.5, 0.5] 21 | use_hflip: true 22 | use_corrupt: true 23 | 24 | # large degradation in stageII 25 | blur_kernel_size: 41 26 | use_motion_kernel: false 27 | motion_kernel_prob: 0.001 28 | kernel_list: ['iso', 'aniso'] 29 | kernel_prob: [0.5, 0.5] 30 | blur_sigma: [1, 15] 31 | downsample_range: [4, 30] 32 | noise_range: [0, 20] 33 | jpeg_range: [30, 80] 34 | 35 | latent_gt_path: ~ # without pre-calculated latent code 36 | # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth' 37 | 38 | # data loader 39 | num_worker_per_gpu: 2 40 | batch_size_per_gpu: 4 41 | dataset_enlarge_ratio: 100 42 | prefetch_mode: ~ 43 | 44 | # val: 45 | # name: CelebA-HQ-512 46 | # type: PairedImageDataset 47 | # dataroot_lq: datasets/faces/validation/lq 48 | # dataroot_gt: datasets/faces/validation/gt 49 | # io_backend: 50 | # type: disk 51 | # mean: [0.5, 0.5, 0.5] 52 | # std: [0.5, 0.5, 0.5] 53 | # scale: 1 54 | 55 | # network structures 56 | network_g: 57 | type: CodeFormer 58 | dim_embd: 512 59 | n_head: 8 60 | n_layers: 9 61 | codebook_size: 1024 62 | connect_list: ['32', '64', '128', '256'] 63 | fix_modules: ['quantize','generator'] 64 | vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN 65 | 66 | network_vqgan: # this config is needed if no pre-calculated latent 67 | type: VQAutoEncoder 68 | img_size: 512 69 | nf: 64 70 | ch_mult: [1, 2, 2, 4, 4, 8] 71 | quantizer: 'nearest' 72 | codebook_size: 1024 73 | 74 | # path 75 | path: 76 | pretrain_network_g: ~ 77 | param_key_g: params_ema 78 | strict_load_g: false 79 | pretrain_network_d: ~ 80 | strict_load_d: true 81 | resume_state: ~ 82 | 83 | # base_lr(4.5e-6)*bach_size(4) 84 | train: 85 | use_hq_feat_loss: true 86 | feat_loss_weight: 1.0 87 | cross_entropy_loss: true 88 | entropy_loss_weight: 0.5 89 | fidelity_weight: 0 90 | 91 | optim_g: 92 | type: Adam 93 | lr: !!float 1e-4 94 | weight_decay: 0 95 | betas: [0.9, 0.99] 96 | 97 | scheduler: 98 | type: MultiStepLR 99 | milestones: [400000, 450000] 100 | gamma: 0.5 101 | 102 | # scheduler: 103 | # type: CosineAnnealingRestartLR 104 | # periods: [500000] 105 | # restart_weights: [1] 106 | # eta_min: !!float 2e-5 # no lr reduce in official vqgan code 107 | 108 | total_iter: 500000 109 | 110 | warmup_iter: -1 # no warm up 111 | ema_decay: 0.995 112 | 113 | use_adaptive_weight: true 114 | 115 | net_g_start_iter: 0 116 | net_d_iters: 1 117 | net_d_start_iter: 0 118 | manual_seed: 0 119 | 120 | # validation settings 121 | val: 122 | val_freq: !!float 5e10 # no validation 123 | save_img: true 124 | 125 | metrics: 126 | psnr: # metric name, can be arbitrary 127 | type: calculate_psnr 128 | crop_border: 4 129 | test_y_channel: false 130 | 131 | # logging settings 132 | logger: 133 | print_freq: 100 134 | save_checkpoint_freq: !!float 1e4 135 | use_tb_logger: true 136 | wandb: 137 | project: ~ 138 | resume_id: ~ 139 | 140 | # dist training settings 141 | dist_params: 142 | backend: nccl 143 | port: 29412 144 | 145 | find_unused_parameters: true 146 | -------------------------------------------------------------------------------- /options/CodeFormer_stage3.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: CodeFormer_stage3 3 | model_type: CodeFormerJointModel 4 | num_gpu: 8 5 | manual_seed: 0 6 | 7 | # dataset and data loader settings 8 | datasets: 9 | train: 10 | name: FFHQ 11 | type: FFHQBlindJointDataset 12 | dataroot_gt: datasets/ffhq/ffhq_512 13 | filename_tmpl: '{}' 14 | io_backend: 15 | type: disk 16 | 17 | in_size: 512 18 | gt_size: 512 19 | mean: [0.5, 0.5, 0.5] 20 | std: [0.5, 0.5, 0.5] 21 | use_hflip: true 22 | use_corrupt: true 23 | 24 | blur_kernel_size: 41 25 | use_motion_kernel: false 26 | motion_kernel_prob: 0.001 27 | kernel_list: ['iso', 'aniso'] 28 | kernel_prob: [0.5, 0.5] 29 | # small degradation in stageIII 30 | blur_sigma: [0.1, 10] 31 | downsample_range: [1, 12] 32 | noise_range: [0, 15] 33 | jpeg_range: [60, 100] 34 | # large degradation in stageII 35 | blur_sigma_large: [1, 15] 36 | downsample_range_large: [4, 30] 37 | noise_range_large: [0, 20] 38 | jpeg_range_large: [30, 80] 39 | 40 | latent_gt_path: ~ # without pre-calculated latent code 41 | # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth' 42 | 43 | # data loader 44 | num_worker_per_gpu: 1 45 | batch_size_per_gpu: 3 46 | dataset_enlarge_ratio: 100 47 | prefetch_mode: ~ 48 | 49 | # val: 50 | # name: CelebA-HQ-512 51 | # type: PairedImageDataset 52 | # dataroot_lq: datasets/faces/validation/lq 53 | # dataroot_gt: datasets/faces/validation/gt 54 | # io_backend: 55 | # type: disk 56 | # mean: [0.5, 0.5, 0.5] 57 | # std: [0.5, 0.5, 0.5] 58 | # scale: 1 59 | 60 | # network structures 61 | network_g: 62 | type: CodeFormer 63 | dim_embd: 512 64 | n_head: 8 65 | n_layers: 9 66 | codebook_size: 1024 67 | connect_list: ['32', '64', '128', '256'] 68 | fix_modules: ['quantize','generator'] 69 | 70 | network_vqgan: # this config is needed if no pre-calculated latent 71 | type: VQAutoEncoder 72 | img_size: 512 73 | nf: 64 74 | ch_mult: [1, 2, 2, 4, 4, 8] 75 | quantizer: 'nearest' 76 | codebook_size: 1024 77 | 78 | network_d: 79 | type: VQGANDiscriminator 80 | nc: 3 81 | ndf: 64 82 | n_layers: 4 83 | 84 | # path 85 | path: 86 | pretrain_network_g: './experiments/pretrained_models/CodeFormer_stage2/net_g_latest.pth' # pretrained G model in StageII 87 | param_key_g: params_ema 88 | strict_load_g: false 89 | pretrain_network_d: './experiments/pretrained_models/CodeFormer_stage2/net_d_latest.pth' # pretrained D model in StageII 90 | resume_state: ~ 91 | 92 | # base_lr(4.5e-6)*bach_size(4) 93 | train: 94 | use_hq_feat_loss: true 95 | feat_loss_weight: 1.0 96 | cross_entropy_loss: true 97 | entropy_loss_weight: 0.5 98 | scale_adaptive_gan_weight: 0.1 99 | 100 | optim_g: 101 | type: Adam 102 | lr: !!float 5e-5 103 | weight_decay: 0 104 | betas: [0.9, 0.99] 105 | optim_d: 106 | type: Adam 107 | lr: !!float 5e-5 108 | weight_decay: 0 109 | betas: [0.9, 0.99] 110 | 111 | scheduler: 112 | type: CosineAnnealingRestartLR 113 | periods: [150000] 114 | restart_weights: [1] 115 | eta_min: !!float 2e-5 116 | 117 | 118 | total_iter: 150000 119 | 120 | warmup_iter: -1 # no warm up 121 | ema_decay: 0.997 122 | 123 | pixel_opt: 124 | type: L1Loss 125 | loss_weight: 1.0 126 | reduction: mean 127 | 128 | perceptual_opt: 129 | type: LPIPSLoss 130 | loss_weight: 1.0 131 | use_input_norm: true 132 | range_norm: true 133 | 134 | gan_opt: 135 | type: GANLoss 136 | gan_type: hinge 137 | loss_weight: !!float 1.0 # adaptive_weighting 138 | 139 | use_adaptive_weight: true 140 | 141 | net_g_start_iter: 0 142 | net_d_iters: 1 143 | net_d_start_iter: 5001 144 | manual_seed: 0 145 | 146 | # validation settings 147 | val: 148 | val_freq: !!float 5e10 # no validation 149 | save_img: true 150 | 151 | metrics: 152 | psnr: # metric name, can be arbitrary 153 | type: calculate_psnr 154 | crop_border: 4 155 | test_y_channel: false 156 | 157 | # logging settings 158 | logger: 159 | print_freq: 100 160 | save_checkpoint_freq: !!float 5e3 161 | use_tb_logger: true 162 | wandb: 163 | project: ~ 164 | resume_id: ~ 165 | 166 | # dist training settings 167 | dist_params: 168 | backend: nccl 169 | port: 29413 170 | 171 | find_unused_parameters: true 172 | -------------------------------------------------------------------------------- /options/VQGAN_512_ds32_nearest_stage1.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: VQGAN-512-ds32-nearest-stage1 3 | model_type: VQGANModel 4 | num_gpu: 8 5 | manual_seed: 0 6 | 7 | # dataset and data loader settings 8 | datasets: 9 | train: 10 | name: FFHQ 11 | type: FFHQBlindDataset 12 | dataroot_gt: datasets/ffhq/ffhq_512 13 | filename_tmpl: '{}' 14 | io_backend: 15 | type: disk 16 | 17 | in_size: 512 18 | gt_size: 512 19 | mean: [0.5, 0.5, 0.5] 20 | std: [0.5, 0.5, 0.5] 21 | use_hflip: true 22 | use_corrupt: false # for VQGAN 23 | 24 | # data loader 25 | num_worker_per_gpu: 2 26 | batch_size_per_gpu: 4 27 | dataset_enlarge_ratio: 100 28 | 29 | prefetch_mode: cpu 30 | num_prefetch_queue: 4 31 | 32 | # val: 33 | # name: CelebA-HQ-512 34 | # type: PairedImageDataset 35 | # dataroot_lq: datasets/faces/validation/gt 36 | # dataroot_gt: datasets/faces/validation/gt 37 | # io_backend: 38 | # type: disk 39 | # mean: [0.5, 0.5, 0.5] 40 | # std: [0.5, 0.5, 0.5] 41 | # scale: 1 42 | 43 | # network structures 44 | network_g: 45 | type: VQAutoEncoder 46 | img_size: 512 47 | nf: 64 48 | ch_mult: [1, 2, 2, 4, 4, 8] 49 | quantizer: 'nearest' 50 | codebook_size: 1024 51 | 52 | network_d: 53 | type: VQGANDiscriminator 54 | nc: 3 55 | ndf: 64 56 | 57 | # path 58 | path: 59 | pretrain_network_g: ~ 60 | param_key_g: params_ema 61 | strict_load_g: true 62 | pretrain_network_d: ~ 63 | strict_load_d: true 64 | resume_state: ~ 65 | 66 | # base_lr(4.5e-6)*bach_size(4) 67 | train: 68 | optim_g: 69 | type: Adam 70 | lr: !!float 7e-5 71 | weight_decay: 0 72 | betas: [0.9, 0.99] 73 | optim_d: 74 | type: Adam 75 | lr: !!float 7e-5 76 | weight_decay: 0 77 | betas: [0.9, 0.99] 78 | 79 | scheduler: 80 | type: CosineAnnealingRestartLR 81 | periods: [1600000] 82 | restart_weights: [1] 83 | eta_min: !!float 6e-5 # no lr reduce in official vqgan code 84 | 85 | total_iter: 1600000 86 | 87 | warmup_iter: -1 # no warm up 88 | ema_decay: 0.995 # GFPGAN: 0.5**(32 / (10 * 1000) == 0.998; Unleashing: 0.995 89 | 90 | pixel_opt: 91 | type: L1Loss 92 | loss_weight: 1.0 93 | reduction: mean 94 | 95 | perceptual_opt: 96 | type: LPIPSLoss 97 | loss_weight: 1.0 98 | use_input_norm: true 99 | range_norm: true 100 | 101 | gan_opt: 102 | type: GANLoss 103 | gan_type: hinge 104 | loss_weight: !!float 1.0 # adaptive_weighting 105 | 106 | net_g_start_iter: 0 107 | net_d_iters: 1 108 | net_d_start_iter: 30001 109 | manual_seed: 0 110 | 111 | # validation settings 112 | val: 113 | val_freq: !!float 5e10 # no validation 114 | save_img: true 115 | 116 | metrics: 117 | psnr: # metric name, can be arbitrary 118 | type: calculate_psnr 119 | crop_border: 4 120 | test_y_channel: false 121 | 122 | # logging settings 123 | logger: 124 | print_freq: 100 125 | save_checkpoint_freq: !!float 1e4 126 | use_tb_logger: true 127 | wandb: 128 | project: ~ 129 | resume_id: ~ 130 | 131 | # dist training settings 132 | dist_params: 133 | backend: nccl 134 | port: 29411 135 | 136 | find_unused_parameters: true 137 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | future 3 | lmdb 4 | numpy 5 | opencv-python 6 | Pillow 7 | pyyaml 8 | requests 9 | scikit-image 10 | scipy 11 | tb-nightly 12 | torch>=1.7.1 13 | torchvision 14 | tqdm 15 | yapf 16 | lpips 17 | gdown # supports downloading the large file from Google Drive -------------------------------------------------------------------------------- /scripts/download_pretrained_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path as osp 4 | 5 | from basicsr.utils.download_util import load_file_from_url 6 | 7 | 8 | def download_pretrained_models(method, file_urls): 9 | if method == 'CodeFormer_train': 10 | method = 'CodeFormer' 11 | save_path_root = f'./weights/{method}' 12 | os.makedirs(save_path_root, exist_ok=True) 13 | 14 | for file_name, file_url in file_urls.items(): 15 | save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name) 16 | 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser() 20 | 21 | parser.add_argument( 22 | 'method', 23 | type=str, 24 | help=("Options: 'CodeFormer' 'facelib' 'dlib'. Set to 'all' to download all the models.")) 25 | args = parser.parse_args() 26 | 27 | file_urls = { 28 | 'CodeFormer': { 29 | 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' 30 | }, 31 | 'CodeFormer_train': { 32 | 'vqgan_code1024.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/vqgan_code1024.pth', 33 | 'latent_gt_code1024.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/latent_gt_code1024.pth', 34 | 'codeformer_stage2.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer_stage2.pth', 35 | 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' 36 | }, 37 | 'facelib': { 38 | # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth', 39 | 'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth', 40 | 'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' 41 | }, 42 | 'dlib': { 43 | 'mmod_human_face_detector-4cb19393.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat', 44 | 'shape_predictor_5_face_landmarks-c4b1e980.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat' 45 | } 46 | } 47 | 48 | if args.method == 'all': 49 | for method in file_urls.keys(): 50 | download_pretrained_models(method, file_urls[method]) 51 | else: 52 | download_pretrained_models(args.method, file_urls[args.method]) -------------------------------------------------------------------------------- /scripts/download_pretrained_models_from_gdrive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path as osp 4 | 5 | # from basicsr.utils.download_util import download_file_from_google_drive 6 | import gdown 7 | 8 | 9 | def download_pretrained_models(method, file_ids): 10 | save_path_root = f'./weights/{method}' 11 | os.makedirs(save_path_root, exist_ok=True) 12 | 13 | for file_name, file_id in file_ids.items(): 14 | file_url = 'https://drive.google.com/uc?id='+file_id 15 | save_path = osp.abspath(osp.join(save_path_root, file_name)) 16 | if osp.exists(save_path): 17 | user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') 18 | if user_response.lower() == 'y': 19 | print(f'Covering {file_name} to {save_path}') 20 | gdown.download(file_url, save_path, quiet=False) 21 | # download_file_from_google_drive(file_id, save_path) 22 | elif user_response.lower() == 'n': 23 | print(f'Skipping {file_name}') 24 | else: 25 | raise ValueError('Wrong input. Only accepts Y/N.') 26 | else: 27 | print(f'Downloading {file_name} to {save_path}') 28 | gdown.download(file_url, save_path, quiet=False) 29 | # download_file_from_google_drive(file_id, save_path) 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | 34 | parser.add_argument( 35 | 'method', 36 | type=str, 37 | help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models.")) 38 | args = parser.parse_args() 39 | 40 | # file name: file id 41 | # 'dlib': { 42 | # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX', 43 | # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg', 44 | # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq' 45 | # } 46 | file_ids = { 47 | 'CodeFormer': { 48 | 'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB' 49 | }, 50 | 'facelib': { 51 | 'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV', 52 | 'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK' 53 | } 54 | } 55 | 56 | if args.method == 'all': 57 | for method in file_ids.keys(): 58 | download_pretrained_models(method, file_ids[method]) 59 | else: 60 | download_pretrained_models(args.method, file_ids[args.method]) -------------------------------------------------------------------------------- /scripts/generate_latent_gt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import numpy as np 4 | import os 5 | import cv2 6 | import torch 7 | from torchvision.transforms.functional import normalize 8 | from basicsr.utils import imwrite, img2tensor, tensor2img 9 | 10 | from basicsr.utils.registry import ARCH_REGISTRY 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq/ffhq_512') 15 | parser.add_argument('-o', '--save_root', type=str, default='./experiments/pretrained_models/vqgan') 16 | parser.add_argument('--codebook_size', type=int, default=1024) 17 | parser.add_argument('--ckpt_path', type=str, default='./experiments/pretrained_models/vqgan/net_g.pth') 18 | args = parser.parse_args() 19 | 20 | if args.save_root.endswith('/'): # solve when path ends with / 21 | args.save_root = args.save_root[:-1] 22 | dir_name = os.path.abspath(args.save_root) 23 | os.makedirs(dir_name, exist_ok=True) 24 | 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | test_path = args.test_path 27 | save_root = args.save_root 28 | ckpt_path = args.ckpt_path 29 | codebook_size = args.codebook_size 30 | 31 | vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest', 32 | codebook_size=codebook_size).to(device) 33 | checkpoint = torch.load(ckpt_path)['params_ema'] 34 | 35 | vqgan.load_state_dict(checkpoint) 36 | vqgan.eval() 37 | 38 | sum_latent = np.zeros((codebook_size)).astype('float64') 39 | size_latent = 16 40 | latent = {} 41 | latent['orig'] = {} 42 | latent['hflip'] = {} 43 | for i in ['orig', 'hflip']: 44 | # for i in ['hflip']: 45 | for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))): 46 | img_name = os.path.basename(img_path) 47 | img = cv2.imread(img_path) 48 | if i == 'hflip': 49 | cv2.flip(img, 1, img) 50 | img = img2tensor(img / 255., bgr2rgb=True, float32=True) 51 | normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 52 | img = img.unsqueeze(0).to(device) 53 | with torch.no_grad(): 54 | # output = net(img)[0] 55 | x, feat_dict = vqgan.encoder(img, True) 56 | x, _, log = vqgan.quantize(x) 57 | # del output 58 | torch.cuda.empty_cache() 59 | 60 | min_encoding_indices = log['min_encoding_indices'] 61 | min_encoding_indices = min_encoding_indices.view(size_latent,size_latent) 62 | latent[i][img_name[:-4]] = min_encoding_indices.cpu().numpy() 63 | print(img_name, latent[i][img_name[:-4]].shape) 64 | 65 | latent_save_path = os.path.join(save_root, f'latent_gt_code{codebook_size}.pth') 66 | torch.save(latent, latent_save_path) 67 | print(f'\nLatent GT code are saved in {save_root}') 68 | -------------------------------------------------------------------------------- /scripts/inference_vqgan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import numpy as np 4 | import os 5 | import cv2 6 | import torch 7 | from torchvision.transforms.functional import normalize 8 | from basicsr.utils import imwrite, img2tensor, tensor2img 9 | 10 | from basicsr.utils.registry import ARCH_REGISTRY 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq/ffhq_512') 15 | parser.add_argument('-o', '--save_root', type=str, default='./results/vqgan_rec') 16 | parser.add_argument('--codebook_size', type=int, default=1024) 17 | parser.add_argument('--ckpt_path', type=str, default='./experiments/pretrained_models/vqgan/net_g.pth') 18 | args = parser.parse_args() 19 | 20 | if args.save_root.endswith('/'): # solve when path ends with / 21 | args.save_root = args.save_root[:-1] 22 | dir_name = os.path.abspath(args.save_root) 23 | os.makedirs(dir_name, exist_ok=True) 24 | 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | test_path = args.test_path 27 | save_root = args.save_root 28 | ckpt_path = args.ckpt_path 29 | codebook_size = args.codebook_size 30 | 31 | vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest', 32 | codebook_size=codebook_size).to(device) 33 | checkpoint = torch.load(ckpt_path)['params_ema'] 34 | 35 | vqgan.load_state_dict(checkpoint) 36 | vqgan.eval() 37 | 38 | for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))): 39 | img_name = os.path.basename(img_path) 40 | print(img_name) 41 | img = cv2.imread(img_path) 42 | img = img2tensor(img / 255., bgr2rgb=True, float32=True) 43 | normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 44 | img = img.unsqueeze(0).to(device) 45 | with torch.no_grad(): 46 | output = vqgan(img)[0] 47 | output = tensor2img(output, min_max=[-1,1]) 48 | img = tensor2img(img, min_max=[-1,1]) 49 | restored_img = np.concatenate([img, output], axis=1) 50 | restored_img = output 51 | del output 52 | torch.cuda.empty_cache() 53 | 54 | path = os.path.splitext(os.path.join(save_root, img_name))[0] 55 | save_path = f'{path}.png' 56 | imwrite(restored_img, save_path) 57 | 58 | print(f'\nAll results are saved in {save_root}') 59 | 60 | -------------------------------------------------------------------------------- /web-demos/replicate/cog.yaml: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used for deploying replicate demo: 3 | https://replicate.com/sczhou/codeformer 4 | """ 5 | 6 | build: 7 | gpu: true 8 | cuda: "11.3" 9 | python_version: "3.8" 10 | system_packages: 11 | - "libgl1-mesa-glx" 12 | - "libglib2.0-0" 13 | python_packages: 14 | - "ipython==8.4.0" 15 | - "future==0.18.2" 16 | - "lmdb==1.3.0" 17 | - "scikit-image==0.19.3" 18 | - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" 19 | - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" 20 | - "scipy==1.9.0" 21 | - "gdown==4.5.1" 22 | - "pyyaml==6.0" 23 | - "tb-nightly==2.11.0a20220906" 24 | - "tqdm==4.64.1" 25 | - "yapf==0.32.0" 26 | - "lpips==0.1.4" 27 | - "Pillow==9.2.0" 28 | - "opencv-python==4.6.0.66" 29 | 30 | predict: "predict.py:Predictor" 31 | -------------------------------------------------------------------------------- /weights/CodeFormer/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/weights/CodeFormer/.gitkeep -------------------------------------------------------------------------------- /weights/README.md: -------------------------------------------------------------------------------- 1 | # Weights 2 | 3 | Put the downloaded pre-trained models to this folder. -------------------------------------------------------------------------------- /weights/facelib/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sczhou/CodeFormer/e878192ee253cfcc8f19e29d3307c181501f53ae/weights/facelib/.gitkeep --------------------------------------------------------------------------------