├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── codeformer
├── __init__.py
├── app.py
├── basicsr
│ ├── VERSION
│ ├── __init__.py
│ ├── archs
│ │ ├── __init__.py
│ │ ├── arcface_arch.py
│ │ ├── arch_util.py
│ │ ├── codeformer_arch.py
│ │ ├── rrdbnet_arch.py
│ │ ├── vgg_arch.py
│ │ └── vqgan_arch.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── data_sampler.py
│ │ ├── data_util.py
│ │ ├── prefetch_dataloader.py
│ │ └── transforms.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── loss_util.py
│ │ └── losses.py
│ ├── metrics
│ │ ├── __init__.py
│ │ ├── metric_util.py
│ │ └── psnr_ssim.py
│ ├── models
│ │ └── __init__.py
│ ├── ops
│ │ ├── __init__.py
│ │ ├── dcn
│ │ │ ├── __init__.py
│ │ │ ├── deform_conv.py
│ │ │ └── src
│ │ │ │ ├── deform_conv_cuda.cpp
│ │ │ │ ├── deform_conv_cuda_kernel.cu
│ │ │ │ └── deform_conv_ext.cpp
│ │ ├── fused_act
│ │ │ ├── __init__.py
│ │ │ ├── fused_act.py
│ │ │ └── src
│ │ │ │ ├── fused_bias_act.cpp
│ │ │ │ └── fused_bias_act_kernel.cu
│ │ └── upfirdn2d
│ │ │ ├── __init__.py
│ │ │ ├── src
│ │ │ ├── upfirdn2d.cpp
│ │ │ └── upfirdn2d_kernel.cu
│ │ │ └── upfirdn2d.py
│ ├── setup.py
│ ├── train.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── dist_util.py
│ │ ├── download_util.py
│ │ ├── file_client.py
│ │ ├── img_util.py
│ │ ├── lmdb_util.py
│ │ ├── logger.py
│ │ ├── matlab_functions.py
│ │ ├── misc.py
│ │ ├── options.py
│ │ ├── realesrgan_utils.py
│ │ ├── registry.py
│ │ └── video_util.py
├── facelib
│ ├── __init__.py
│ ├── detection
│ │ ├── __init__.py
│ │ ├── align_trans.py
│ │ ├── matlab_cp2tform.py
│ │ ├── retinaface
│ │ │ ├── __init__.py
│ │ │ ├── retinaface.py
│ │ │ ├── retinaface_net.py
│ │ │ └── retinaface_utils.py
│ │ └── yolov5face
│ │ │ ├── __init__.py
│ │ │ ├── face_detector.py
│ │ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── experimental.py
│ │ │ ├── yolo.py
│ │ │ ├── yolov5l.yaml
│ │ │ └── yolov5n.yaml
│ │ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── autoanchor.py
│ │ │ ├── datasets.py
│ │ │ ├── extract_ckpt.py
│ │ │ ├── general.py
│ │ │ └── torch_utils.py
│ ├── parsing
│ │ ├── __init__.py
│ │ ├── bisenet.py
│ │ ├── parsenet.py
│ │ └── resnet.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── face_restoration_helper.py
│ │ ├── face_utils.py
│ │ └── misc.py
├── inference_codeformer.py
└── scripts
│ ├── __init__.py
│ ├── code_format.sh
│ ├── crop_align_face.py
│ ├── download_pretrained_models.py
│ ├── download_pretrained_models_from_gdrive.py
│ └── package.sh
├── pyproject.toml
├── requirements.txt
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | S-Lab License 1.0
2 |
3 | Copyright 2022 S-Lab
4 |
5 | Redistribution and use for non-commercial purpose in source and
6 | binary forms, with or without modification, are permitted provided
7 | that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright
10 | notice, this list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright
13 | notice, this list of conditions and the following disclaimer in
14 | the documentation and/or other materials provided with the
15 | distribution.
16 |
17 | 3. Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived
19 | from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 |
33 | In the event that redistribution and/or use for commercial purpose in
34 | source or binary forms, with or without modification is required,
35 | please contact the contributor(s) of the work.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Towards Robust Blind Face Restoration with Codebook Lookup Transformer
3 |
4 |
5 |
6 |
7 |
8 | This repo is a PyTorch implementation of the paper [CodeFormer](https://arxiv.org/abs/2206.11253).
9 |
10 | ### Installation
11 | ```bash
12 | pip install codeformer-pip
13 | ```
14 | ### Usage
15 | ```python
16 | from codeformer.app import inference_app
17 |
18 | inference_app(
19 | image="test.jpg",
20 | background_enhance=True,
21 | face_upsample=True,
22 | upscale=2,
23 | codeformer_fidelity=0.5,
24 | )
25 | ```
26 | ### Citation
27 | ```bibtex
28 | @inproceedings{zhou2022codeformer,
29 | author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
30 | title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
31 | booktitle = {NeurIPS},
32 | year = {2022}
33 | }
34 | ```
35 |
36 | ### License
37 |
38 | This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license.
--------------------------------------------------------------------------------
/codeformer/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.4"
2 |
--------------------------------------------------------------------------------
/codeformer/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import torch
5 | from torchvision.transforms.functional import normalize
6 |
7 | from codeformer.basicsr.archs.rrdbnet_arch import RRDBNet
8 | from codeformer.basicsr.utils import img2tensor, imwrite, tensor2img
9 | from codeformer.basicsr.utils.download_util import load_file_from_url
10 | from codeformer.basicsr.utils.realesrgan_utils import RealESRGANer
11 | from codeformer.basicsr.utils.registry import ARCH_REGISTRY
12 | from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
13 | from codeformer.facelib.utils.misc import is_gray
14 |
15 | pretrain_model_url = {
16 | "codeformer": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
17 | "detection": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
18 | "parsing": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
19 | "realesrgan": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
20 | }
21 |
22 | # download weights
23 | if not os.path.exists("CodeFormer/weights/CodeFormer/codeformer.pth"):
24 | load_file_from_url(
25 | url=pretrain_model_url["codeformer"], model_dir="CodeFormer/weights/CodeFormer", progress=True, file_name=None
26 | )
27 | if not os.path.exists("CodeFormer/weights/facelib/detection_Resnet50_Final.pth"):
28 | load_file_from_url(
29 | url=pretrain_model_url["detection"], model_dir="CodeFormer/weights/facelib", progress=True, file_name=None
30 | )
31 | if not os.path.exists("CodeFormer/weights/facelib/parsing_parsenet.pth"):
32 | load_file_from_url(
33 | url=pretrain_model_url["parsing"], model_dir="CodeFormer/weights/facelib", progress=True, file_name=None
34 | )
35 | if not os.path.exists("CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth"):
36 | load_file_from_url(
37 | url=pretrain_model_url["realesrgan"], model_dir="CodeFormer/weights/realesrgan", progress=True, file_name=None
38 | )
39 |
40 |
41 | def imread(img_path):
42 | img = cv2.imread(img_path)
43 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
44 | return img
45 |
46 |
47 | # set enhancer with RealESRGAN
48 | def set_realesrgan():
49 | half = True if torch.cuda.is_available() else False
50 | model = RRDBNet(
51 | num_in_ch=3,
52 | num_out_ch=3,
53 | num_feat=64,
54 | num_block=23,
55 | num_grow_ch=32,
56 | scale=2,
57 | )
58 | upsampler = RealESRGANer(
59 | scale=2,
60 | model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
61 | model=model,
62 | tile=400,
63 | tile_pad=40,
64 | pre_pad=0,
65 | half=half,
66 | )
67 | return upsampler
68 |
69 |
70 | upsampler = set_realesrgan()
71 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72 | codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
73 | dim_embd=512,
74 | codebook_size=1024,
75 | n_head=8,
76 | n_layers=9,
77 | connect_list=["32", "64", "128", "256"],
78 | ).to(device)
79 | ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
80 | checkpoint = torch.load(ckpt_path)["params_ema"]
81 | codeformer_net.load_state_dict(checkpoint)
82 | codeformer_net.eval()
83 |
84 | os.makedirs("output", exist_ok=True)
85 |
86 |
87 | def inference_app(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
88 | # take the default setting for the demo
89 | has_aligned = False
90 | only_center_face = False
91 | draw_box = False
92 | detection_model = "retinaface_resnet50"
93 | print("Inp:", type(image), background_enhance, face_upsample, upscale, codeformer_fidelity)
94 | if isinstance(image, str):
95 | img = cv2.imread(str(image), cv2.IMREAD_COLOR)
96 | if isinstance(image, np.ndarray):
97 | img = image
98 | print("\timage size:", img.shape)
99 |
100 | upscale = int(upscale) # convert type to int
101 | if upscale > 4: # avoid memory exceeded due to too large upscale
102 | upscale = 4
103 | if upscale > 2 and max(img.shape[:2]) > 1000: # avoid memory exceeded due to too large img resolution
104 | upscale = 2
105 | if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution
106 | upscale = 1
107 | background_enhance = False
108 | face_upsample = False
109 |
110 | face_helper = FaceRestoreHelper(
111 | upscale,
112 | face_size=512,
113 | crop_ratio=(1, 1),
114 | det_model=detection_model,
115 | save_ext="png",
116 | use_parse=True,
117 | device=device,
118 | )
119 | bg_upsampler = upsampler if background_enhance else None
120 | face_upsampler = upsampler if face_upsample else None
121 |
122 | if has_aligned:
123 | # the input faces are already cropped and aligned
124 | img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
125 | face_helper.is_gray = is_gray(img, threshold=5)
126 | if face_helper.is_gray:
127 | print("\tgrayscale input: True")
128 | face_helper.cropped_faces = [img]
129 | else:
130 | face_helper.read_image(img)
131 | # get face landmarks for each face
132 | num_det_faces = face_helper.get_face_landmarks_5(
133 | only_center_face=only_center_face, resize=640, eye_dist_threshold=5
134 | )
135 | print(f"\tdetect {num_det_faces} faces")
136 | # align and warp each face
137 | face_helper.align_warp_face()
138 |
139 | # face restoration for each cropped face
140 | for idx, cropped_face in enumerate(face_helper.cropped_faces):
141 | # prepare data
142 | cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
143 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
144 | cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
145 |
146 | try:
147 | with torch.no_grad():
148 | output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
149 | restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
150 | del output
151 | torch.cuda.empty_cache()
152 | except RuntimeError as error:
153 | print(f"Failed inference for CodeFormer: {error}")
154 | restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
155 |
156 | restored_face = restored_face.astype("uint8")
157 | face_helper.add_restored_face(restored_face)
158 |
159 | # paste_back
160 | if not has_aligned:
161 | # upsample the background
162 | if bg_upsampler is not None:
163 | # Now only support RealESRGAN for upsampling background
164 | bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
165 | else:
166 | bg_img = None
167 | face_helper.get_inverse_affine(None)
168 | # paste each restored face to the input image
169 | if face_upsample and face_upsampler is not None:
170 | restored_img = face_helper.paste_faces_to_input_image(
171 | upsample_img=bg_img,
172 | draw_box=draw_box,
173 | face_upsampler=face_upsampler,
174 | )
175 | else:
176 | restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box)
177 |
178 | return restored_img
179 |
--------------------------------------------------------------------------------
/codeformer/basicsr/VERSION:
--------------------------------------------------------------------------------
1 | 1.3.2
2 |
--------------------------------------------------------------------------------
/codeformer/basicsr/__init__.py:
--------------------------------------------------------------------------------
1 | # https://github.com/xinntao/BasicSR
2 | # flake8: noqa
3 | from .archs import *
4 | from .data import *
5 | from .losses import *
6 | from .metrics import *
7 | from .models import *
8 | from .ops import *
9 | from .train import *
10 | from .utils import *
11 |
--------------------------------------------------------------------------------
/codeformer/basicsr/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from codeformer.basicsr.utils import get_root_logger, scandir
6 | from codeformer.basicsr.utils.registry import ARCH_REGISTRY
7 |
8 | __all__ = ["build_network"]
9 |
10 | # automatically scan and import arch modules for registry
11 | # scan all the files under the 'archs' folder and collect files ending with
12 | # '_arch.py'
13 | arch_folder = osp.dirname(osp.abspath(__file__))
14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith("_arch.py")]
15 | # import all the arch modules
16 | _arch_modules = [importlib.import_module(f"codeformer.basicsr.archs.{file_name}") for file_name in arch_filenames]
17 |
18 |
19 | def build_network(opt):
20 | opt = deepcopy(opt)
21 | network_type = opt.pop("type")
22 | net = ARCH_REGISTRY.get(network_type)(**opt)
23 | logger = get_root_logger()
24 | logger.info(f"Network [{net.__class__.__name__}] is created.")
25 | return net
26 |
--------------------------------------------------------------------------------
/codeformer/basicsr/archs/rrdbnet_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 |
5 | from codeformer.basicsr.archs.arch_util import default_init_weights, make_layer, pixel_unshuffle
6 | from codeformer.basicsr.utils.registry import ARCH_REGISTRY
7 |
8 |
9 | class ResidualDenseBlock(nn.Module):
10 | """Residual Dense Block.
11 |
12 | Used in RRDB block in ESRGAN.
13 |
14 | Args:
15 | num_feat (int): Channel number of intermediate features.
16 | num_grow_ch (int): Channels for each growth.
17 | """
18 |
19 | def __init__(self, num_feat=64, num_grow_ch=32):
20 | super(ResidualDenseBlock, self).__init__()
21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
26 |
27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
28 |
29 | # initialization
30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
31 |
32 | def forward(self, x):
33 | x1 = self.lrelu(self.conv1(x))
34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
38 | # Emperically, we use 0.2 to scale the residual for better performance
39 | return x5 * 0.2 + x
40 |
41 |
42 | class RRDB(nn.Module):
43 | """Residual in Residual Dense Block.
44 |
45 | Used in RRDB-Net in ESRGAN.
46 |
47 | Args:
48 | num_feat (int): Channel number of intermediate features.
49 | num_grow_ch (int): Channels for each growth.
50 | """
51 |
52 | def __init__(self, num_feat, num_grow_ch=32):
53 | super(RRDB, self).__init__()
54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
57 |
58 | def forward(self, x):
59 | out = self.rdb1(x)
60 | out = self.rdb2(out)
61 | out = self.rdb3(out)
62 | # Emperically, we use 0.2 to scale the residual for better performance
63 | return out * 0.2 + x
64 |
65 |
66 | @ARCH_REGISTRY.register()
67 | class RRDBNet(nn.Module):
68 | """Networks consisting of Residual in Residual Dense Block, which is used
69 | in ESRGAN.
70 |
71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
72 |
73 | We extend ESRGAN for scale x2 and scale x1.
74 | Note: This is one option for scale 1, scale 2 in RRDBNet.
75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
77 |
78 | Args:
79 | num_in_ch (int): Channel number of inputs.
80 | num_out_ch (int): Channel number of outputs.
81 | num_feat (int): Channel number of intermediate features.
82 | Default: 64
83 | num_block (int): Block number in the trunk network. Defaults: 23
84 | num_grow_ch (int): Channels for each growth. Default: 32.
85 | """
86 |
87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
88 | super(RRDBNet, self).__init__()
89 | self.scale = scale
90 | if scale == 2:
91 | num_in_ch = num_in_ch * 4
92 | elif scale == 1:
93 | num_in_ch = num_in_ch * 16
94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97 | # upsample
98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102 |
103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104 |
105 | def forward(self, x):
106 | if self.scale == 2:
107 | feat = pixel_unshuffle(x, scale=2)
108 | elif self.scale == 1:
109 | feat = pixel_unshuffle(x, scale=4)
110 | else:
111 | feat = x
112 | feat = self.conv_first(feat)
113 | body_feat = self.conv_body(self.body(feat))
114 | feat = feat + body_feat
115 | # upsample
116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")))
117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")))
118 | out = self.conv_last(self.lrelu(self.conv_hr(feat)))
119 | return out
120 |
--------------------------------------------------------------------------------
/codeformer/basicsr/archs/vgg_arch.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 |
4 | import torch
5 | from torch import nn as nn
6 | from torchvision.models import vgg as vgg
7 |
8 | from codeformer.basicsr.utils.registry import ARCH_REGISTRY
9 |
10 | VGG_PRETRAIN_PATH = "experiments/pretrained_models/vgg19-dcbb9e9d.pth"
11 | NAMES = {
12 | "vgg11": [
13 | "conv1_1",
14 | "relu1_1",
15 | "pool1",
16 | "conv2_1",
17 | "relu2_1",
18 | "pool2",
19 | "conv3_1",
20 | "relu3_1",
21 | "conv3_2",
22 | "relu3_2",
23 | "pool3",
24 | "conv4_1",
25 | "relu4_1",
26 | "conv4_2",
27 | "relu4_2",
28 | "pool4",
29 | "conv5_1",
30 | "relu5_1",
31 | "conv5_2",
32 | "relu5_2",
33 | "pool5",
34 | ],
35 | "vgg13": [
36 | "conv1_1",
37 | "relu1_1",
38 | "conv1_2",
39 | "relu1_2",
40 | "pool1",
41 | "conv2_1",
42 | "relu2_1",
43 | "conv2_2",
44 | "relu2_2",
45 | "pool2",
46 | "conv3_1",
47 | "relu3_1",
48 | "conv3_2",
49 | "relu3_2",
50 | "pool3",
51 | "conv4_1",
52 | "relu4_1",
53 | "conv4_2",
54 | "relu4_2",
55 | "pool4",
56 | "conv5_1",
57 | "relu5_1",
58 | "conv5_2",
59 | "relu5_2",
60 | "pool5",
61 | ],
62 | "vgg16": [
63 | "conv1_1",
64 | "relu1_1",
65 | "conv1_2",
66 | "relu1_2",
67 | "pool1",
68 | "conv2_1",
69 | "relu2_1",
70 | "conv2_2",
71 | "relu2_2",
72 | "pool2",
73 | "conv3_1",
74 | "relu3_1",
75 | "conv3_2",
76 | "relu3_2",
77 | "conv3_3",
78 | "relu3_3",
79 | "pool3",
80 | "conv4_1",
81 | "relu4_1",
82 | "conv4_2",
83 | "relu4_2",
84 | "conv4_3",
85 | "relu4_3",
86 | "pool4",
87 | "conv5_1",
88 | "relu5_1",
89 | "conv5_2",
90 | "relu5_2",
91 | "conv5_3",
92 | "relu5_3",
93 | "pool5",
94 | ],
95 | "vgg19": [
96 | "conv1_1",
97 | "relu1_1",
98 | "conv1_2",
99 | "relu1_2",
100 | "pool1",
101 | "conv2_1",
102 | "relu2_1",
103 | "conv2_2",
104 | "relu2_2",
105 | "pool2",
106 | "conv3_1",
107 | "relu3_1",
108 | "conv3_2",
109 | "relu3_2",
110 | "conv3_3",
111 | "relu3_3",
112 | "conv3_4",
113 | "relu3_4",
114 | "pool3",
115 | "conv4_1",
116 | "relu4_1",
117 | "conv4_2",
118 | "relu4_2",
119 | "conv4_3",
120 | "relu4_3",
121 | "conv4_4",
122 | "relu4_4",
123 | "pool4",
124 | "conv5_1",
125 | "relu5_1",
126 | "conv5_2",
127 | "relu5_2",
128 | "conv5_3",
129 | "relu5_3",
130 | "conv5_4",
131 | "relu5_4",
132 | "pool5",
133 | ],
134 | }
135 |
136 |
137 | def insert_bn(names):
138 | """Insert bn layer after each conv.
139 |
140 | Args:
141 | names (list): The list of layer names.
142 |
143 | Returns:
144 | list: The list of layer names with bn layers.
145 | """
146 | names_bn = []
147 | for name in names:
148 | names_bn.append(name)
149 | if "conv" in name:
150 | position = name.replace("conv", "")
151 | names_bn.append("bn" + position)
152 | return names_bn
153 |
154 |
155 | @ARCH_REGISTRY.register()
156 | class VGGFeatureExtractor(nn.Module):
157 | """VGG network for feature extraction.
158 |
159 | In this implementation, we allow users to choose whether use normalization
160 | in the input feature and the type of vgg network. Note that the pretrained
161 | path must fit the vgg type.
162 |
163 | Args:
164 | layer_name_list (list[str]): Forward function returns the corresponding
165 | features according to the layer_name_list.
166 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
167 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
168 | use_input_norm (bool): If True, normalize the input image. Importantly,
169 | the input feature must in the range [0, 1]. Default: True.
170 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
171 | Default: False.
172 | requires_grad (bool): If true, the parameters of VGG network will be
173 | optimized. Default: False.
174 | remove_pooling (bool): If true, the max pooling operations in VGG net
175 | will be removed. Default: False.
176 | pooling_stride (int): The stride of max pooling operation. Default: 2.
177 | """
178 |
179 | def __init__(
180 | self,
181 | layer_name_list,
182 | vgg_type="vgg19",
183 | use_input_norm=True,
184 | range_norm=False,
185 | requires_grad=False,
186 | remove_pooling=False,
187 | pooling_stride=2,
188 | ):
189 | super(VGGFeatureExtractor, self).__init__()
190 |
191 | self.layer_name_list = layer_name_list
192 | self.use_input_norm = use_input_norm
193 | self.range_norm = range_norm
194 |
195 | self.names = NAMES[vgg_type.replace("_bn", "")]
196 | if "bn" in vgg_type:
197 | self.names = insert_bn(self.names)
198 |
199 | # only borrow layers that will be used to avoid unused params
200 | max_idx = 0
201 | for v in layer_name_list:
202 | idx = self.names.index(v)
203 | if idx > max_idx:
204 | max_idx = idx
205 |
206 | if os.path.exists(VGG_PRETRAIN_PATH):
207 | vgg_net = getattr(vgg, vgg_type)(pretrained=False)
208 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
209 | vgg_net.load_state_dict(state_dict)
210 | else:
211 | vgg_net = getattr(vgg, vgg_type)(pretrained=True)
212 |
213 | features = vgg_net.features[: max_idx + 1]
214 |
215 | modified_net = OrderedDict()
216 | for k, v in zip(self.names, features):
217 | if "pool" in k:
218 | # if remove_pooling is true, pooling operation will be removed
219 | if remove_pooling:
220 | continue
221 | else:
222 | # in some cases, we may want to change the default stride
223 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
224 | else:
225 | modified_net[k] = v
226 |
227 | self.vgg_net = nn.Sequential(modified_net)
228 |
229 | if not requires_grad:
230 | self.vgg_net.eval()
231 | for param in self.parameters():
232 | param.requires_grad = False
233 | else:
234 | self.vgg_net.train()
235 | for param in self.parameters():
236 | param.requires_grad = True
237 |
238 | if self.use_input_norm:
239 | # the mean is for image with range [0, 1]
240 | self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
241 | # the std is for image with range [0, 1]
242 | self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
243 |
244 | def forward(self, x):
245 | """Forward function.
246 |
247 | Args:
248 | x (Tensor): Input tensor with shape (n, c, h, w).
249 |
250 | Returns:
251 | Tensor: Forward results.
252 | """
253 | if self.range_norm:
254 | x = (x + 1) / 2
255 | if self.use_input_norm:
256 | x = (x - self.mean) / self.std
257 | output = {}
258 |
259 | for key, layer in self.vgg_net._modules.items():
260 | x = layer(x)
261 | if key in self.layer_name_list:
262 | output[key] = x.clone()
263 |
264 | return output
265 |
--------------------------------------------------------------------------------
/codeformer/basicsr/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import random
3 | from copy import deepcopy
4 | from functools import partial
5 | from os import path as osp
6 |
7 | import numpy as np
8 | import torch
9 | import torch.utils.data
10 |
11 | from codeformer.basicsr.data.prefetch_dataloader import PrefetchDataLoader
12 | from codeformer.basicsr.utils import get_root_logger, scandir
13 | from codeformer.basicsr.utils.dist_util import get_dist_info
14 | from codeformer.basicsr.utils.registry import DATASET_REGISTRY
15 |
16 | __all__ = ["build_dataset", "build_dataloader"]
17 |
18 | # automatically scan and import dataset modules for registry
19 | # scan all the files under the data folder with '_dataset' in file names
20 | data_folder = osp.dirname(osp.abspath(__file__))
21 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith("_dataset.py")]
22 | # import all the dataset modules
23 | _dataset_modules = [importlib.import_module(f"codeformer.basicsr.data.{file_name}") for file_name in dataset_filenames]
24 |
25 |
26 | def build_dataset(dataset_opt):
27 | """Build dataset from options.
28 |
29 | Args:
30 | dataset_opt (dict): Configuration for dataset. It must constain:
31 | name (str): Dataset name.
32 | type (str): Dataset type.
33 | """
34 | dataset_opt = deepcopy(dataset_opt)
35 | dataset = DATASET_REGISTRY.get(dataset_opt["type"])(dataset_opt)
36 | logger = get_root_logger()
37 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' "is built.")
38 | return dataset
39 |
40 |
41 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
42 | """Build dataloader.
43 |
44 | Args:
45 | dataset (torch.utils.data.Dataset): Dataset.
46 | dataset_opt (dict): Dataset options. It contains the following keys:
47 | phase (str): 'train' or 'val'.
48 | num_worker_per_gpu (int): Number of workers for each GPU.
49 | batch_size_per_gpu (int): Training batch size for each GPU.
50 | num_gpu (int): Number of GPUs. Used only in the train phase.
51 | Default: 1.
52 | dist (bool): Whether in distributed training. Used only in the train
53 | phase. Default: False.
54 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
55 | seed (int | None): Seed. Default: None
56 | """
57 | phase = dataset_opt["phase"]
58 | rank, _ = get_dist_info()
59 | if phase == "train":
60 | if dist: # distributed training
61 | batch_size = dataset_opt["batch_size_per_gpu"]
62 | num_workers = dataset_opt["num_worker_per_gpu"]
63 | else: # non-distributed training
64 | multiplier = 1 if num_gpu == 0 else num_gpu
65 | batch_size = dataset_opt["batch_size_per_gpu"] * multiplier
66 | num_workers = dataset_opt["num_worker_per_gpu"] * multiplier
67 | dataloader_args = dict(
68 | dataset=dataset,
69 | batch_size=batch_size,
70 | shuffle=False,
71 | num_workers=num_workers,
72 | sampler=sampler,
73 | drop_last=True,
74 | )
75 | if sampler is None:
76 | dataloader_args["shuffle"] = True
77 | dataloader_args["worker_init_fn"] = (
78 | partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
79 | )
80 | elif phase in ["val", "test"]: # validation
81 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
82 | else:
83 | raise ValueError(f"Wrong dataset phase: {phase}. " "Supported ones are 'train', 'val' and 'test'.")
84 |
85 | dataloader_args["pin_memory"] = dataset_opt.get("pin_memory", False)
86 |
87 | prefetch_mode = dataset_opt.get("prefetch_mode")
88 | if prefetch_mode == "cpu": # CPUPrefetcher
89 | num_prefetch_queue = dataset_opt.get("num_prefetch_queue", 1)
90 | logger = get_root_logger()
91 | logger.info(f"Use {prefetch_mode} prefetch dataloader: " f"num_prefetch_queue = {num_prefetch_queue}")
92 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
93 | else:
94 | # prefetch_mode=None: Normal dataloader
95 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
96 | return torch.utils.data.DataLoader(**dataloader_args)
97 |
98 |
99 | def worker_init_fn(worker_id, num_workers, rank, seed):
100 | # Set the worker seed to num_workers * rank + worker_id + seed
101 | worker_seed = num_workers * rank + worker_id + seed
102 | np.random.seed(worker_seed)
103 | random.seed(worker_seed)
104 |
--------------------------------------------------------------------------------
/codeformer/basicsr/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.utils.data.sampler import Sampler
5 |
6 |
7 | class EnlargedSampler(Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset.
9 |
10 | Modified from torch.utils.data.distributed.DistributedSampler
11 | Support enlarging the dataset for iteration-based training, for saving
12 | time when restart the dataloader after each epoch
13 |
14 | Args:
15 | dataset (torch.utils.data.Dataset): Dataset used for sampling.
16 | num_replicas (int | None): Number of processes participating in
17 | the training. It is usually the world_size.
18 | rank (int | None): Rank of the current process within num_replicas.
19 | ratio (int): Enlarging ratio. Default: 1.
20 | """
21 |
22 | def __init__(self, dataset, num_replicas, rank, ratio=1):
23 | self.dataset = dataset
24 | self.num_replicas = num_replicas
25 | self.rank = rank
26 | self.epoch = 0
27 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
28 | self.total_size = self.num_samples * self.num_replicas
29 |
30 | def __iter__(self):
31 | # deterministically shuffle based on epoch
32 | g = torch.Generator()
33 | g.manual_seed(self.epoch)
34 | indices = torch.randperm(self.total_size, generator=g).tolist()
35 |
36 | dataset_size = len(self.dataset)
37 | indices = [v % dataset_size for v in indices]
38 |
39 | # subsample
40 | indices = indices[self.rank : self.total_size : self.num_replicas]
41 | assert len(indices) == self.num_samples
42 |
43 | return iter(indices)
44 |
45 | def __len__(self):
46 | return self.num_samples
47 |
48 | def set_epoch(self, epoch):
49 | self.epoch = epoch
50 |
--------------------------------------------------------------------------------
/codeformer/basicsr/data/prefetch_dataloader.py:
--------------------------------------------------------------------------------
1 | import queue as Queue
2 | import threading
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 |
8 | class PrefetchGenerator(threading.Thread):
9 | """A general prefetch generator.
10 |
11 | Ref:
12 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
13 |
14 | Args:
15 | generator: Python generator.
16 | num_prefetch_queue (int): Number of prefetch queue.
17 | """
18 |
19 | def __init__(self, generator, num_prefetch_queue):
20 | threading.Thread.__init__(self)
21 | self.queue = Queue.Queue(num_prefetch_queue)
22 | self.generator = generator
23 | self.daemon = True
24 | self.start()
25 |
26 | def run(self):
27 | for item in self.generator:
28 | self.queue.put(item)
29 | self.queue.put(None)
30 |
31 | def __next__(self):
32 | next_item = self.queue.get()
33 | if next_item is None:
34 | raise StopIteration
35 | return next_item
36 |
37 | def __iter__(self):
38 | return self
39 |
40 |
41 | class PrefetchDataLoader(DataLoader):
42 | """Prefetch version of dataloader.
43 |
44 | Ref:
45 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
46 |
47 | TODO:
48 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in
49 | ddp.
50 |
51 | Args:
52 | num_prefetch_queue (int): Number of prefetch queue.
53 | kwargs (dict): Other arguments for dataloader.
54 | """
55 |
56 | def __init__(self, num_prefetch_queue, **kwargs):
57 | self.num_prefetch_queue = num_prefetch_queue
58 | super(PrefetchDataLoader, self).__init__(**kwargs)
59 |
60 | def __iter__(self):
61 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
62 |
63 |
64 | class CPUPrefetcher:
65 | """CPU prefetcher.
66 |
67 | Args:
68 | loader: Dataloader.
69 | """
70 |
71 | def __init__(self, loader):
72 | self.ori_loader = loader
73 | self.loader = iter(loader)
74 |
75 | def next(self):
76 | try:
77 | return next(self.loader)
78 | except StopIteration:
79 | return None
80 |
81 | def reset(self):
82 | self.loader = iter(self.ori_loader)
83 |
84 |
85 | class CUDAPrefetcher:
86 | """CUDA prefetcher.
87 |
88 | Ref:
89 | https://github.com/NVIDIA/apex/issues/304#
90 |
91 | It may consums more GPU memory.
92 |
93 | Args:
94 | loader: Dataloader.
95 | opt (dict): Options.
96 | """
97 |
98 | def __init__(self, loader, opt):
99 | self.ori_loader = loader
100 | self.loader = iter(loader)
101 | self.opt = opt
102 | self.stream = torch.cuda.Stream()
103 | self.device = torch.device("cuda" if opt["num_gpu"] != 0 else "cpu")
104 | self.preload()
105 |
106 | def preload(self):
107 | try:
108 | self.batch = next(self.loader) # self.batch is a dict
109 | except StopIteration:
110 | self.batch = None
111 | return None
112 | # put tensors to gpu
113 | with torch.cuda.stream(self.stream):
114 | for k, v in self.batch.items():
115 | if torch.is_tensor(v):
116 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
117 |
118 | def next(self):
119 | torch.cuda.current_stream().wait_stream(self.stream)
120 | batch = self.batch
121 | self.preload()
122 | return batch
123 |
124 | def reset(self):
125 | self.loader = iter(self.ori_loader)
126 | self.preload()
127 |
--------------------------------------------------------------------------------
/codeformer/basicsr/data/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import cv2
4 |
5 |
6 | def mod_crop(img, scale):
7 | """Mod crop images, used during testing.
8 |
9 | Args:
10 | img (ndarray): Input image.
11 | scale (int): Scale factor.
12 |
13 | Returns:
14 | ndarray: Result image.
15 | """
16 | img = img.copy()
17 | if img.ndim in (2, 3):
18 | h, w = img.shape[0], img.shape[1]
19 | h_remainder, w_remainder = h % scale, w % scale
20 | img = img[: h - h_remainder, : w - w_remainder, ...]
21 | else:
22 | raise ValueError(f"Wrong img ndim: {img.ndim}.")
23 | return img
24 |
25 |
26 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
27 | """Paired random crop.
28 |
29 | It crops lists of lq and gt images with corresponding locations.
30 |
31 | Args:
32 | img_gts (list[ndarray] | ndarray): GT images. Note that all images
33 | should have the same shape. If the input is an ndarray, it will
34 | be transformed to a list containing itself.
35 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
36 | should have the same shape. If the input is an ndarray, it will
37 | be transformed to a list containing itself.
38 | gt_patch_size (int): GT patch size.
39 | scale (int): Scale factor.
40 | gt_path (str): Path to ground-truth.
41 |
42 | Returns:
43 | list[ndarray] | ndarray: GT images and LQ images. If returned results
44 | only have one element, just return ndarray.
45 | """
46 |
47 | if not isinstance(img_gts, list):
48 | img_gts = [img_gts]
49 | if not isinstance(img_lqs, list):
50 | img_lqs = [img_lqs]
51 |
52 | h_lq, w_lq, _ = img_lqs[0].shape
53 | h_gt, w_gt, _ = img_gts[0].shape
54 | lq_patch_size = gt_patch_size // scale
55 |
56 | if h_gt != h_lq * scale or w_gt != w_lq * scale:
57 | raise ValueError(
58 | f"Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ", f"multiplication of LQ ({h_lq}, {w_lq})."
59 | )
60 | if h_lq < lq_patch_size or w_lq < lq_patch_size:
61 | raise ValueError(
62 | f"LQ ({h_lq}, {w_lq}) is smaller than patch size "
63 | f"({lq_patch_size}, {lq_patch_size}). "
64 | f"Please remove {gt_path}."
65 | )
66 |
67 | # randomly choose top and left coordinates for lq patch
68 | top = random.randint(0, h_lq - lq_patch_size)
69 | left = random.randint(0, w_lq - lq_patch_size)
70 |
71 | # crop lq patch
72 | img_lqs = [v[top : top + lq_patch_size, left : left + lq_patch_size, ...] for v in img_lqs]
73 |
74 | # crop corresponding gt patch
75 | top_gt, left_gt = int(top * scale), int(left * scale)
76 | img_gts = [v[top_gt : top_gt + gt_patch_size, left_gt : left_gt + gt_patch_size, ...] for v in img_gts]
77 | if len(img_gts) == 1:
78 | img_gts = img_gts[0]
79 | if len(img_lqs) == 1:
80 | img_lqs = img_lqs[0]
81 | return img_gts, img_lqs
82 |
83 |
84 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
85 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
86 |
87 | We use vertical flip and transpose for rotation implementation.
88 | All the images in the list use the same augmentation.
89 |
90 | Args:
91 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input
92 | is an ndarray, it will be transformed to a list.
93 | hflip (bool): Horizontal flip. Default: True.
94 | rotation (bool): Ratotation. Default: True.
95 | flows (list[ndarray]: Flows to be augmented. If the input is an
96 | ndarray, it will be transformed to a list.
97 | Dimension is (h, w, 2). Default: None.
98 | return_status (bool): Return the status of flip and rotation.
99 | Default: False.
100 |
101 | Returns:
102 | list[ndarray] | ndarray: Augmented images and flows. If returned
103 | results only have one element, just return ndarray.
104 |
105 | """
106 | hflip = hflip and random.random() < 0.5
107 | vflip = rotation and random.random() < 0.5
108 | rot90 = rotation and random.random() < 0.5
109 |
110 | def _augment(img):
111 | if hflip: # horizontal
112 | cv2.flip(img, 1, img)
113 | if vflip: # vertical
114 | cv2.flip(img, 0, img)
115 | if rot90:
116 | img = img.transpose(1, 0, 2)
117 | return img
118 |
119 | def _augment_flow(flow):
120 | if hflip: # horizontal
121 | cv2.flip(flow, 1, flow)
122 | flow[:, :, 0] *= -1
123 | if vflip: # vertical
124 | cv2.flip(flow, 0, flow)
125 | flow[:, :, 1] *= -1
126 | if rot90:
127 | flow = flow.transpose(1, 0, 2)
128 | flow = flow[:, :, [1, 0]]
129 | return flow
130 |
131 | if not isinstance(imgs, list):
132 | imgs = [imgs]
133 | imgs = [_augment(img) for img in imgs]
134 | if len(imgs) == 1:
135 | imgs = imgs[0]
136 |
137 | if flows is not None:
138 | if not isinstance(flows, list):
139 | flows = [flows]
140 | flows = [_augment_flow(flow) for flow in flows]
141 | if len(flows) == 1:
142 | flows = flows[0]
143 | return imgs, flows
144 | else:
145 | if return_status:
146 | return imgs, (hflip, vflip, rot90)
147 | else:
148 | return imgs
149 |
150 |
151 | def img_rotate(img, angle, center=None, scale=1.0):
152 | """Rotate image.
153 |
154 | Args:
155 | img (ndarray): Image to be rotated.
156 | angle (float): Rotation angle in degrees. Positive values mean
157 | counter-clockwise rotation.
158 | center (tuple[int]): Rotation center. If the center is None,
159 | initialize it as the center of the image. Default: None.
160 | scale (float): Isotropic scale factor. Default: 1.0.
161 | """
162 | (h, w) = img.shape[:2]
163 |
164 | if center is None:
165 | center = (w // 2, h // 2)
166 |
167 | matrix = cv2.getRotationMatrix2D(center, angle, scale)
168 | rotated_img = cv2.warpAffine(img, matrix, (w, h))
169 | return rotated_img
170 |
--------------------------------------------------------------------------------
/codeformer/basicsr/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | from codeformer.basicsr.losses.losses import (
4 | CharbonnierLoss,
5 | GANLoss,
6 | L1Loss,
7 | MSELoss,
8 | PerceptualLoss,
9 | WeightedTVLoss,
10 | g_path_regularize,
11 | gradient_penalty_loss,
12 | r1_penalty,
13 | )
14 | from codeformer.basicsr.utils import get_root_logger
15 | from codeformer.basicsr.utils.registry import LOSS_REGISTRY
16 |
17 | __all__ = [
18 | "L1Loss",
19 | "MSELoss",
20 | "CharbonnierLoss",
21 | "WeightedTVLoss",
22 | "PerceptualLoss",
23 | "GANLoss",
24 | "gradient_penalty_loss",
25 | "r1_penalty",
26 | "g_path_regularize",
27 | ]
28 |
29 |
30 | def build_loss(opt):
31 | """Build loss from options.
32 |
33 | Args:
34 | opt (dict): Configuration. It must constain:
35 | type (str): Model type.
36 | """
37 | opt = deepcopy(opt)
38 | loss_type = opt.pop("type")
39 | loss = LOSS_REGISTRY.get(loss_type)(**opt)
40 | logger = get_root_logger()
41 | logger.info(f"Loss [{loss.__class__.__name__}] is created.")
42 | return loss
43 |
--------------------------------------------------------------------------------
/codeformer/basicsr/losses/loss_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | from torch.nn import functional as F
4 |
5 |
6 | def reduce_loss(loss, reduction):
7 | """Reduce loss as specified.
8 |
9 | Args:
10 | loss (Tensor): Elementwise loss tensor.
11 | reduction (str): Options are 'none', 'mean' and 'sum'.
12 |
13 | Returns:
14 | Tensor: Reduced loss tensor.
15 | """
16 | reduction_enum = F._Reduction.get_enum(reduction)
17 | # none: 0, elementwise_mean:1, sum: 2
18 | if reduction_enum == 0:
19 | return loss
20 | elif reduction_enum == 1:
21 | return loss.mean()
22 | else:
23 | return loss.sum()
24 |
25 |
26 | def weight_reduce_loss(loss, weight=None, reduction="mean"):
27 | """Apply element-wise weight and reduce loss.
28 |
29 | Args:
30 | loss (Tensor): Element-wise loss.
31 | weight (Tensor): Element-wise weights. Default: None.
32 | reduction (str): Same as built-in losses of PyTorch. Options are
33 | 'none', 'mean' and 'sum'. Default: 'mean'.
34 |
35 | Returns:
36 | Tensor: Loss values.
37 | """
38 | # if weight is specified, apply element-wise weight
39 | if weight is not None:
40 | assert weight.dim() == loss.dim()
41 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
42 | loss = loss * weight
43 |
44 | # if weight is not specified or reduction is sum, just reduce the loss
45 | if weight is None or reduction == "sum":
46 | loss = reduce_loss(loss, reduction)
47 | # if reduction is mean, then compute mean over weight region
48 | elif reduction == "mean":
49 | if weight.size(1) > 1:
50 | weight = weight.sum()
51 | else:
52 | weight = weight.sum() * loss.size(1)
53 | loss = loss.sum() / weight
54 |
55 | return loss
56 |
57 |
58 | def weighted_loss(loss_func):
59 | """Create a weighted version of a given loss function.
60 |
61 | To use this decorator, the loss function must have the signature like
62 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
63 | element-wise loss without any reduction. This decorator will add weight
64 | and reduction arguments to the function. The decorated function will have
65 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
66 | **kwargs)`.
67 |
68 | :Example:
69 |
70 | >>> import torch
71 | >>> @weighted_loss
72 | >>> def l1_loss(pred, target):
73 | >>> return (pred - target).abs()
74 |
75 | >>> pred = torch.Tensor([0, 2, 3])
76 | >>> target = torch.Tensor([1, 1, 1])
77 | >>> weight = torch.Tensor([1, 0, 1])
78 |
79 | >>> l1_loss(pred, target)
80 | tensor(1.3333)
81 | >>> l1_loss(pred, target, weight)
82 | tensor(1.5000)
83 | >>> l1_loss(pred, target, reduction='none')
84 | tensor([1., 1., 2.])
85 | >>> l1_loss(pred, target, weight, reduction='sum')
86 | tensor(3.)
87 | """
88 |
89 | @functools.wraps(loss_func)
90 | def wrapper(pred, target, weight=None, reduction="mean", **kwargs):
91 | # get element-wise loss
92 | loss = loss_func(pred, target, **kwargs)
93 | loss = weight_reduce_loss(loss, weight, reduction)
94 | return loss
95 |
96 | return wrapper
97 |
--------------------------------------------------------------------------------
/codeformer/basicsr/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | from codeformer.basicsr.metrics.psnr_ssim import calculate_psnr, calculate_ssim
4 | from codeformer.basicsr.utils.registry import METRIC_REGISTRY
5 |
6 | __all__ = ["calculate_psnr", "calculate_ssim"]
7 |
8 |
9 | def calculate_metric(data, opt):
10 | """Calculate metric from data and options.
11 |
12 | Args:
13 | opt (dict): Configuration. It must constain:
14 | type (str): Model type.
15 | """
16 | opt = deepcopy(opt)
17 | metric_type = opt.pop("type")
18 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
19 | return metric
20 |
--------------------------------------------------------------------------------
/codeformer/basicsr/metrics/metric_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from codeformer.basicsr.utils.matlab_functions import bgr2ycbcr
4 |
5 |
6 | def reorder_image(img, input_order="HWC"):
7 | """Reorder images to 'HWC' order.
8 |
9 | If the input_order is (h, w), return (h, w, 1);
10 | If the input_order is (c, h, w), return (h, w, c);
11 | If the input_order is (h, w, c), return as it is.
12 |
13 | Args:
14 | img (ndarray): Input image.
15 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
16 | If the input image shape is (h, w), input_order will not have
17 | effects. Default: 'HWC'.
18 |
19 | Returns:
20 | ndarray: reordered image.
21 | """
22 |
23 | if input_order not in ["HWC", "CHW"]:
24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " "'HWC' and 'CHW'")
25 | if len(img.shape) == 2:
26 | img = img[..., None]
27 | if input_order == "CHW":
28 | img = img.transpose(1, 2, 0)
29 | return img
30 |
31 |
32 | def to_y_channel(img):
33 | """Change to Y channel of YCbCr.
34 |
35 | Args:
36 | img (ndarray): Images with range [0, 255].
37 |
38 | Returns:
39 | (ndarray): Images with range [0, 255] (float type) without round.
40 | """
41 | img = img.astype(np.float32) / 255.0
42 | if img.ndim == 3 and img.shape[2] == 3:
43 | img = bgr2ycbcr(img, y_only=True)
44 | img = img[..., None]
45 | return img * 255.0
46 |
--------------------------------------------------------------------------------
/codeformer/basicsr/metrics/psnr_ssim.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | from codeformer.basicsr.metrics.metric_util import reorder_image, to_y_channel
5 | from codeformer.basicsr.utils.registry import METRIC_REGISTRY
6 |
7 |
8 | @METRIC_REGISTRY.register()
9 | def calculate_psnr(img1, img2, crop_border, input_order="HWC", test_y_channel=False):
10 | """Calculate PSNR (Peak Signal-to-Noise Ratio).
11 |
12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
13 |
14 | Args:
15 | img1 (ndarray): Images with range [0, 255].
16 | img2 (ndarray): Images with range [0, 255].
17 | crop_border (int): Cropped pixels in each edge of an image. These
18 | pixels are not involved in the PSNR calculation.
19 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
20 | Default: 'HWC'.
21 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
22 |
23 | Returns:
24 | float: psnr result.
25 | """
26 |
27 | assert img1.shape == img2.shape, f"Image shapes are differnet: {img1.shape}, {img2.shape}."
28 | if input_order not in ["HWC", "CHW"]:
29 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " '"HWC" and "CHW"')
30 | img1 = reorder_image(img1, input_order=input_order)
31 | img2 = reorder_image(img2, input_order=input_order)
32 | img1 = img1.astype(np.float64)
33 | img2 = img2.astype(np.float64)
34 |
35 | if crop_border != 0:
36 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
38 |
39 | if test_y_channel:
40 | img1 = to_y_channel(img1)
41 | img2 = to_y_channel(img2)
42 |
43 | mse = np.mean((img1 - img2) ** 2)
44 | if mse == 0:
45 | return float("inf")
46 | return 20.0 * np.log10(255.0 / np.sqrt(mse))
47 |
48 |
49 | def _ssim(img1, img2):
50 | """Calculate SSIM (structural similarity) for one channel images.
51 |
52 | It is called by func:`calculate_ssim`.
53 |
54 | Args:
55 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
56 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
57 |
58 | Returns:
59 | float: ssim result.
60 | """
61 |
62 | C1 = (0.01 * 255) ** 2
63 | C2 = (0.03 * 255) ** 2
64 |
65 | img1 = img1.astype(np.float64)
66 | img2 = img2.astype(np.float64)
67 | kernel = cv2.getGaussianKernel(11, 1.5)
68 | window = np.outer(kernel, kernel.transpose())
69 |
70 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
71 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
72 | mu1_sq = mu1 ** 2
73 | mu2_sq = mu2 ** 2
74 | mu1_mu2 = mu1 * mu2
75 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
76 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
77 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
78 |
79 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
80 | return ssim_map.mean()
81 |
82 |
83 | @METRIC_REGISTRY.register()
84 | def calculate_ssim(img1, img2, crop_border, input_order="HWC", test_y_channel=False):
85 | """Calculate SSIM (structural similarity).
86 |
87 | Ref:
88 | Image quality assessment: From error visibility to structural similarity
89 |
90 | The results are the same as that of the official released MATLAB code in
91 | https://ece.uwaterloo.ca/~z70wang/research/ssim/.
92 |
93 | For three-channel images, SSIM is calculated for each channel and then
94 | averaged.
95 |
96 | Args:
97 | img1 (ndarray): Images with range [0, 255].
98 | img2 (ndarray): Images with range [0, 255].
99 | crop_border (int): Cropped pixels in each edge of an image. These
100 | pixels are not involved in the SSIM calculation.
101 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
102 | Default: 'HWC'.
103 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
104 |
105 | Returns:
106 | float: ssim result.
107 | """
108 |
109 | assert img1.shape == img2.shape, f"Image shapes are differnet: {img1.shape}, {img2.shape}."
110 | if input_order not in ["HWC", "CHW"]:
111 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " '"HWC" and "CHW"')
112 | img1 = reorder_image(img1, input_order=input_order)
113 | img2 = reorder_image(img2, input_order=input_order)
114 | img1 = img1.astype(np.float64)
115 | img2 = img2.astype(np.float64)
116 |
117 | if crop_border != 0:
118 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
119 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
120 |
121 | if test_y_channel:
122 | img1 = to_y_channel(img1)
123 | img2 = to_y_channel(img2)
124 |
125 | ssims = []
126 | for i in range(img1.shape[2]):
127 | ssims.append(_ssim(img1[..., i], img2[..., i]))
128 | return np.array(ssims).mean()
129 |
--------------------------------------------------------------------------------
/codeformer/basicsr/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from codeformer.basicsr.utils import get_root_logger, scandir
6 | from codeformer.basicsr.utils.registry import MODEL_REGISTRY
7 |
8 | __all__ = ["build_model"]
9 |
10 | # automatically scan and import model modules for registry
11 | # scan all the files under the 'models' folder and collect files ending with
12 | # '_model.py'
13 | model_folder = osp.dirname(osp.abspath(__file__))
14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith("_model.py")]
15 | # import all the model modules
16 | _model_modules = [importlib.import_module(f"codeformer.basicsr.models.{file_name}") for file_name in model_filenames]
17 |
18 |
19 | def build_model(opt):
20 | """Build model from options.
21 |
22 | Args:
23 | opt (dict): Configuration. It must constain:
24 | model_type (str): Model type.
25 | """
26 | opt = deepcopy(opt)
27 | model = MODEL_REGISTRY.get(opt["model_type"])(opt)
28 | logger = get_root_logger()
29 | logger.info(f"Model [{model.__class__.__name__}] is created.")
30 | return model
31 |
--------------------------------------------------------------------------------
/codeformer/basicsr/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/basicsr/ops/__init__.py
--------------------------------------------------------------------------------
/codeformer/basicsr/ops/dcn/__init__.py:
--------------------------------------------------------------------------------
1 | from codeformer.basicsr.ops.dcn.deform_conv import (
2 | DeformConv,
3 | DeformConvPack,
4 | ModulatedDeformConv,
5 | ModulatedDeformConvPack,
6 | deform_conv,
7 | modulated_deform_conv,
8 | )
9 |
10 | __all__ = [
11 | "DeformConv",
12 | "DeformConvPack",
13 | "ModulatedDeformConv",
14 | "ModulatedDeformConvPack",
15 | "deform_conv",
16 | "modulated_deform_conv",
17 | ]
18 |
--------------------------------------------------------------------------------
/codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp:
--------------------------------------------------------------------------------
1 | // modify from
2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
3 |
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 |
10 | #define WITH_CUDA // always use cuda
11 | #ifdef WITH_CUDA
12 | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
13 | at::Tensor offset, at::Tensor output,
14 | at::Tensor columns, at::Tensor ones, int kW,
15 | int kH, int dW, int dH, int padW, int padH,
16 | int dilationW, int dilationH, int group,
17 | int deformable_group, int im2col_step);
18 |
19 | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
20 | at::Tensor gradOutput, at::Tensor gradInput,
21 | at::Tensor gradOffset, at::Tensor weight,
22 | at::Tensor columns, int kW, int kH, int dW,
23 | int dH, int padW, int padH, int dilationW,
24 | int dilationH, int group,
25 | int deformable_group, int im2col_step);
26 |
27 | int deform_conv_backward_parameters_cuda(
28 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
29 | at::Tensor gradWeight, // at::Tensor gradBias,
30 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
31 | int padW, int padH, int dilationW, int dilationH, int group,
32 | int deformable_group, float scale, int im2col_step);
33 |
34 | void modulated_deform_conv_cuda_forward(
35 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
36 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
37 | int kernel_h, int kernel_w, const int stride_h, const int stride_w,
38 | const int pad_h, const int pad_w, const int dilation_h,
39 | const int dilation_w, const int group, const int deformable_group,
40 | const bool with_bias);
41 |
42 | void modulated_deform_conv_cuda_backward(
43 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
44 | at::Tensor offset, at::Tensor mask, at::Tensor columns,
45 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
46 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
47 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
48 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
49 | const bool with_bias);
50 | #endif
51 |
52 | int deform_conv_forward(at::Tensor input, at::Tensor weight,
53 | at::Tensor offset, at::Tensor output,
54 | at::Tensor columns, at::Tensor ones, int kW,
55 | int kH, int dW, int dH, int padW, int padH,
56 | int dilationW, int dilationH, int group,
57 | int deformable_group, int im2col_step) {
58 | if (input.device().is_cuda()) {
59 | #ifdef WITH_CUDA
60 | return deform_conv_forward_cuda(input, weight, offset, output, columns,
61 | ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
62 | deformable_group, im2col_step);
63 | #else
64 | AT_ERROR("deform conv is not compiled with GPU support");
65 | #endif
66 | }
67 | AT_ERROR("deform conv is not implemented on CPU");
68 | }
69 |
70 | int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
71 | at::Tensor gradOutput, at::Tensor gradInput,
72 | at::Tensor gradOffset, at::Tensor weight,
73 | at::Tensor columns, int kW, int kH, int dW,
74 | int dH, int padW, int padH, int dilationW,
75 | int dilationH, int group,
76 | int deformable_group, int im2col_step) {
77 | if (input.device().is_cuda()) {
78 | #ifdef WITH_CUDA
79 | return deform_conv_backward_input_cuda(input, offset, gradOutput,
80 | gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
81 | dilationW, dilationH, group, deformable_group, im2col_step);
82 | #else
83 | AT_ERROR("deform conv is not compiled with GPU support");
84 | #endif
85 | }
86 | AT_ERROR("deform conv is not implemented on CPU");
87 | }
88 |
89 | int deform_conv_backward_parameters(
90 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
91 | at::Tensor gradWeight, // at::Tensor gradBias,
92 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
93 | int padW, int padH, int dilationW, int dilationH, int group,
94 | int deformable_group, float scale, int im2col_step) {
95 | if (input.device().is_cuda()) {
96 | #ifdef WITH_CUDA
97 | return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
98 | gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
99 | dilationH, group, deformable_group, scale, im2col_step);
100 | #else
101 | AT_ERROR("deform conv is not compiled with GPU support");
102 | #endif
103 | }
104 | AT_ERROR("deform conv is not implemented on CPU");
105 | }
106 |
107 | void modulated_deform_conv_forward(
108 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
109 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
110 | int kernel_h, int kernel_w, const int stride_h, const int stride_w,
111 | const int pad_h, const int pad_w, const int dilation_h,
112 | const int dilation_w, const int group, const int deformable_group,
113 | const bool with_bias) {
114 | if (input.device().is_cuda()) {
115 | #ifdef WITH_CUDA
116 | return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
117 | offset, mask, output, columns, kernel_h, kernel_w, stride_h,
118 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
119 | deformable_group, with_bias);
120 | #else
121 | AT_ERROR("modulated deform conv is not compiled with GPU support");
122 | #endif
123 | }
124 | AT_ERROR("modulated deform conv is not implemented on CPU");
125 | }
126 |
127 | void modulated_deform_conv_backward(
128 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
129 | at::Tensor offset, at::Tensor mask, at::Tensor columns,
130 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
131 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
132 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
133 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
134 | const bool with_bias) {
135 | if (input.device().is_cuda()) {
136 | #ifdef WITH_CUDA
137 | return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
138 | offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
139 | grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
140 | pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
141 | with_bias);
142 | #else
143 | AT_ERROR("modulated deform conv is not compiled with GPU support");
144 | #endif
145 | }
146 | AT_ERROR("modulated deform conv is not implemented on CPU");
147 | }
148 |
149 |
150 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
151 | m.def("deform_conv_forward", &deform_conv_forward,
152 | "deform forward");
153 | m.def("deform_conv_backward_input", &deform_conv_backward_input,
154 | "deform_conv_backward_input");
155 | m.def("deform_conv_backward_parameters",
156 | &deform_conv_backward_parameters,
157 | "deform_conv_backward_parameters");
158 | m.def("modulated_deform_conv_forward",
159 | &modulated_deform_conv_forward,
160 | "modulated deform conv forward");
161 | m.def("modulated_deform_conv_backward",
162 | &modulated_deform_conv_backward,
163 | "modulated deform conv backward");
164 | }
165 |
--------------------------------------------------------------------------------
/codeformer/basicsr/ops/fused_act/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 |
3 | __all__ = ["FusedLeakyReLU", "fused_leaky_relu"]
4 |
--------------------------------------------------------------------------------
/codeformer/basicsr/ops/fused_act/fused_act.py:
--------------------------------------------------------------------------------
1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
2 |
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Function
6 |
7 | try:
8 | from . import fused_act_ext
9 | except ImportError:
10 | import os
11 |
12 | BASICSR_JIT = os.getenv("BASICSR_JIT")
13 | if BASICSR_JIT == "True":
14 | from torch.utils.cpp_extension import load
15 |
16 | module_path = os.path.dirname(__file__)
17 | fused_act_ext = load(
18 | "fused",
19 | sources=[
20 | os.path.join(module_path, "src", "fused_bias_act.cpp"),
21 | os.path.join(module_path, "src", "fused_bias_act_kernel.cu"),
22 | ],
23 | )
24 |
25 |
26 | class FusedLeakyReLUFunctionBackward(Function):
27 | @staticmethod
28 | def forward(ctx, grad_output, out, negative_slope, scale):
29 | ctx.save_for_backward(out)
30 | ctx.negative_slope = negative_slope
31 | ctx.scale = scale
32 |
33 | empty = grad_output.new_empty(0)
34 |
35 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
36 |
37 | dim = [0]
38 |
39 | if grad_input.ndim > 2:
40 | dim += list(range(2, grad_input.ndim))
41 |
42 | grad_bias = grad_input.sum(dim).detach()
43 |
44 | return grad_input, grad_bias
45 |
46 | @staticmethod
47 | def backward(ctx, gradgrad_input, gradgrad_bias):
48 | (out,) = ctx.saved_tensors
49 | gradgrad_out = fused_act_ext.fused_bias_act(
50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
51 | )
52 |
53 | return gradgrad_out, None, None, None
54 |
55 |
56 | class FusedLeakyReLUFunction(Function):
57 | @staticmethod
58 | def forward(ctx, input, bias, negative_slope, scale):
59 | empty = input.new_empty(0)
60 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
61 | ctx.save_for_backward(out)
62 | ctx.negative_slope = negative_slope
63 | ctx.scale = scale
64 |
65 | return out
66 |
67 | @staticmethod
68 | def backward(ctx, grad_output):
69 | (out,) = ctx.saved_tensors
70 |
71 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
72 |
73 | return grad_input, grad_bias, None, None
74 |
75 |
76 | class FusedLeakyReLU(nn.Module):
77 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
78 | super().__init__()
79 |
80 | self.bias = nn.Parameter(torch.zeros(channel))
81 | self.negative_slope = negative_slope
82 | self.scale = scale
83 |
84 | def forward(self, input):
85 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
86 |
87 |
88 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
89 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
90 |
--------------------------------------------------------------------------------
/codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
2 | #include
3 |
4 |
5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input,
6 | const torch::Tensor& bias,
7 | const torch::Tensor& refer,
8 | int act, int grad, float alpha, float scale);
9 |
10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
13 |
14 | torch::Tensor fused_bias_act(const torch::Tensor& input,
15 | const torch::Tensor& bias,
16 | const torch::Tensor& refer,
17 | int act, int grad, float alpha, float scale) {
18 | CHECK_CUDA(input);
19 | CHECK_CUDA(bias);
20 |
21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
22 | }
23 |
24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
26 | }
27 |
--------------------------------------------------------------------------------
/codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
3 | //
4 | // This work is made available under the Nvidia Source Code License-NC.
5 | // To view a copy of this license, visit
6 | // https://nvlabs.github.io/stylegan2/license.html
7 |
8 | #include
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 |
15 | #include
16 | #include
17 |
18 |
19 | template
20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
23 |
24 | scalar_t zero = 0.0;
25 |
26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
27 | scalar_t x = p_x[xi];
28 |
29 | if (use_bias) {
30 | x += p_b[(xi / step_b) % size_b];
31 | }
32 |
33 | scalar_t ref = use_ref ? p_ref[xi] : zero;
34 |
35 | scalar_t y;
36 |
37 | switch (act * 10 + grad) {
38 | default:
39 | case 10: y = x; break;
40 | case 11: y = x; break;
41 | case 12: y = 0.0; break;
42 |
43 | case 30: y = (x > 0.0) ? x : x * alpha; break;
44 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
45 | case 32: y = 0.0; break;
46 | }
47 |
48 | out[xi] = y * scale;
49 | }
50 | }
51 |
52 |
53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
54 | int act, int grad, float alpha, float scale) {
55 | int curDevice = -1;
56 | cudaGetDevice(&curDevice);
57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
58 |
59 | auto x = input.contiguous();
60 | auto b = bias.contiguous();
61 | auto ref = refer.contiguous();
62 |
63 | int use_bias = b.numel() ? 1 : 0;
64 | int use_ref = ref.numel() ? 1 : 0;
65 |
66 | int size_x = x.numel();
67 | int size_b = b.numel();
68 | int step_b = 1;
69 |
70 | for (int i = 1 + 1; i < x.dim(); i++) {
71 | step_b *= x.size(i);
72 | }
73 |
74 | int loop_x = 4;
75 | int block_size = 4 * 32;
76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
77 |
78 | auto y = torch::empty_like(x);
79 |
80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
81 | fused_bias_act_kernel<<>>(
82 | y.data_ptr(),
83 | x.data_ptr(),
84 | b.data_ptr(),
85 | ref.data_ptr(),
86 | act,
87 | grad,
88 | alpha,
89 | scale,
90 | loop_x,
91 | size_x,
92 | step_b,
93 | size_b,
94 | use_bias,
95 | use_ref
96 | );
97 | });
98 |
99 | return y;
100 | }
101 |
--------------------------------------------------------------------------------
/codeformer/basicsr/ops/upfirdn2d/__init__.py:
--------------------------------------------------------------------------------
1 | from .upfirdn2d import upfirdn2d
2 |
3 | __all__ = ["upfirdn2d"]
4 |
--------------------------------------------------------------------------------
/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
2 | #include
3 |
4 |
5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
6 | int up_x, int up_y, int down_x, int down_y,
7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
8 |
9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
12 |
13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
14 | int up_x, int up_y, int down_x, int down_y,
15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
16 | CHECK_CUDA(input);
17 | CHECK_CUDA(kernel);
18 |
19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
20 | }
21 |
22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
24 | }
25 |
--------------------------------------------------------------------------------
/codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
2 |
3 | import torch
4 | from torch.autograd import Function
5 | from torch.nn import functional as F
6 |
7 | try:
8 | from . import upfirdn2d_ext
9 | except ImportError:
10 | import os
11 |
12 | BASICSR_JIT = os.getenv("BASICSR_JIT")
13 | if BASICSR_JIT == "True":
14 | from torch.utils.cpp_extension import load
15 |
16 | module_path = os.path.dirname(__file__)
17 | upfirdn2d_ext = load(
18 | "upfirdn2d",
19 | sources=[
20 | os.path.join(module_path, "src", "upfirdn2d.cpp"),
21 | os.path.join(module_path, "src", "upfirdn2d_kernel.cu"),
22 | ],
23 | )
24 |
25 |
26 | class UpFirDn2dBackward(Function):
27 | @staticmethod
28 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
29 |
30 | up_x, up_y = up
31 | down_x, down_y = down
32 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
33 |
34 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
35 |
36 | grad_input = upfirdn2d_ext.upfirdn2d(
37 | grad_output,
38 | grad_kernel,
39 | down_x,
40 | down_y,
41 | up_x,
42 | up_y,
43 | g_pad_x0,
44 | g_pad_x1,
45 | g_pad_y0,
46 | g_pad_y1,
47 | )
48 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
49 |
50 | ctx.save_for_backward(kernel)
51 |
52 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
53 |
54 | ctx.up_x = up_x
55 | ctx.up_y = up_y
56 | ctx.down_x = down_x
57 | ctx.down_y = down_y
58 | ctx.pad_x0 = pad_x0
59 | ctx.pad_x1 = pad_x1
60 | ctx.pad_y0 = pad_y0
61 | ctx.pad_y1 = pad_y1
62 | ctx.in_size = in_size
63 | ctx.out_size = out_size
64 |
65 | return grad_input
66 |
67 | @staticmethod
68 | def backward(ctx, gradgrad_input):
69 | (kernel,) = ctx.saved_tensors
70 |
71 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
72 |
73 | gradgrad_out = upfirdn2d_ext.upfirdn2d(
74 | gradgrad_input,
75 | kernel,
76 | ctx.up_x,
77 | ctx.up_y,
78 | ctx.down_x,
79 | ctx.down_y,
80 | ctx.pad_x0,
81 | ctx.pad_x1,
82 | ctx.pad_y0,
83 | ctx.pad_y1,
84 | )
85 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
86 | # ctx.out_size[1], ctx.in_size[3])
87 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
88 |
89 | return gradgrad_out, None, None, None, None, None, None, None, None
90 |
91 |
92 | class UpFirDn2d(Function):
93 | @staticmethod
94 | def forward(ctx, input, kernel, up, down, pad):
95 | up_x, up_y = up
96 | down_x, down_y = down
97 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
98 |
99 | kernel_h, kernel_w = kernel.shape
100 | batch, channel, in_h, in_w = input.shape
101 | ctx.in_size = input.shape
102 |
103 | input = input.reshape(-1, in_h, in_w, 1)
104 |
105 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
106 |
107 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
108 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
109 | ctx.out_size = (out_h, out_w)
110 |
111 | ctx.up = (up_x, up_y)
112 | ctx.down = (down_x, down_y)
113 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
114 |
115 | g_pad_x0 = kernel_w - pad_x0 - 1
116 | g_pad_y0 = kernel_h - pad_y0 - 1
117 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
118 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
119 |
120 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
121 |
122 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
123 | # out = out.view(major, out_h, out_w, minor)
124 | out = out.view(-1, channel, out_h, out_w)
125 |
126 | return out
127 |
128 | @staticmethod
129 | def backward(ctx, grad_output):
130 | kernel, grad_kernel = ctx.saved_tensors
131 |
132 | grad_input = UpFirDn2dBackward.apply(
133 | grad_output,
134 | kernel,
135 | grad_kernel,
136 | ctx.up,
137 | ctx.down,
138 | ctx.pad,
139 | ctx.g_pad,
140 | ctx.in_size,
141 | ctx.out_size,
142 | )
143 |
144 | return grad_input, None, None, None, None
145 |
146 |
147 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
148 | if input.device.type == "cpu":
149 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
150 | else:
151 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
152 |
153 | return out
154 |
155 |
156 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
157 | _, channel, in_h, in_w = input.shape
158 | input = input.reshape(-1, in_h, in_w, 1)
159 |
160 | _, in_h, in_w, minor = input.shape
161 | kernel_h, kernel_w = kernel.shape
162 |
163 | out = input.view(-1, in_h, 1, in_w, 1, minor)
164 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
165 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
166 |
167 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
168 | out = out[
169 | :,
170 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
171 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
172 | :,
173 | ]
174 |
175 | out = out.permute(0, 3, 1, 2)
176 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
177 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
178 | out = F.conv2d(out, w)
179 | out = out.reshape(
180 | -1,
181 | minor,
182 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
183 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
184 | )
185 | out = out.permute(0, 2, 3, 1)
186 | out = out[:, ::down_y, ::down_x, :]
187 |
188 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
189 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
190 |
191 | return out.view(-1, channel, out_h, out_w)
192 |
--------------------------------------------------------------------------------
/codeformer/basicsr/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import subprocess
5 | import sys
6 | import time
7 |
8 | from setuptools import find_packages, setup
9 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
10 | from utils.misc import gpu_is_available
11 |
12 | version_file = ".codeformer/basicsr/version.py"
13 |
14 |
15 | def readme():
16 | with open("README.md", encoding="utf-8") as f:
17 | content = f.read()
18 | return content
19 |
20 |
21 | def get_git_hash():
22 | def _minimal_ext_cmd(cmd):
23 | # construct minimal environment
24 | env = {}
25 | for k in ["SYSTEMROOT", "PATH", "HOME"]:
26 | v = os.environ.get(k)
27 | if v is not None:
28 | env[k] = v
29 | # LANGUAGE is used on win32
30 | env["LANGUAGE"] = "C"
31 | env["LANG"] = "C"
32 | env["LC_ALL"] = "C"
33 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
34 | return out
35 |
36 | try:
37 | out = _minimal_ext_cmd(["git", "rev-parse", "HEAD"])
38 | sha = out.strip().decode("ascii")
39 | except OSError:
40 | sha = "unknown"
41 |
42 | return sha
43 |
44 |
45 | def get_hash():
46 | if os.path.exists(".git"):
47 | sha = get_git_hash()[:7]
48 | elif os.path.exists(version_file):
49 | try:
50 | from version import __version__
51 |
52 | sha = __version__.split("+")[-1]
53 | except ImportError:
54 | raise ImportError("Unable to get git version")
55 | else:
56 | sha = "unknown"
57 |
58 | return sha
59 |
60 |
61 | def write_version_py():
62 | content = """# GENERATED VERSION FILE
63 | # TIME: {}
64 | __version__ = '{}'
65 | __gitsha__ = '{}'
66 | version_info = ({})
67 | """
68 | sha = get_hash()
69 | with open("./basicsr/VERSION", "r") as f:
70 | SHORT_VERSION = f.read().strip()
71 | VERSION_INFO = ", ".join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split(".")])
72 |
73 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
74 | with open(version_file, "w") as f:
75 | f.write(version_file_str)
76 |
77 |
78 | def get_version():
79 | with open(version_file, "r") as f:
80 | exec(compile(f.read(), version_file, "exec"))
81 | return locals()["__version__"]
82 |
83 |
84 | def make_cuda_ext(name, module, sources, sources_cuda=None):
85 | if sources_cuda is None:
86 | sources_cuda = []
87 | define_macros = []
88 | extra_compile_args = {"cxx": []}
89 |
90 | # if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
91 | if gpu_is_available or os.getenv("FORCE_CUDA", "0") == "1":
92 | define_macros += [("WITH_CUDA", None)]
93 | extension = CUDAExtension
94 | extra_compile_args["nvcc"] = [
95 | "-D__CUDA_NO_HALF_OPERATORS__",
96 | "-D__CUDA_NO_HALF_CONVERSIONS__",
97 | "-D__CUDA_NO_HALF2_OPERATORS__",
98 | ]
99 | sources += sources_cuda
100 | else:
101 | print(f"Compiling {name} without CUDA")
102 | extension = CppExtension
103 |
104 | return extension(
105 | name=f"{module}.{name}",
106 | sources=[os.path.join(*module.split("."), p) for p in sources],
107 | define_macros=define_macros,
108 | extra_compile_args=extra_compile_args,
109 | )
110 |
111 |
112 | def get_requirements(filename="requirements.txt"):
113 | with open(os.path.join(".", filename), "r") as f:
114 | requires = [line.replace("\n", "") for line in f.readlines()]
115 | return requires
116 |
117 |
118 | if __name__ == "__main__":
119 | if "--cuda_ext" in sys.argv:
120 | ext_modules = [
121 | make_cuda_ext(
122 | name="deform_conv_ext",
123 | module="ops.dcn",
124 | sources=["src/deform_conv_ext.cpp"],
125 | sources_cuda=["src/deform_conv_cuda.cpp", "src/deform_conv_cuda_kernel.cu"],
126 | ),
127 | make_cuda_ext(
128 | name="fused_act_ext",
129 | module="ops.fused_act",
130 | sources=["src/fused_bias_act.cpp"],
131 | sources_cuda=["src/fused_bias_act_kernel.cu"],
132 | ),
133 | make_cuda_ext(
134 | name="upfirdn2d_ext",
135 | module="ops.upfirdn2d",
136 | sources=["src/upfirdn2d.cpp"],
137 | sources_cuda=["src/upfirdn2d_kernel.cu"],
138 | ),
139 | ]
140 | sys.argv.remove("--cuda_ext")
141 | else:
142 | ext_modules = []
143 |
144 | write_version_py()
145 | setup(
146 | name="basicsr",
147 | version=get_version(),
148 | description="Open Source Image and Video Super-Resolution Toolbox",
149 | long_description=readme(),
150 | long_description_content_type="text/markdown",
151 | author="Xintao Wang",
152 | author_email="xintao.wang@outlook.com",
153 | keywords="computer vision, restoration, super resolution",
154 | url="https://github.com/xinntao/BasicSR",
155 | include_package_data=True,
156 | packages=find_packages(exclude=("options", "datasets", "experiments", "results", "tb_logger", "wandb")),
157 | classifiers=[
158 | "Development Status :: 4 - Beta",
159 | "License :: OSI Approved :: Apache Software License",
160 | "Operating System :: OS Independent",
161 | "Programming Language :: Python :: 3",
162 | "Programming Language :: Python :: 3.7",
163 | "Programming Language :: Python :: 3.8",
164 | ],
165 | license="Apache License 2.0",
166 | setup_requires=["cython", "numpy"],
167 | install_requires=get_requirements(),
168 | ext_modules=ext_modules,
169 | cmdclass={"build_ext": BuildExtension},
170 | zip_safe=False,
171 | )
172 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from codeformer.basicsr.utils.file_client import FileClient
2 | from codeformer.basicsr.utils.img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
3 | from codeformer.basicsr.utils.logger import (
4 | MessageLogger,
5 | get_env_info,
6 | get_root_logger,
7 | init_tb_logger,
8 | init_wandb_logger,
9 | )
10 | from codeformer.basicsr.utils.misc import *
11 | from codeformer.basicsr.utils.misc import (
12 | check_resume,
13 | get_time_str,
14 | make_exp_dirs,
15 | mkdir_and_rename,
16 | scandir,
17 | set_random_seed,
18 | sizeof_fmt,
19 | )
20 | from codeformer.basicsr.utils.realesrgan_utils import RealESRGANer
21 |
22 | __all__ = [
23 | # file_client.py
24 | "FileClient",
25 | # img_util.py
26 | "img2tensor",
27 | "tensor2img",
28 | "imfrombytes",
29 | "imwrite",
30 | "crop_border",
31 | # logger.py
32 | "MessageLogger",
33 | "init_tb_logger",
34 | "init_wandb_logger",
35 | "get_root_logger",
36 | "get_env_info",
37 | # misc.py
38 | "set_random_seed",
39 | "get_time_str",
40 | "mkdir_and_rename",
41 | "make_exp_dirs",
42 | "scandir",
43 | "check_resume",
44 | "sizeof_fmt",
45 | ]
46 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2 | import functools
3 | import os
4 | import subprocess
5 |
6 | import torch
7 | import torch.distributed as dist
8 | import torch.multiprocessing as mp
9 |
10 |
11 | def init_dist(launcher, backend="nccl", **kwargs):
12 | if mp.get_start_method(allow_none=True) is None:
13 | mp.set_start_method("spawn")
14 | if launcher == "pytorch":
15 | _init_dist_pytorch(backend, **kwargs)
16 | elif launcher == "slurm":
17 | _init_dist_slurm(backend, **kwargs)
18 | else:
19 | raise ValueError(f"Invalid launcher type: {launcher}")
20 |
21 |
22 | def _init_dist_pytorch(backend, **kwargs):
23 | rank = int(os.environ["RANK"])
24 | num_gpus = torch.cuda.device_count()
25 | torch.cuda.set_device(rank % num_gpus)
26 | dist.init_process_group(backend=backend, **kwargs)
27 |
28 |
29 | def _init_dist_slurm(backend, port=None):
30 | """Initialize slurm distributed training environment.
31 |
32 | If argument ``port`` is not specified, then the master port will be system
33 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
34 | environment variable, then a default port ``29500`` will be used.
35 |
36 | Args:
37 | backend (str): Backend of torch.distributed.
38 | port (int, optional): Master port. Defaults to None.
39 | """
40 | proc_id = int(os.environ["SLURM_PROCID"])
41 | ntasks = int(os.environ["SLURM_NTASKS"])
42 | node_list = os.environ["SLURM_NODELIST"]
43 | num_gpus = torch.cuda.device_count()
44 | torch.cuda.set_device(proc_id % num_gpus)
45 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
46 | # specify master port
47 | if port is not None:
48 | os.environ["MASTER_PORT"] = str(port)
49 | elif "MASTER_PORT" in os.environ:
50 | pass # use MASTER_PORT in the environment variable
51 | else:
52 | # 29500 is torch.distributed default port
53 | os.environ["MASTER_PORT"] = "29500"
54 | os.environ["MASTER_ADDR"] = addr
55 | os.environ["WORLD_SIZE"] = str(ntasks)
56 | os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
57 | os.environ["RANK"] = str(proc_id)
58 | dist.init_process_group(backend=backend)
59 |
60 |
61 | def get_dist_info():
62 | if dist.is_available():
63 | initialized = dist.is_initialized()
64 | else:
65 | initialized = False
66 | if initialized:
67 | rank = dist.get_rank()
68 | world_size = dist.get_world_size()
69 | else:
70 | rank = 0
71 | world_size = 1
72 | return rank, world_size
73 |
74 |
75 | def master_only(func):
76 | @functools.wraps(func)
77 | def wrapper(*args, **kwargs):
78 | rank, _ = get_dist_info()
79 | if rank == 0:
80 | return func(*args, **kwargs)
81 |
82 | return wrapper
83 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/download_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from urllib.parse import urlparse
4 |
5 | import requests
6 | from torch.hub import download_url_to_file, get_dir
7 | from tqdm import tqdm
8 |
9 | from codeformer.basicsr.utils.misc import sizeof_fmt
10 |
11 |
12 | def download_file_from_google_drive(file_id, save_path):
13 | """Download files from google drive.
14 | Ref:
15 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
16 | Args:
17 | file_id (str): File id.
18 | save_path (str): Save path.
19 | """
20 |
21 | session = requests.Session()
22 | URL = "https://docs.google.com/uc?export=download"
23 | params = {"id": file_id}
24 |
25 | response = session.get(URL, params=params, stream=True)
26 | token = get_confirm_token(response)
27 | if token:
28 | params["confirm"] = token
29 | response = session.get(URL, params=params, stream=True)
30 |
31 | # get file size
32 | response_file_size = session.get(URL, params=params, stream=True, headers={"Range": "bytes=0-2"})
33 | print(response_file_size)
34 | if "Content-Range" in response_file_size.headers:
35 | file_size = int(response_file_size.headers["Content-Range"].split("/")[1])
36 | else:
37 | file_size = None
38 |
39 | save_response_content(response, save_path, file_size)
40 |
41 |
42 | def get_confirm_token(response):
43 | for key, value in response.cookies.items():
44 | if key.startswith("download_warning"):
45 | return value
46 | return None
47 |
48 |
49 | def save_response_content(response, destination, file_size=None, chunk_size=32768):
50 | if file_size is not None:
51 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit="chunk")
52 |
53 | readable_file_size = sizeof_fmt(file_size)
54 | else:
55 | pbar = None
56 |
57 | with open(destination, "wb") as f:
58 | downloaded_size = 0
59 | for chunk in response.iter_content(chunk_size):
60 | downloaded_size += chunk_size
61 | if pbar is not None:
62 | pbar.update(1)
63 | pbar.set_description(f"Download {sizeof_fmt(downloaded_size)} / {readable_file_size}")
64 | if chunk: # filter out keep-alive new chunks
65 | f.write(chunk)
66 | if pbar is not None:
67 | pbar.close()
68 |
69 |
70 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
71 | """Load file form http url, will download models if necessary.
72 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
73 | Args:
74 | url (str): URL to be downloaded.
75 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
76 | Default: None.
77 | progress (bool): Whether to show the download progress. Default: True.
78 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
79 | Returns:
80 | str: The path to the downloaded file.
81 | """
82 | if model_dir is None: # use the pytorch hub_dir
83 | hub_dir = get_dir()
84 | model_dir = os.path.join(hub_dir, "checkpoints")
85 |
86 | os.makedirs(model_dir, exist_ok=True)
87 |
88 | parts = urlparse(url)
89 | filename = os.path.basename(parts.path)
90 | if file_name is not None:
91 | filename = file_name
92 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
93 | if not os.path.exists(cached_file):
94 | print(f'Downloading: "{url}" to {cached_file}\n')
95 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
96 | return cached_file
97 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/file_client.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class BaseStorageBackend(metaclass=ABCMeta):
6 | """Abstract class of storage backends.
7 |
8 | All backends need to implement two apis: ``get()`` and ``get_text()``.
9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10 | as texts.
11 | """
12 |
13 | @abstractmethod
14 | def get(self, filepath):
15 | pass
16 |
17 | @abstractmethod
18 | def get_text(self, filepath):
19 | pass
20 |
21 |
22 | class MemcachedBackend(BaseStorageBackend):
23 | """Memcached storage backend.
24 |
25 | Attributes:
26 | server_list_cfg (str): Config file for memcached server list.
27 | client_cfg (str): Config file for memcached client.
28 | sys_path (str | None): Additional path to be appended to `sys.path`.
29 | Default: None.
30 | """
31 |
32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33 | if sys_path is not None:
34 | import sys
35 |
36 | sys.path.append(sys_path)
37 | try:
38 | import mc
39 | except ImportError:
40 | raise ImportError("Please install memcached to enable MemcachedBackend.")
41 |
42 | self.server_list_cfg = server_list_cfg
43 | self.client_cfg = client_cfg
44 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
45 | # mc.pyvector servers as a point which points to a memory cache
46 | self._mc_buffer = mc.pyvector()
47 |
48 | def get(self, filepath):
49 | filepath = str(filepath)
50 | import mc
51 |
52 | self._client.Get(filepath, self._mc_buffer)
53 | value_buf = mc.ConvertBuffer(self._mc_buffer)
54 | return value_buf
55 |
56 | def get_text(self, filepath):
57 | raise NotImplementedError
58 |
59 |
60 | class HardDiskBackend(BaseStorageBackend):
61 | """Raw hard disks storage backend."""
62 |
63 | def get(self, filepath):
64 | filepath = str(filepath)
65 | with open(filepath, "rb") as f:
66 | value_buf = f.read()
67 | return value_buf
68 |
69 | def get_text(self, filepath):
70 | filepath = str(filepath)
71 | with open(filepath, "r") as f:
72 | value_buf = f.read()
73 | return value_buf
74 |
75 |
76 | class LmdbBackend(BaseStorageBackend):
77 | """Lmdb storage backend.
78 |
79 | Args:
80 | db_paths (str | list[str]): Lmdb database paths.
81 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
82 | readonly (bool, optional): Lmdb environment parameter. If True,
83 | disallow any write operations. Default: True.
84 | lock (bool, optional): Lmdb environment parameter. If False, when
85 | concurrent access occurs, do not lock the database. Default: False.
86 | readahead (bool, optional): Lmdb environment parameter. If False,
87 | disable the OS filesystem readahead mechanism, which may improve
88 | random read performance when a database is larger than RAM.
89 | Default: False.
90 |
91 | Attributes:
92 | db_paths (list): Lmdb database path.
93 | _client (list): A list of several lmdb envs.
94 | """
95 |
96 | def __init__(self, db_paths, client_keys="default", readonly=True, lock=False, readahead=False, **kwargs):
97 | try:
98 | import lmdb
99 | except ImportError:
100 | raise ImportError("Please install lmdb to enable LmdbBackend.")
101 |
102 | if isinstance(client_keys, str):
103 | client_keys = [client_keys]
104 |
105 | if isinstance(db_paths, list):
106 | self.db_paths = [str(v) for v in db_paths]
107 | elif isinstance(db_paths, str):
108 | self.db_paths = [str(db_paths)]
109 | assert len(client_keys) == len(self.db_paths), (
110 | "client_keys and db_paths should have the same length, "
111 | f"but received {len(client_keys)} and {len(self.db_paths)}."
112 | )
113 |
114 | self._client = {}
115 | for client, path in zip(client_keys, self.db_paths):
116 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
117 |
118 | def get(self, filepath, client_key):
119 | """Get values according to the filepath from one lmdb named client_key.
120 |
121 | Args:
122 | filepath (str | obj:`Path`): Here, filepath is the lmdb key.
123 | client_key (str): Used for distinguishing differnet lmdb envs.
124 | """
125 | filepath = str(filepath)
126 | assert client_key in self._client, f"client_key {client_key} is not " "in lmdb clients."
127 | client = self._client[client_key]
128 | with client.begin(write=False) as txn:
129 | value_buf = txn.get(filepath.encode("ascii"))
130 | return value_buf
131 |
132 | def get_text(self, filepath):
133 | raise NotImplementedError
134 |
135 |
136 | class FileClient(object):
137 | """A general file client to access files in different backend.
138 |
139 | The client loads a file or text in a specified backend from its path
140 | and return it as a binary file. it can also register other backend
141 | accessor with a given name and backend class.
142 |
143 | Attributes:
144 | backend (str): The storage backend type. Options are "disk",
145 | "memcached" and "lmdb".
146 | client (:obj:`BaseStorageBackend`): The backend object.
147 | """
148 |
149 | _backends = {
150 | "disk": HardDiskBackend,
151 | "memcached": MemcachedBackend,
152 | "lmdb": LmdbBackend,
153 | }
154 |
155 | def __init__(self, backend="disk", **kwargs):
156 | if backend not in self._backends:
157 | raise ValueError(
158 | f"Backend {backend} is not supported. Currently supported ones" f" are {list(self._backends.keys())}"
159 | )
160 | self.backend = backend
161 | self.client = self._backends[backend](**kwargs)
162 |
163 | def get(self, filepath, client_key="default"):
164 | # client_key is used only for lmdb, where different fileclients have
165 | # different lmdb environments.
166 | if self.backend == "lmdb":
167 | return self.client.get(filepath, client_key)
168 | else:
169 | return self.client.get(filepath)
170 |
171 | def get_text(self, filepath):
172 | return self.client.get_text(filepath)
173 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/img_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | from torchvision.utils import make_grid
8 |
9 |
10 | def img2tensor(imgs, bgr2rgb=True, float32=True):
11 | """Numpy array to tensor.
12 |
13 | Args:
14 | imgs (list[ndarray] | ndarray): Input images.
15 | bgr2rgb (bool): Whether to change bgr to rgb.
16 | float32 (bool): Whether to change to float32.
17 |
18 | Returns:
19 | list[tensor] | tensor: Tensor images. If returned results only have
20 | one element, just return tensor.
21 | """
22 |
23 | def _totensor(img, bgr2rgb, float32):
24 | if img.shape[2] == 3 and bgr2rgb:
25 | if img.dtype == "float64":
26 | img = img.astype("float32")
27 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
28 | img = torch.from_numpy(img.transpose(2, 0, 1))
29 | if float32:
30 | img = img.float()
31 | return img
32 |
33 | if isinstance(imgs, list):
34 | return [_totensor(img, bgr2rgb, float32) for img in imgs]
35 | else:
36 | return _totensor(imgs, bgr2rgb, float32)
37 |
38 |
39 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
40 | """Convert torch Tensors into image numpy arrays.
41 |
42 | After clamping to [min, max], values will be normalized to [0, 1].
43 |
44 | Args:
45 | tensor (Tensor or list[Tensor]): Accept shapes:
46 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
47 | 2) 3D Tensor of shape (3/1 x H x W);
48 | 3) 2D Tensor of shape (H x W).
49 | Tensor channel should be in RGB order.
50 | rgb2bgr (bool): Whether to change rgb to bgr.
51 | out_type (numpy type): output types. If ``np.uint8``, transform outputs
52 | to uint8 type with range [0, 255]; otherwise, float type with
53 | range [0, 1]. Default: ``np.uint8``.
54 | min_max (tuple[int]): min and max values for clamp.
55 |
56 | Returns:
57 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
58 | shape (H x W). The channel order is BGR.
59 | """
60 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
61 | raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
62 |
63 | if torch.is_tensor(tensor):
64 | tensor = [tensor]
65 | result = []
66 | for _tensor in tensor:
67 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
68 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
69 |
70 | n_dim = _tensor.dim()
71 | if n_dim == 4:
72 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
73 | img_np = img_np.transpose(1, 2, 0)
74 | if rgb2bgr:
75 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
76 | elif n_dim == 3:
77 | img_np = _tensor.numpy()
78 | img_np = img_np.transpose(1, 2, 0)
79 | if img_np.shape[2] == 1: # gray image
80 | img_np = np.squeeze(img_np, axis=2)
81 | else:
82 | if rgb2bgr:
83 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
84 | elif n_dim == 2:
85 | img_np = _tensor.numpy()
86 | else:
87 | raise TypeError("Only support 4D, 3D or 2D tensor. " f"But received with dimension: {n_dim}")
88 | if out_type == np.uint8:
89 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
90 | img_np = (img_np * 255.0).round()
91 | img_np = img_np.astype(out_type)
92 | result.append(img_np)
93 | if len(result) == 1:
94 | result = result[0]
95 | return result
96 |
97 |
98 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
99 | """This implementation is slightly faster than tensor2img.
100 | It now only supports torch tensor with shape (1, c, h, w).
101 |
102 | Args:
103 | tensor (Tensor): Now only support torch tensor with (1, c, h, w).
104 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
105 | min_max (tuple[int]): min and max values for clamp.
106 | """
107 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
108 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
109 | output = output.type(torch.uint8).cpu().numpy()
110 | if rgb2bgr:
111 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
112 | return output
113 |
114 |
115 | def imfrombytes(content, flag="color", float32=False):
116 | """Read an image from bytes.
117 |
118 | Args:
119 | content (bytes): Image bytes got from files or other streams.
120 | flag (str): Flags specifying the color type of a loaded image,
121 | candidates are `color`, `grayscale` and `unchanged`.
122 | float32 (bool): Whether to change to float32., If True, will also norm
123 | to [0, 1]. Default: False.
124 |
125 | Returns:
126 | ndarray: Loaded image array.
127 | """
128 | img_np = np.frombuffer(content, np.uint8)
129 | imread_flags = {"color": cv2.IMREAD_COLOR, "grayscale": cv2.IMREAD_GRAYSCALE, "unchanged": cv2.IMREAD_UNCHANGED}
130 | img = cv2.imdecode(img_np, imread_flags[flag])
131 | if float32:
132 | img = img.astype(np.float32) / 255.0
133 | return img
134 |
135 |
136 | def imwrite(img, file_path, params=None, auto_mkdir=True):
137 | """Write image to file.
138 |
139 | Args:
140 | img (ndarray): Image array to be written.
141 | file_path (str): Image file path.
142 | params (None or list): Same as opencv's :func:`imwrite` interface.
143 | auto_mkdir (bool): If the parent folder of `file_path` does not exist,
144 | whether to create it automatically.
145 |
146 | Returns:
147 | bool: Successful or not.
148 | """
149 | if auto_mkdir:
150 | dir_name = os.path.abspath(os.path.dirname(file_path))
151 | os.makedirs(dir_name, exist_ok=True)
152 | return cv2.imwrite(file_path, img, params)
153 |
154 |
155 | def crop_border(imgs, crop_border):
156 | """Crop borders of images.
157 |
158 | Args:
159 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
160 | crop_border (int): Crop border for each end of height and weight.
161 |
162 | Returns:
163 | list[ndarray]: Cropped images.
164 | """
165 | if crop_border == 0:
166 | return imgs
167 | else:
168 | if isinstance(imgs, list):
169 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
170 | else:
171 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
172 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/lmdb_util.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from multiprocessing import Pool
3 | from os import path as osp
4 |
5 | import cv2
6 | import lmdb
7 | from tqdm import tqdm
8 |
9 |
10 | def make_lmdb_from_imgs(
11 | data_path,
12 | lmdb_path,
13 | img_path_list,
14 | keys,
15 | batch=5000,
16 | compress_level=1,
17 | multiprocessing_read=False,
18 | n_thread=40,
19 | map_size=None,
20 | ):
21 | """Make lmdb from images.
22 |
23 | Contents of lmdb. The file structure is:
24 | example.lmdb
25 | ├── data.mdb
26 | ├── lock.mdb
27 | ├── meta_info.txt
28 |
29 | The data.mdb and lock.mdb are standard lmdb files and you can refer to
30 | https://lmdb.readthedocs.io/en/release/ for more details.
31 |
32 | The meta_info.txt is a specified txt file to record the meta information
33 | of our datasets. It will be automatically created when preparing
34 | datasets by our provided dataset tools.
35 | Each line in the txt file records 1)image name (with extension),
36 | 2)image shape, and 3)compression level, separated by a white space.
37 |
38 | For example, the meta information could be:
39 | `000_00000000.png (720,1280,3) 1`, which means:
40 | 1) image name (with extension): 000_00000000.png;
41 | 2) image shape: (720,1280,3);
42 | 3) compression level: 1
43 |
44 | We use the image name without extension as the lmdb key.
45 |
46 | If `multiprocessing_read` is True, it will read all the images to memory
47 | using multiprocessing. Thus, your server needs to have enough memory.
48 |
49 | Args:
50 | data_path (str): Data path for reading images.
51 | lmdb_path (str): Lmdb save path.
52 | img_path_list (str): Image path list.
53 | keys (str): Used for lmdb keys.
54 | batch (int): After processing batch images, lmdb commits.
55 | Default: 5000.
56 | compress_level (int): Compress level when encoding images. Default: 1.
57 | multiprocessing_read (bool): Whether use multiprocessing to read all
58 | the images to memory. Default: False.
59 | n_thread (int): For multiprocessing.
60 | map_size (int | None): Map size for lmdb env. If None, use the
61 | estimated size from images. Default: None
62 | """
63 |
64 | assert len(img_path_list) == len(keys), (
65 | "img_path_list and keys should have the same length, " f"but got {len(img_path_list)} and {len(keys)}"
66 | )
67 | print(f"Create lmdb for {data_path}, save to {lmdb_path}...")
68 | print(f"Totoal images: {len(img_path_list)}")
69 | if not lmdb_path.endswith(".lmdb"):
70 | raise ValueError("lmdb_path must end with '.lmdb'.")
71 | if osp.exists(lmdb_path):
72 | print(f"Folder {lmdb_path} already exists. Exit.")
73 | sys.exit(1)
74 |
75 | if multiprocessing_read:
76 | # read all the images to memory (multiprocessing)
77 | dataset = {} # use dict to keep the order for multiprocessing
78 | shapes = {}
79 | print(f"Read images with multiprocessing, #thread: {n_thread} ...")
80 | pbar = tqdm(total=len(img_path_list), unit="image")
81 |
82 | def callback(arg):
83 | """get the image data and update pbar."""
84 | key, dataset[key], shapes[key] = arg
85 | pbar.update(1)
86 | pbar.set_description(f"Read {key}")
87 |
88 | pool = Pool(n_thread)
89 | for path, key in zip(img_path_list, keys):
90 | pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
91 | pool.close()
92 | pool.join()
93 | pbar.close()
94 | print(f"Finish reading {len(img_path_list)} images.")
95 |
96 | # create lmdb environment
97 | if map_size is None:
98 | # obtain data size for one image
99 | img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
100 | _, img_byte = cv2.imencode(".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
101 | data_size_per_img = img_byte.nbytes
102 | print("Data size per image is: ", data_size_per_img)
103 | data_size = data_size_per_img * len(img_path_list)
104 | map_size = data_size * 10
105 |
106 | env = lmdb.open(lmdb_path, map_size=map_size)
107 |
108 | # write data to lmdb
109 | pbar = tqdm(total=len(img_path_list), unit="chunk")
110 | txn = env.begin(write=True)
111 | txt_file = open(osp.join(lmdb_path, "meta_info.txt"), "w")
112 | for idx, (path, key) in enumerate(zip(img_path_list, keys)):
113 | pbar.update(1)
114 | pbar.set_description(f"Write {key}")
115 | key_byte = key.encode("ascii")
116 | if multiprocessing_read:
117 | img_byte = dataset[key]
118 | h, w, c = shapes[key]
119 | else:
120 | _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
121 | h, w, c = img_shape
122 |
123 | txn.put(key_byte, img_byte)
124 | # write meta information
125 | txt_file.write(f"{key}.png ({h},{w},{c}) {compress_level}\n")
126 | if idx % batch == 0:
127 | txn.commit()
128 | txn = env.begin(write=True)
129 | pbar.close()
130 | txn.commit()
131 | env.close()
132 | txt_file.close()
133 | print("\nFinish writing lmdb.")
134 |
135 |
136 | def read_img_worker(path, key, compress_level):
137 | """Read image worker.
138 |
139 | Args:
140 | path (str): Image path.
141 | key (str): Image key.
142 | compress_level (int): Compress level when encoding images.
143 |
144 | Returns:
145 | str: Image key.
146 | byte: Image byte.
147 | tuple[int]: Image shape.
148 | """
149 |
150 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
151 | if img.ndim == 2:
152 | h, w = img.shape
153 | c = 1
154 | else:
155 | h, w, c = img.shape
156 | _, img_byte = cv2.imencode(".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
157 | return (key, img_byte, (h, w, c))
158 |
159 |
160 | class LmdbMaker:
161 | """LMDB Maker.
162 |
163 | Args:
164 | lmdb_path (str): Lmdb save path.
165 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
166 | batch (int): After processing batch images, lmdb commits.
167 | Default: 5000.
168 | compress_level (int): Compress level when encoding images. Default: 1.
169 | """
170 |
171 | def __init__(self, lmdb_path, map_size=1024 ** 4, batch=5000, compress_level=1):
172 | if not lmdb_path.endswith(".lmdb"):
173 | raise ValueError("lmdb_path must end with '.lmdb'.")
174 | if osp.exists(lmdb_path):
175 | print(f"Folder {lmdb_path} already exists. Exit.")
176 | sys.exit(1)
177 |
178 | self.lmdb_path = lmdb_path
179 | self.batch = batch
180 | self.compress_level = compress_level
181 | self.env = lmdb.open(lmdb_path, map_size=map_size)
182 | self.txn = self.env.begin(write=True)
183 | self.txt_file = open(osp.join(lmdb_path, "meta_info.txt"), "w")
184 | self.counter = 0
185 |
186 | def put(self, img_byte, key, img_shape):
187 | self.counter += 1
188 | key_byte = key.encode("ascii")
189 | self.txn.put(key_byte, img_byte)
190 | # write meta information
191 | h, w, c = img_shape
192 | self.txt_file.write(f"{key}.png ({h},{w},{c}) {self.compress_level}\n")
193 | if self.counter % self.batch == 0:
194 | self.txn.commit()
195 | self.txn = self.env.begin(write=True)
196 |
197 | def close(self):
198 | self.txn.commit()
199 | self.env.close()
200 | self.txt_file.close()
201 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 |
5 | from codeformer.basicsr.utils.dist_util import get_dist_info, master_only
6 |
7 | initialized_logger = {}
8 |
9 |
10 | class MessageLogger:
11 | """Message logger for printing.
12 | Args:
13 | opt (dict): Config. It contains the following keys:
14 | name (str): Exp name.
15 | logger (dict): Contains 'print_freq' (str) for logger interval.
16 | train (dict): Contains 'total_iter' (int) for total iters.
17 | use_tb_logger (bool): Use tensorboard logger.
18 | start_iter (int): Start iter. Default: 1.
19 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
20 | """
21 |
22 | def __init__(self, opt, start_iter=1, tb_logger=None):
23 | self.exp_name = opt["name"]
24 | self.interval = opt["logger"]["print_freq"]
25 | self.start_iter = start_iter
26 | self.max_iters = opt["train"]["total_iter"]
27 | self.use_tb_logger = opt["logger"]["use_tb_logger"]
28 | self.tb_logger = tb_logger
29 | self.start_time = time.time()
30 | self.logger = get_root_logger()
31 |
32 | @master_only
33 | def __call__(self, log_vars):
34 | """Format logging message.
35 | Args:
36 | log_vars (dict): It contains the following keys:
37 | epoch (int): Epoch number.
38 | iter (int): Current iter.
39 | lrs (list): List for learning rates.
40 | time (float): Iter time.
41 | data_time (float): Data time for each iter.
42 | """
43 | # epoch, iter, learning rates
44 | epoch = log_vars.pop("epoch")
45 | current_iter = log_vars.pop("iter")
46 | lrs = log_vars.pop("lrs")
47 |
48 | message = f"[{self.exp_name[:5]}..][epoch:{epoch:3d}, " f"iter:{current_iter:8,d}, lr:("
49 | for v in lrs:
50 | message += f"{v:.3e},"
51 | message += ")] "
52 |
53 | # time and estimated time
54 | if "time" in log_vars.keys():
55 | iter_time = log_vars.pop("time")
56 | data_time = log_vars.pop("data_time")
57 |
58 | total_time = time.time() - self.start_time
59 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
60 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
61 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
62 | message += f"[eta: {eta_str}, "
63 | message += f"time (data): {iter_time:.3f} ({data_time:.3f})] "
64 |
65 | # other items, especially losses
66 | for k, v in log_vars.items():
67 | message += f"{k}: {v:.4e} "
68 | # tensorboard logger
69 | if self.use_tb_logger:
70 | if k.startswith("l_"):
71 | self.tb_logger.add_scalar(f"losses/{k}", v, current_iter)
72 | else:
73 | self.tb_logger.add_scalar(k, v, current_iter)
74 | self.logger.info(message)
75 |
76 |
77 | @master_only
78 | def init_tb_logger(log_dir):
79 | from torch.utils.tensorboard import SummaryWriter
80 |
81 | tb_logger = SummaryWriter(log_dir=log_dir)
82 | return tb_logger
83 |
84 |
85 | @master_only
86 | def init_wandb_logger(opt):
87 | """We now only use wandb to sync tensorboard log."""
88 | import wandb
89 |
90 | logger = logging.getLogger("basicsr")
91 |
92 | project = opt["logger"]["wandb"]["project"]
93 | resume_id = opt["logger"]["wandb"].get("resume_id")
94 | if resume_id:
95 | wandb_id = resume_id
96 | resume = "allow"
97 | logger.warning(f"Resume wandb logger with id={wandb_id}.")
98 | else:
99 | wandb_id = wandb.util.generate_id()
100 | resume = "never"
101 |
102 | wandb.init(id=wandb_id, resume=resume, name=opt["name"], config=opt, project=project, sync_tensorboard=True)
103 |
104 | logger.info(f"Use wandb logger with id={wandb_id}; project={project}.")
105 |
106 |
107 | def get_root_logger(logger_name="basicsr", log_level=logging.INFO, log_file=None):
108 | """Get the root logger.
109 | The logger will be initialized if it has not been initialized. By default a
110 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
111 | also be added.
112 | Args:
113 | logger_name (str): root logger name. Default: 'basicsr'.
114 | log_file (str | None): The log filename. If specified, a FileHandler
115 | will be added to the root logger.
116 | log_level (int): The root logger level. Note that only the process of
117 | rank 0 is affected, while other processes will set the level to
118 | "Error" and be silent most of the time.
119 | Returns:
120 | logging.Logger: The root logger.
121 | """
122 | logger = logging.getLogger(logger_name)
123 | # if the logger has been initialized, just return it
124 | if logger_name in initialized_logger:
125 | return logger
126 |
127 | format_str = "%(asctime)s %(levelname)s: %(message)s"
128 | stream_handler = logging.StreamHandler()
129 | stream_handler.setFormatter(logging.Formatter(format_str))
130 | logger.addHandler(stream_handler)
131 | logger.propagate = False
132 | rank, _ = get_dist_info()
133 | if rank != 0:
134 | logger.setLevel("ERROR")
135 | elif log_file is not None:
136 | logger.setLevel(log_level)
137 | # add file handler
138 | # file_handler = logging.FileHandler(log_file, 'w')
139 | file_handler = logging.FileHandler(log_file, "a") # Shangchen: keep the previous log
140 | file_handler.setFormatter(logging.Formatter(format_str))
141 | file_handler.setLevel(log_level)
142 | logger.addHandler(file_handler)
143 | initialized_logger[logger_name] = True
144 | return logger
145 |
146 |
147 | def get_env_info():
148 | """Get environment information.
149 | Currently, only log the software version.
150 | """
151 | import torch
152 | import torchvision
153 | from basicsr.version import __version__
154 |
155 | msg = r"""
156 | ____ _ _____ ____
157 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \
158 | / __ |/ __ `// ___// // ___/\__ \ / /_/ /
159 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
160 | /_____/ \__,_//____//_/ \___//____//_/ |_|
161 | ______ __ __ __ __
162 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
163 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
164 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
165 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
166 | """
167 | msg += (
168 | "\nVersion Information: "
169 | f"\n\tBasicSR: {__version__}"
170 | f"\n\tPyTorch: {torch.__version__}"
171 | f"\n\tTorchVision: {torchvision.__version__}"
172 | )
173 | return msg
174 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import re
4 | import time
5 | from os import path as osp
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from codeformer.basicsr.utils.dist_util import master_only
11 | from codeformer.basicsr.utils.logger import get_root_logger
12 |
13 | IS_HIGH_VERSION = [
14 | int(m)
15 | for m in list(
16 | re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$", torch.__version__)[0][:3]
17 | )
18 | ] >= [1, 12, 0]
19 |
20 |
21 | def gpu_is_available():
22 | if IS_HIGH_VERSION:
23 | if torch.backends.mps.is_available():
24 | return True
25 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
26 |
27 |
28 | def get_device(gpu_id=None):
29 | if gpu_id is None:
30 | gpu_str = ""
31 | elif isinstance(gpu_id, int):
32 | gpu_str = f":{gpu_id}"
33 | else:
34 | raise TypeError("Input should be int value.")
35 |
36 | if IS_HIGH_VERSION:
37 | if torch.backends.mps.is_available():
38 | return torch.device("mps" + gpu_str)
39 | return torch.device(
40 | "cuda" + gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else "cpu"
41 | )
42 |
43 |
44 | def set_random_seed(seed):
45 | """Set random seeds."""
46 | random.seed(seed)
47 | np.random.seed(seed)
48 | torch.manual_seed(seed)
49 | torch.cuda.manual_seed(seed)
50 | torch.cuda.manual_seed_all(seed)
51 |
52 |
53 | def get_time_str():
54 | return time.strftime("%Y%m%d_%H%M%S", time.localtime())
55 |
56 |
57 | def mkdir_and_rename(path):
58 | """mkdirs. If path exists, rename it with timestamp and create a new one.
59 |
60 | Args:
61 | path (str): Folder path.
62 | """
63 | if osp.exists(path):
64 | new_name = path + "_archived_" + get_time_str()
65 | print(f"Path already exists. Rename it to {new_name}", flush=True)
66 | os.rename(path, new_name)
67 | os.makedirs(path, exist_ok=True)
68 |
69 |
70 | @master_only
71 | def make_exp_dirs(opt):
72 | """Make dirs for experiments."""
73 | path_opt = opt["path"].copy()
74 | if opt["is_train"]:
75 | mkdir_and_rename(path_opt.pop("experiments_root"))
76 | else:
77 | mkdir_and_rename(path_opt.pop("results_root"))
78 | for key, path in path_opt.items():
79 | if ("strict_load" not in key) and ("pretrain_network" not in key) and ("resume" not in key):
80 | os.makedirs(path, exist_ok=True)
81 |
82 |
83 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
84 | """Scan a directory to find the interested files.
85 |
86 | Args:
87 | dir_path (str): Path of the directory.
88 | suffix (str | tuple(str), optional): File suffix that we are
89 | interested in. Default: None.
90 | recursive (bool, optional): If set to True, recursively scan the
91 | directory. Default: False.
92 | full_path (bool, optional): If set to True, include the dir_path.
93 | Default: False.
94 |
95 | Returns:
96 | A generator for all the interested files with relative pathes.
97 | """
98 |
99 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
100 | raise TypeError('"suffix" must be a string or tuple of strings')
101 |
102 | root = dir_path
103 |
104 | def _scandir(dir_path, suffix, recursive):
105 | for entry in os.scandir(dir_path):
106 | if not entry.name.startswith(".") and entry.is_file():
107 | if full_path:
108 | return_path = entry.path
109 | else:
110 | return_path = osp.relpath(entry.path, root)
111 |
112 | if suffix is None:
113 | yield return_path
114 | elif return_path.endswith(suffix):
115 | yield return_path
116 | else:
117 | if recursive:
118 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
119 | else:
120 | continue
121 |
122 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
123 |
124 |
125 | def check_resume(opt, resume_iter):
126 | """Check resume states and pretrain_network paths.
127 |
128 | Args:
129 | opt (dict): Options.
130 | resume_iter (int): Resume iteration.
131 | """
132 | logger = get_root_logger()
133 | if opt["path"]["resume_state"]:
134 | # get all the networks
135 | networks = [key for key in opt.keys() if key.startswith("network_")]
136 | flag_pretrain = False
137 | for network in networks:
138 | if opt["path"].get(f"pretrain_{network}") is not None:
139 | flag_pretrain = True
140 | if flag_pretrain:
141 | logger.warning("pretrain_network path will be ignored during resuming.")
142 | # set pretrained model paths
143 | for network in networks:
144 | name = f"pretrain_{network}"
145 | basename = network.replace("network_", "")
146 | if opt["path"].get("ignore_resume_networks") is None or (
147 | basename not in opt["path"]["ignore_resume_networks"]
148 | ):
149 | opt["path"][name] = osp.join(opt["path"]["models"], f"net_{basename}_{resume_iter}.pth")
150 | logger.info(f"Set {name} to {opt['path'][name]}")
151 |
152 |
153 | def sizeof_fmt(size, suffix="B"):
154 | """Get human readable file size.
155 |
156 | Args:
157 | size (int): File size.
158 | suffix (str): Suffix. Default: 'B'.
159 |
160 | Return:
161 | str: Formated file siz.
162 | """
163 | for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
164 | if abs(size) < 1024.0:
165 | return f"{size:3.1f} {unit}{suffix}"
166 | size /= 1024.0
167 | return f"{size:3.1f} Y{suffix}"
168 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/options.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import OrderedDict
3 | from os import path as osp
4 |
5 | import yaml
6 |
7 | from codeformer.basicsr.utils.misc import get_time_str
8 |
9 |
10 | def ordered_yaml():
11 | """Support OrderedDict for yaml.
12 |
13 | Returns:
14 | yaml Loader and Dumper.
15 | """
16 | try:
17 | from yaml import CDumper as Dumper
18 | from yaml import CLoader as Loader
19 | except ImportError:
20 | from yaml import Dumper, Loader
21 |
22 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
23 |
24 | def dict_representer(dumper, data):
25 | return dumper.represent_dict(data.items())
26 |
27 | def dict_constructor(loader, node):
28 | return OrderedDict(loader.construct_pairs(node))
29 |
30 | Dumper.add_representer(OrderedDict, dict_representer)
31 | Loader.add_constructor(_mapping_tag, dict_constructor)
32 | return Loader, Dumper
33 |
34 |
35 | def parse(opt_path, root_path, is_train=True):
36 | """Parse option file.
37 |
38 | Args:
39 | opt_path (str): Option file path.
40 | is_train (str): Indicate whether in training or not. Default: True.
41 |
42 | Returns:
43 | (dict): Options.
44 | """
45 | with open(opt_path, mode="r") as f:
46 | Loader, _ = ordered_yaml()
47 | opt = yaml.load(f, Loader=Loader)
48 |
49 | opt["is_train"] = is_train
50 |
51 | # opt['name'] = f"{get_time_str()}_{opt['name']}"
52 | if opt["path"].get("resume_state", None): # Shangchen added
53 | resume_state_path = opt["path"].get("resume_state")
54 | opt["name"] = resume_state_path.split("/")[-3]
55 | else:
56 | opt["name"] = f"{get_time_str()}_{opt['name']}"
57 |
58 | # datasets
59 | for phase, dataset in opt["datasets"].items():
60 | # for several datasets, e.g., test_1, test_2
61 | phase = phase.split("_")[0]
62 | dataset["phase"] = phase
63 | if "scale" in opt:
64 | dataset["scale"] = opt["scale"]
65 | if dataset.get("dataroot_gt") is not None:
66 | dataset["dataroot_gt"] = osp.expanduser(dataset["dataroot_gt"])
67 | if dataset.get("dataroot_lq") is not None:
68 | dataset["dataroot_lq"] = osp.expanduser(dataset["dataroot_lq"])
69 |
70 | # paths
71 | for key, val in opt["path"].items():
72 | if (val is not None) and ("resume_state" in key or "pretrain_network" in key):
73 | opt["path"][key] = osp.expanduser(val)
74 |
75 | if is_train:
76 | experiments_root = osp.join(root_path, "experiments", opt["name"])
77 | opt["path"]["experiments_root"] = experiments_root
78 | opt["path"]["models"] = osp.join(experiments_root, "models")
79 | opt["path"]["training_states"] = osp.join(experiments_root, "training_states")
80 | opt["path"]["log"] = experiments_root
81 | opt["path"]["visualization"] = osp.join(experiments_root, "visualization")
82 |
83 | else: # test
84 | results_root = osp.join(root_path, "results", opt["name"])
85 | opt["path"]["results_root"] = results_root
86 | opt["path"]["log"] = results_root
87 | opt["path"]["visualization"] = osp.join(results_root, "visualization")
88 |
89 | return opt
90 |
91 |
92 | def dict2str(opt, indent_level=1):
93 | """dict to string for printing options.
94 |
95 | Args:
96 | opt (dict): Option dict.
97 | indent_level (int): Indent level. Default: 1.
98 |
99 | Return:
100 | (str): Option string for printing.
101 | """
102 | msg = "\n"
103 | for k, v in opt.items():
104 | if isinstance(v, dict):
105 | msg += " " * (indent_level * 2) + k + ":["
106 | msg += dict2str(v, indent_level + 1)
107 | msg += " " * (indent_level * 2) + "]\n"
108 | else:
109 | msg += " " * (indent_level * 2) + k + ": " + str(v) + "\n"
110 | return msg
111 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/registry.py:
--------------------------------------------------------------------------------
1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
2 |
3 |
4 | class Registry:
5 | """
6 | The registry that provides name -> object mapping, to support third-party
7 | users' custom modules.
8 |
9 | To create a registry (e.g. a backbone registry):
10 |
11 | .. code-block:: python
12 |
13 | BACKBONE_REGISTRY = Registry('BACKBONE')
14 |
15 | To register an object:
16 |
17 | .. code-block:: python
18 |
19 | @BACKBONE_REGISTRY.register()
20 | class MyBackbone():
21 | ...
22 |
23 | Or:
24 |
25 | .. code-block:: python
26 |
27 | BACKBONE_REGISTRY.register(MyBackbone)
28 | """
29 |
30 | def __init__(self, name):
31 | """
32 | Args:
33 | name (str): the name of this registry
34 | """
35 | self._name = name
36 | self._obj_map = {}
37 |
38 | def _do_register(self, name, obj):
39 | assert name not in self._obj_map, (
40 | f"An object named '{name}' was already registered " f"in '{self._name}' registry!"
41 | )
42 | self._obj_map[name] = obj
43 |
44 | def register(self, obj=None):
45 | """
46 | Register the given object under the the name `obj.__name__`.
47 | Can be used as either a decorator or not.
48 | See docstring of this class for usage.
49 | """
50 | if obj is None:
51 | # used as a decorator
52 | def deco(func_or_class):
53 | name = func_or_class.__name__
54 | self._do_register(name, func_or_class)
55 | return func_or_class
56 |
57 | return deco
58 |
59 | # used as a function call
60 | name = obj.__name__
61 | self._do_register(name, obj)
62 |
63 | def get(self, name):
64 | ret = self._obj_map.get(name)
65 | if ret is None:
66 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
67 | return ret
68 |
69 | def __contains__(self, name):
70 | return name in self._obj_map
71 |
72 | def __iter__(self):
73 | return iter(self._obj_map.items())
74 |
75 | def keys(self):
76 | return self._obj_map.keys()
77 |
78 |
79 | DATASET_REGISTRY = Registry("dataset")
80 | ARCH_REGISTRY = Registry("arch")
81 | MODEL_REGISTRY = Registry("model")
82 | LOSS_REGISTRY = Registry("loss")
83 | METRIC_REGISTRY = Registry("metric")
84 |
--------------------------------------------------------------------------------
/codeformer/basicsr/utils/video_util.py:
--------------------------------------------------------------------------------
1 | """
2 | The code is modified from the Real-ESRGAN:
3 | https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py
4 |
5 | """
6 | import sys
7 |
8 | import cv2
9 | import numpy as np
10 |
11 | try:
12 | import ffmpeg
13 | except ImportError:
14 | import pip
15 |
16 | pip.main(["install", "--user", "ffmpeg-python"])
17 | import ffmpeg
18 |
19 |
20 | def get_video_meta_info(video_path):
21 | ret = {}
22 | probe = ffmpeg.probe(video_path)
23 | video_streams = [stream for stream in probe["streams"] if stream["codec_type"] == "video"]
24 | has_audio = any(stream["codec_type"] == "audio" for stream in probe["streams"])
25 | ret["width"] = video_streams[0]["width"]
26 | ret["height"] = video_streams[0]["height"]
27 | ret["fps"] = eval(video_streams[0]["avg_frame_rate"])
28 | ret["audio"] = ffmpeg.input(video_path).audio if has_audio else None
29 | ret["nb_frames"] = int(video_streams[0]["nb_frames"])
30 | return ret
31 |
32 |
33 | class VideoReader:
34 | def __init__(self, video_path):
35 | self.paths = [] # for image&folder type
36 | self.audio = None
37 | try:
38 | self.stream_reader = (
39 | ffmpeg.input(video_path)
40 | .output("pipe:", format="rawvideo", pix_fmt="bgr24", loglevel="error")
41 | .run_async(pipe_stdin=True, pipe_stdout=True, cmd="ffmpeg")
42 | )
43 | except FileNotFoundError:
44 | print("Please install ffmpeg (not ffmpeg-python) by running\n", "\t$ conda install -c conda-forge ffmpeg")
45 | sys.exit(0)
46 |
47 | meta = get_video_meta_info(video_path)
48 | self.width = meta["width"]
49 | self.height = meta["height"]
50 | self.input_fps = meta["fps"]
51 | self.audio = meta["audio"]
52 | self.nb_frames = meta["nb_frames"]
53 |
54 | self.idx = 0
55 |
56 | def get_resolution(self):
57 | return self.height, self.width
58 |
59 | def get_fps(self):
60 | if self.input_fps is not None:
61 | return self.input_fps
62 | return 24
63 |
64 | def get_audio(self):
65 | return self.audio
66 |
67 | def __len__(self):
68 | return self.nb_frames
69 |
70 | def get_frame_from_stream(self):
71 | img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
72 | if not img_bytes:
73 | return None
74 | img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
75 | return img
76 |
77 | def get_frame_from_list(self):
78 | if self.idx >= self.nb_frames:
79 | return None
80 | img = cv2.imread(self.paths[self.idx])
81 | self.idx += 1
82 | return img
83 |
84 | def get_frame(self):
85 | return self.get_frame_from_stream()
86 |
87 | def close(self):
88 | self.stream_reader.stdin.close()
89 | self.stream_reader.wait()
90 |
91 |
92 | class VideoWriter:
93 | def __init__(self, video_save_path, height, width, fps, audio):
94 | if height > 2160:
95 | print(
96 | "You are generating video that is larger than 4K, which will be very slow due to IO speed.",
97 | "We highly recommend to decrease the outscale(aka, -s).",
98 | )
99 | if audio is not None:
100 | self.stream_writer = (
101 | ffmpeg.input("pipe:", format="rawvideo", pix_fmt="bgr24", s=f"{width}x{height}", framerate=fps)
102 | .output(audio, video_save_path, pix_fmt="yuv420p", vcodec="libx264", loglevel="error", acodec="copy")
103 | .overwrite_output()
104 | .run_async(pipe_stdin=True, pipe_stdout=True, cmd="ffmpeg")
105 | )
106 | else:
107 | self.stream_writer = (
108 | ffmpeg.input("pipe:", format="rawvideo", pix_fmt="bgr24", s=f"{width}x{height}", framerate=fps)
109 | .output(video_save_path, pix_fmt="yuv420p", vcodec="libx264", loglevel="error")
110 | .overwrite_output()
111 | .run_async(pipe_stdin=True, pipe_stdout=True, cmd="ffmpeg")
112 | )
113 |
114 | def write_frame(self, frame):
115 | try:
116 | frame = frame.astype(np.uint8).tobytes()
117 | self.stream_writer.stdin.write(frame)
118 | except BrokenPipeError:
119 | print(
120 | "Please re-install ffmpeg and libx264 by running\n",
121 | "\t$ conda install -c conda-forge ffmpeg\n",
122 | "\t$ conda install -c conda-forge x264",
123 | )
124 | sys.exit(0)
125 |
126 | def close(self):
127 | self.stream_writer.stdin.close()
128 | self.stream_writer.wait()
129 |
--------------------------------------------------------------------------------
/codeformer/facelib/__init__.py:
--------------------------------------------------------------------------------
1 | from codeformer.facelib.utils import *
2 |
--------------------------------------------------------------------------------
/codeformer/facelib/detection/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 |
4 | import torch
5 | from codeformer.facelib.detection.yolov5face.models.common import Conv
6 | from codeformer.facelib.utils import download_pretrained_models, load_file_from_url
7 | from torch import nn
8 |
9 | from codeformer.facelib.detection.retinaface.retinaface import RetinaFace
10 | from codeformer.facelib.detection.yolov5face.face_detector import YoloDetector
11 |
12 |
13 | def init_detection_model(model_name, half=False, device="cuda"):
14 | if "retinaface" in model_name:
15 | model = init_retinaface_model(model_name, half, device)
16 | elif "YOLOv5" in model_name:
17 | model = init_yolov5face_model(model_name, device)
18 | else:
19 | raise NotImplementedError(f"{model_name} is not implemented.")
20 |
21 | return model
22 |
23 |
24 | def init_retinaface_model(model_name, half=False, device="cuda"):
25 | if model_name == "retinaface_resnet50":
26 | model = RetinaFace(network_name="resnet50", half=half)
27 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth"
28 | elif model_name == "retinaface_mobile0.25":
29 | model = RetinaFace(network_name="mobile0.25", half=half)
30 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth"
31 | else:
32 | raise NotImplementedError(f"{model_name} is not implemented.")
33 |
34 | model_path = load_file_from_url(url=model_url, model_dir="weights/facelib", progress=True, file_name=None)
35 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
36 | # remove unnecessary 'module.'
37 | for k, v in deepcopy(load_net).items():
38 | if k.startswith("module."):
39 | load_net[k[7:]] = v
40 | load_net.pop(k)
41 | model.load_state_dict(load_net, strict=True)
42 | model.eval()
43 | model = model.to(device)
44 |
45 | return model
46 |
47 |
48 | def init_yolov5face_model(model_name, device="cuda"):
49 | if model_name == "YOLOv5l":
50 | model = YoloDetector(config_name="facelib/detection/yolov5face/models/yolov5l.yaml", device=device)
51 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth"
52 | elif model_name == "YOLOv5n":
53 | model = YoloDetector(config_name="facelib/detection/yolov5face/models/yolov5n.yaml", device=device)
54 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth"
55 | else:
56 | raise NotImplementedError(f"{model_name} is not implemented.")
57 |
58 | model_path = load_file_from_url(url=model_url, model_dir="weights/facelib", progress=True, file_name=None)
59 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
60 | model.detector.load_state_dict(load_net, strict=True)
61 | model.detector.eval()
62 | model.detector = model.detector.to(device).float()
63 |
64 | for m in model.detector.modules():
65 | if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
66 | m.inplace = True # pytorch 1.7.0 compatibility
67 | elif isinstance(m, Conv):
68 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
69 |
70 | return model
71 |
72 |
73 | # Download from Google Drive
74 | # def init_yolov5face_model(model_name, device='cuda'):
75 | # if model_name == 'YOLOv5l':
76 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
77 | # f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
78 | # elif model_name == 'YOLOv5n':
79 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
80 | # f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
81 | # else:
82 | # raise NotImplementedError(f'{model_name} is not implemented.')
83 |
84 | # model_path = os.path.join('weights/facelib', list(f_id.keys())[0])
85 | # if not os.path.exists(model_path):
86 | # download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib')
87 |
88 | # load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
89 | # model.detector.load_state_dict(load_net, strict=True)
90 | # model.detector.eval()
91 | # model.detector = model.detector.to(device).float()
92 |
93 | # for m in model.detector.modules():
94 | # if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
95 | # m.inplace = True # pytorch 1.7.0 compatibility
96 | # elif isinstance(m, Conv):
97 | # m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
98 |
99 | # return model
100 |
--------------------------------------------------------------------------------
/codeformer/facelib/detection/align_trans.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | from codeformer.facelib.detection.matlab_cp2tform import get_similarity_transform_for_cv2
5 |
6 | # reference facial points, a list of coordinates (x,y)
7 | REFERENCE_FACIAL_POINTS = [
8 | [30.29459953, 51.69630051],
9 | [65.53179932, 51.50139999],
10 | [48.02519989, 71.73660278],
11 | [33.54930115, 92.3655014],
12 | [62.72990036, 92.20410156],
13 | ]
14 |
15 | DEFAULT_CROP_SIZE = (96, 112)
16 |
17 |
18 | class FaceWarpException(Exception):
19 | def __str__(self):
20 | return "In File {}:{}".format(__file__, super.__str__(self))
21 |
22 |
23 | def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
24 | """
25 | Function:
26 | ----------
27 | get reference 5 key points according to crop settings:
28 | 0. Set default crop_size:
29 | if default_square:
30 | crop_size = (112, 112)
31 | else:
32 | crop_size = (96, 112)
33 | 1. Pad the crop_size by inner_padding_factor in each side;
34 | 2. Resize crop_size into (output_size - outer_padding*2),
35 | pad into output_size with outer_padding;
36 | 3. Output reference_5point;
37 | Parameters:
38 | ----------
39 | @output_size: (w, h) or None
40 | size of aligned face image
41 | @inner_padding_factor: (w_factor, h_factor)
42 | padding factor for inner (w, h)
43 | @outer_padding: (w_pad, h_pad)
44 | each row is a pair of coordinates (x, y)
45 | @default_square: True or False
46 | if True:
47 | default crop_size = (112, 112)
48 | else:
49 | default crop_size = (96, 112);
50 | !!! make sure, if output_size is not None:
51 | (output_size - outer_padding)
52 | = some_scale * (default crop_size * (1.0 +
53 | inner_padding_factor))
54 | Returns:
55 | ----------
56 | @reference_5point: 5x2 np.array
57 | each row is a pair of transformed coordinates (x, y)
58 | """
59 |
60 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
61 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
62 |
63 | # 0) make the inner region a square
64 | if default_square:
65 | size_diff = max(tmp_crop_size) - tmp_crop_size
66 | tmp_5pts += size_diff / 2
67 | tmp_crop_size += size_diff
68 |
69 | if output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]:
70 |
71 | return tmp_5pts
72 |
73 | if inner_padding_factor == 0 and outer_padding == (0, 0):
74 | if output_size is None:
75 | return tmp_5pts
76 | else:
77 | raise FaceWarpException("No paddings to do, output_size must be None or {}".format(tmp_crop_size))
78 |
79 | # check output size
80 | if not (0 <= inner_padding_factor <= 1.0):
81 | raise FaceWarpException("Not (0 <= inner_padding_factor <= 1.0)")
82 |
83 | if (inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None:
84 | output_size = tmp_crop_size * (1 + inner_padding_factor * 2).astype(np.int32)
85 | output_size += np.array(outer_padding)
86 | if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
87 | raise FaceWarpException("Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])")
88 |
89 | # 1) pad the inner region according inner_padding_factor
90 | if inner_padding_factor > 0:
91 | size_diff = tmp_crop_size * inner_padding_factor * 2
92 | tmp_5pts += size_diff / 2
93 | tmp_crop_size += np.round(size_diff).astype(np.int32)
94 |
95 | # 2) resize the padded inner region
96 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
97 |
98 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
99 | raise FaceWarpException(
100 | "Must have (output_size - outer_padding)" "= some_scale * (crop_size * (1.0 + inner_padding_factor)"
101 | )
102 |
103 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
104 | tmp_5pts = tmp_5pts * scale_factor
105 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
106 | # tmp_5pts = tmp_5pts + size_diff / 2
107 | tmp_crop_size = size_bf_outer_pad
108 |
109 | # 3) add outer_padding to make output_size
110 | reference_5point = tmp_5pts + np.array(outer_padding)
111 | tmp_crop_size = output_size
112 |
113 | return reference_5point
114 |
115 |
116 | def get_affine_transform_matrix(src_pts, dst_pts):
117 | """
118 | Function:
119 | ----------
120 | get affine transform matrix 'tfm' from src_pts to dst_pts
121 | Parameters:
122 | ----------
123 | @src_pts: Kx2 np.array
124 | source points matrix, each row is a pair of coordinates (x, y)
125 | @dst_pts: Kx2 np.array
126 | destination points matrix, each row is a pair of coordinates (x, y)
127 | Returns:
128 | ----------
129 | @tfm: 2x3 np.array
130 | transform matrix from src_pts to dst_pts
131 | """
132 |
133 | tfm = np.float32([[1, 0, 0], [0, 1, 0]])
134 | n_pts = src_pts.shape[0]
135 | ones = np.ones((n_pts, 1), src_pts.dtype)
136 | src_pts_ = np.hstack([src_pts, ones])
137 | dst_pts_ = np.hstack([dst_pts, ones])
138 |
139 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
140 |
141 | if rank == 3:
142 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
143 | elif rank == 2:
144 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
145 |
146 | return tfm
147 |
148 |
149 | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type="smilarity"):
150 | """
151 | Function:
152 | ----------
153 | apply affine transform 'trans' to uv
154 | Parameters:
155 | ----------
156 | @src_img: 3x3 np.array
157 | input image
158 | @facial_pts: could be
159 | 1)a list of K coordinates (x,y)
160 | or
161 | 2) Kx2 or 2xK np.array
162 | each row or col is a pair of coordinates (x, y)
163 | @reference_pts: could be
164 | 1) a list of K coordinates (x,y)
165 | or
166 | 2) Kx2 or 2xK np.array
167 | each row or col is a pair of coordinates (x, y)
168 | or
169 | 3) None
170 | if None, use default reference facial points
171 | @crop_size: (w, h)
172 | output face image size
173 | @align_type: transform type, could be one of
174 | 1) 'similarity': use similarity transform
175 | 2) 'cv2_affine': use the first 3 points to do affine transform,
176 | by calling cv2.getAffineTransform()
177 | 3) 'affine': use all points to do affine transform
178 | Returns:
179 | ----------
180 | @face_img: output face image with size (w, h) = @crop_size
181 | """
182 |
183 | if reference_pts is None:
184 | if crop_size[0] == 96 and crop_size[1] == 112:
185 | reference_pts = REFERENCE_FACIAL_POINTS
186 | else:
187 | default_square = False
188 | inner_padding_factor = 0
189 | outer_padding = (0, 0)
190 | output_size = crop_size
191 |
192 | reference_pts = get_reference_facial_points(
193 | output_size, inner_padding_factor, outer_padding, default_square
194 | )
195 |
196 | ref_pts = np.float32(reference_pts)
197 | ref_pts_shp = ref_pts.shape
198 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
199 | raise FaceWarpException("reference_pts.shape must be (K,2) or (2,K) and K>2")
200 |
201 | if ref_pts_shp[0] == 2:
202 | ref_pts = ref_pts.T
203 |
204 | src_pts = np.float32(facial_pts)
205 | src_pts_shp = src_pts.shape
206 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
207 | raise FaceWarpException("facial_pts.shape must be (K,2) or (2,K) and K>2")
208 |
209 | if src_pts_shp[0] == 2:
210 | src_pts = src_pts.T
211 |
212 | if src_pts.shape != ref_pts.shape:
213 | raise FaceWarpException("facial_pts and reference_pts must have the same shape")
214 |
215 | if align_type == "cv2_affine":
216 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
217 | elif align_type == "affine":
218 | tfm = get_affine_transform_matrix(src_pts, ref_pts)
219 | else:
220 | tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
221 |
222 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
223 |
224 | return face_img
225 |
--------------------------------------------------------------------------------
/codeformer/facelib/detection/retinaface/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/facelib/detection/retinaface/__init__.py
--------------------------------------------------------------------------------
/codeformer/facelib/detection/retinaface/retinaface_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def conv_bn(inp, oup, stride=1, leaky=0):
7 | return nn.Sequential(
8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
9 | nn.BatchNorm2d(oup),
10 | nn.LeakyReLU(negative_slope=leaky, inplace=True),
11 | )
12 |
13 |
14 | def conv_bn_no_relu(inp, oup, stride):
15 | return nn.Sequential(
16 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
17 | nn.BatchNorm2d(oup),
18 | )
19 |
20 |
21 | def conv_bn1X1(inp, oup, stride, leaky=0):
22 | return nn.Sequential(
23 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
24 | nn.BatchNorm2d(oup),
25 | nn.LeakyReLU(negative_slope=leaky, inplace=True),
26 | )
27 |
28 |
29 | def conv_dw(inp, oup, stride, leaky=0.1):
30 | return nn.Sequential(
31 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
32 | nn.BatchNorm2d(inp),
33 | nn.LeakyReLU(negative_slope=leaky, inplace=True),
34 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
35 | nn.BatchNorm2d(oup),
36 | nn.LeakyReLU(negative_slope=leaky, inplace=True),
37 | )
38 |
39 |
40 | class SSH(nn.Module):
41 | def __init__(self, in_channel, out_channel):
42 | super(SSH, self).__init__()
43 | assert out_channel % 4 == 0
44 | leaky = 0
45 | if out_channel <= 64:
46 | leaky = 0.1
47 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
48 |
49 | self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
50 | self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
51 |
52 | self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
53 | self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
54 |
55 | def forward(self, input):
56 | conv3X3 = self.conv3X3(input)
57 |
58 | conv5X5_1 = self.conv5X5_1(input)
59 | conv5X5 = self.conv5X5_2(conv5X5_1)
60 |
61 | conv7X7_2 = self.conv7X7_2(conv5X5_1)
62 | conv7X7 = self.conv7x7_3(conv7X7_2)
63 |
64 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
65 | out = F.relu(out)
66 | return out
67 |
68 |
69 | class FPN(nn.Module):
70 | def __init__(self, in_channels_list, out_channels):
71 | super(FPN, self).__init__()
72 | leaky = 0
73 | if out_channels <= 64:
74 | leaky = 0.1
75 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
76 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
77 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
78 |
79 | self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
80 | self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
81 |
82 | def forward(self, input):
83 | # names = list(input.keys())
84 | # input = list(input.values())
85 |
86 | output1 = self.output1(input[0])
87 | output2 = self.output2(input[1])
88 | output3 = self.output3(input[2])
89 |
90 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
91 | output2 = output2 + up3
92 | output2 = self.merge2(output2)
93 |
94 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
95 | output1 = output1 + up2
96 | output1 = self.merge1(output1)
97 |
98 | out = [output1, output2, output3]
99 | return out
100 |
101 |
102 | class MobileNetV1(nn.Module):
103 | def __init__(self):
104 | super(MobileNetV1, self).__init__()
105 | self.stage1 = nn.Sequential(
106 | conv_bn(3, 8, 2, leaky=0.1), # 3
107 | conv_dw(8, 16, 1), # 7
108 | conv_dw(16, 32, 2), # 11
109 | conv_dw(32, 32, 1), # 19
110 | conv_dw(32, 64, 2), # 27
111 | conv_dw(64, 64, 1), # 43
112 | )
113 | self.stage2 = nn.Sequential(
114 | conv_dw(64, 128, 2), # 43 + 16 = 59
115 | conv_dw(128, 128, 1), # 59 + 32 = 91
116 | conv_dw(128, 128, 1), # 91 + 32 = 123
117 | conv_dw(128, 128, 1), # 123 + 32 = 155
118 | conv_dw(128, 128, 1), # 155 + 32 = 187
119 | conv_dw(128, 128, 1), # 187 + 32 = 219
120 | )
121 | self.stage3 = nn.Sequential(
122 | conv_dw(128, 256, 2), # 219 +3 2 = 241
123 | conv_dw(256, 256, 1), # 241 + 64 = 301
124 | )
125 | self.avg = nn.AdaptiveAvgPool2d((1, 1))
126 | self.fc = nn.Linear(256, 1000)
127 |
128 | def forward(self, x):
129 | x = self.stage1(x)
130 | x = self.stage2(x)
131 | x = self.stage3(x)
132 | x = self.avg(x)
133 | # x = self.model(x)
134 | x = x.view(-1, 256)
135 | x = self.fc(x)
136 | return x
137 |
138 |
139 | class ClassHead(nn.Module):
140 | def __init__(self, inchannels=512, num_anchors=3):
141 | super(ClassHead, self).__init__()
142 | self.num_anchors = num_anchors
143 | self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
144 |
145 | def forward(self, x):
146 | out = self.conv1x1(x)
147 | out = out.permute(0, 2, 3, 1).contiguous()
148 |
149 | return out.view(out.shape[0], -1, 2)
150 |
151 |
152 | class BboxHead(nn.Module):
153 | def __init__(self, inchannels=512, num_anchors=3):
154 | super(BboxHead, self).__init__()
155 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
156 |
157 | def forward(self, x):
158 | out = self.conv1x1(x)
159 | out = out.permute(0, 2, 3, 1).contiguous()
160 |
161 | return out.view(out.shape[0], -1, 4)
162 |
163 |
164 | class LandmarkHead(nn.Module):
165 | def __init__(self, inchannels=512, num_anchors=3):
166 | super(LandmarkHead, self).__init__()
167 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
168 |
169 | def forward(self, x):
170 | out = self.conv1x1(x)
171 | out = out.permute(0, 2, 3, 1).contiguous()
172 |
173 | return out.view(out.shape[0], -1, 10)
174 |
175 |
176 | def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
177 | classhead = nn.ModuleList()
178 | for i in range(fpn_num):
179 | classhead.append(ClassHead(inchannels, anchor_num))
180 | return classhead
181 |
182 |
183 | def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
184 | bboxhead = nn.ModuleList()
185 | for i in range(fpn_num):
186 | bboxhead.append(BboxHead(inchannels, anchor_num))
187 | return bboxhead
188 |
189 |
190 | def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
191 | landmarkhead = nn.ModuleList()
192 | for i in range(fpn_num):
193 | landmarkhead.append(LandmarkHead(inchannels, anchor_num))
194 | return landmarkhead
195 |
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/facelib/detection/yolov5face/__init__.py
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/face_detector.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import re
3 | from pathlib import Path
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 |
9 | from codeformer.facelib.detection.yolov5face.models.yolo import Model
10 | from codeformer.facelib.detection.yolov5face.utils.datasets import letterbox
11 | from codeformer.facelib.detection.yolov5face.utils.general import (
12 | check_img_size,
13 | non_max_suppression_face,
14 | scale_coords,
15 | scale_coords_landmarks,
16 | )
17 |
18 | # IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9)
19 | IS_HIGH_VERSION = [
20 | int(m)
21 | for m in list(
22 | re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$", torch.__version__)[0][:3]
23 | )
24 | ] >= [1, 9, 0]
25 |
26 |
27 | def isListempty(inList):
28 | if isinstance(inList, list): # Is a list
29 | return all(map(isListempty, inList))
30 | return False # Not a list
31 |
32 |
33 | class YoloDetector:
34 | def __init__(
35 | self,
36 | config_name,
37 | min_face=10,
38 | target_size=None,
39 | device="cuda",
40 | ):
41 | """
42 | config_name: name of .yaml config with network configuration from models/ folder.
43 | min_face : minimal face size in pixels.
44 | target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
45 | None for original resolution.
46 | """
47 | self._class_path = Path(__file__).parent.absolute()
48 | self.target_size = target_size
49 | self.min_face = min_face
50 | self.detector = Model(cfg=config_name)
51 | self.device = device
52 |
53 | def _preprocess(self, imgs):
54 | """
55 | Preprocessing image before passing through the network. Resize and conversion to torch tensor.
56 | """
57 | pp_imgs = []
58 | for img in imgs:
59 | h0, w0 = img.shape[:2] # orig hw
60 | if self.target_size:
61 | r = self.target_size / min(h0, w0) # resize image to img_size
62 | if r < 1:
63 | img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR)
64 |
65 | imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size
66 | img = letterbox(img, new_shape=imgsz)[0]
67 | pp_imgs.append(img)
68 | pp_imgs = np.array(pp_imgs)
69 | pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
70 | pp_imgs = torch.from_numpy(pp_imgs).to(self.device)
71 | pp_imgs = pp_imgs.float() # uint8 to fp16/32
72 | return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0
73 |
74 | def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
75 | """
76 | Postprocessing of raw pytorch model output.
77 | Returns:
78 | bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
79 | points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
80 | """
81 | bboxes = [[] for _ in range(len(origimgs))]
82 | landmarks = [[] for _ in range(len(origimgs))]
83 |
84 | pred = non_max_suppression_face(pred, conf_thres, iou_thres)
85 |
86 | for image_id, origimg in enumerate(origimgs):
87 | img_shape = origimg.shape
88 | image_height, image_width = img_shape[:2]
89 | gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh
90 | gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks
91 | det = pred[image_id].cpu()
92 | scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round()
93 | scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round()
94 |
95 | for j in range(det.size()[0]):
96 | box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
97 | box = list(
98 | map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height])
99 | )
100 | if box[3] - box[1] < self.min_face:
101 | continue
102 | lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
103 | lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)]))
104 | lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
105 | bboxes[image_id].append(box)
106 | landmarks[image_id].append(lm)
107 | return bboxes, landmarks
108 |
109 | def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
110 | """
111 | Get bbox coordinates and keypoints of faces on original image.
112 | Params:
113 | imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
114 | conf_thres: confidence threshold for each prediction
115 | iou_thres: threshold for NMS (filter of intersecting bboxes)
116 | Returns:
117 | bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
118 | points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
119 | """
120 | # Pass input images through face detector
121 | images = imgs if isinstance(imgs, list) else [imgs]
122 | images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
123 | origimgs = copy.deepcopy(images)
124 |
125 | images = self._preprocess(images)
126 |
127 | if IS_HIGH_VERSION:
128 | with torch.inference_mode(): # for pytorch>=1.9
129 | pred = self.detector(images)[0]
130 | else:
131 | with torch.no_grad(): # for pytorch<1.9
132 | pred = self.detector(images)[0]
133 |
134 | bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres)
135 |
136 | # return bboxes, points
137 | if not isListempty(points):
138 | bboxes = np.array(bboxes).reshape(-1, 4)
139 | points = np.array(points).reshape(-1, 10)
140 | padding = bboxes[:, 0].reshape(-1, 1)
141 | return np.concatenate((bboxes, padding, points), axis=1)
142 | else:
143 | return None
144 |
145 | def __call__(self, *args):
146 | return self.predict(*args)
147 |
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/facelib/detection/yolov5face/models/__init__.py
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/models/experimental.py:
--------------------------------------------------------------------------------
1 | # # This file contains experimental modules
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 |
7 | from codeformer.facelib.detection.yolov5face.models.common import Conv
8 |
9 |
10 | class CrossConv(nn.Module):
11 | # Cross Convolution Downsample
12 | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
13 | # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
14 | super().__init__()
15 | c_ = int(c2 * e) # hidden channels
16 | self.cv1 = Conv(c1, c_, (1, k), (1, s))
17 | self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
18 | self.add = shortcut and c1 == c2
19 |
20 | def forward(self, x):
21 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
22 |
23 |
24 | class MixConv2d(nn.Module):
25 | # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
26 | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
27 | super().__init__()
28 | groups = len(k)
29 | if equal_ch: # equal c_ per group
30 | i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices
31 | c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
32 | else: # equal weight.numel() per group
33 | b = [c2] + [0] * groups
34 | a = np.eye(groups + 1, groups, k=-1)
35 | a -= np.roll(a, 1, axis=1)
36 | a *= np.array(k) ** 2
37 | a[0] = 1
38 | c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
39 |
40 | self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
41 | self.bn = nn.BatchNorm2d(c2)
42 | self.act = nn.LeakyReLU(0.1, inplace=True)
43 |
44 | def forward(self, x):
45 | return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
46 |
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/models/yolov5l.yaml:
--------------------------------------------------------------------------------
1 | # parameters
2 | nc: 1 # number of classes
3 | depth_multiple: 1.0 # model depth multiple
4 | width_multiple: 1.0 # layer channel multiple
5 |
6 | # anchors
7 | anchors:
8 | - [4,5, 8,10, 13,16] # P3/8
9 | - [23,29, 43,55, 73,105] # P4/16
10 | - [146,217, 231,300, 335,433] # P5/32
11 |
12 | # YOLOv5 backbone
13 | backbone:
14 | # [from, number, module, args]
15 | [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2
16 | [-1, 3, C3, [128]],
17 | [-1, 1, Conv, [256, 3, 2]], # 2-P3/8
18 | [-1, 9, C3, [256]],
19 | [-1, 1, Conv, [512, 3, 2]], # 4-P4/16
20 | [-1, 9, C3, [512]],
21 | [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32
22 | [-1, 1, SPP, [1024, [3,5,7]]],
23 | [-1, 3, C3, [1024, False]], # 8
24 | ]
25 |
26 | # YOLOv5 head
27 | head:
28 | [[-1, 1, Conv, [512, 1, 1]],
29 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
30 | [[-1, 5], 1, Concat, [1]], # cat backbone P4
31 | [-1, 3, C3, [512, False]], # 12
32 |
33 | [-1, 1, Conv, [256, 1, 1]],
34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35 | [[-1, 3], 1, Concat, [1]], # cat backbone P3
36 | [-1, 3, C3, [256, False]], # 16 (P3/8-small)
37 |
38 | [-1, 1, Conv, [256, 3, 2]],
39 | [[-1, 13], 1, Concat, [1]], # cat head P4
40 | [-1, 3, C3, [512, False]], # 19 (P4/16-medium)
41 |
42 | [-1, 1, Conv, [512, 3, 2]],
43 | [[-1, 9], 1, Concat, [1]], # cat head P5
44 | [-1, 3, C3, [1024, False]], # 22 (P5/32-large)
45 |
46 | [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
47 | ]
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/models/yolov5n.yaml:
--------------------------------------------------------------------------------
1 | # parameters
2 | nc: 1 # number of classes
3 | depth_multiple: 1.0 # model depth multiple
4 | width_multiple: 1.0 # layer channel multiple
5 |
6 | # anchors
7 | anchors:
8 | - [4,5, 8,10, 13,16] # P3/8
9 | - [23,29, 43,55, 73,105] # P4/16
10 | - [146,217, 231,300, 335,433] # P5/32
11 |
12 | # YOLOv5 backbone
13 | backbone:
14 | # [from, number, module, args]
15 | [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4
16 | [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8
17 | [-1, 3, ShuffleV2Block, [128, 1]], # 2
18 | [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16
19 | [-1, 7, ShuffleV2Block, [256, 1]], # 4
20 | [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32
21 | [-1, 3, ShuffleV2Block, [512, 1]], # 6
22 | ]
23 |
24 | # YOLOv5 head
25 | head:
26 | [[-1, 1, Conv, [128, 1, 1]],
27 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
28 | [[-1, 4], 1, Concat, [1]], # cat backbone P4
29 | [-1, 1, C3, [128, False]], # 10
30 |
31 | [-1, 1, Conv, [128, 1, 1]],
32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']],
33 | [[-1, 2], 1, Concat, [1]], # cat backbone P3
34 | [-1, 1, C3, [128, False]], # 14 (P3/8-small)
35 |
36 | [-1, 1, Conv, [128, 3, 2]],
37 | [[-1, 11], 1, Concat, [1]], # cat head P4
38 | [-1, 1, C3, [128, False]], # 17 (P4/16-medium)
39 |
40 | [-1, 1, Conv, [128, 3, 2]],
41 | [[-1, 7], 1, Concat, [1]], # cat head P5
42 | [-1, 1, C3, [128, False]], # 20 (P5/32-large)
43 |
44 | [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
45 | ]
46 |
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/facelib/detection/yolov5face/utils/__init__.py
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/utils/autoanchor.py:
--------------------------------------------------------------------------------
1 | # Auto-anchor utils
2 |
3 |
4 | def check_anchor_order(m):
5 | # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
6 | a = m.anchor_grid.prod(-1).view(-1) # anchor area
7 | da = a[-1] - a[0] # delta a
8 | ds = m.stride[-1] - m.stride[0] # delta s
9 | if da.sign() != ds.sign(): # same order
10 | print("Reversing anchor order")
11 | m.anchors[:] = m.anchors.flip(0)
12 | m.anchor_grid[:] = m.anchor_grid.flip(0)
13 |
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/utils/datasets.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True):
6 | # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
7 | shape = img.shape[:2] # current shape [height, width]
8 | if isinstance(new_shape, int):
9 | new_shape = (new_shape, new_shape)
10 |
11 | # Scale ratio (new / old)
12 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
13 | if not scaleup: # only scale down, do not scale up (for better test mAP)
14 | r = min(r, 1.0)
15 |
16 | # Compute padding
17 | ratio = r, r # width, height ratios
18 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
19 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
20 | if auto: # minimum rectangle
21 | dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
22 | elif scale_fill: # stretch
23 | dw, dh = 0.0, 0.0
24 | new_unpad = (new_shape[1], new_shape[0])
25 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
26 |
27 | dw /= 2 # divide padding into 2 sides
28 | dh /= 2
29 |
30 | if shape[::-1] != new_unpad: # resize
31 | img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
32 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
33 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
34 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
35 | return img, ratio, (dw, dh)
36 |
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/utils/extract_ckpt.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import torch
4 |
5 | sys.path.insert(0, "./facelib/detection/yolov5face")
6 | model = torch.load("facelib/detection/yolov5face/yolov5n-face.pt", map_location="cpu")["model"]
7 | torch.save(model.state_dict(), "weights/facelib/yolov5n-face.pth")
8 |
--------------------------------------------------------------------------------
/codeformer/facelib/detection/yolov5face/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def fuse_conv_and_bn(conv, bn):
6 | # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
7 | fusedconv = (
8 | nn.Conv2d(
9 | conv.in_channels,
10 | conv.out_channels,
11 | kernel_size=conv.kernel_size,
12 | stride=conv.stride,
13 | padding=conv.padding,
14 | groups=conv.groups,
15 | bias=True,
16 | )
17 | .requires_grad_(False)
18 | .to(conv.weight.device)
19 | )
20 |
21 | # prepare filters
22 | w_conv = conv.weight.clone().view(conv.out_channels, -1)
23 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
24 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
25 |
26 | # prepare spatial bias
27 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
28 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
29 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
30 |
31 | return fusedconv
32 |
33 |
34 | def copy_attr(a, b, include=(), exclude=()):
35 | # Copy attributes from b to a, options to only include [...] and to exclude [...]
36 | for k, v in b.__dict__.items():
37 | if (include and k not in include) or k.startswith("_") or k in exclude:
38 | continue
39 |
40 | setattr(a, k, v)
41 |
--------------------------------------------------------------------------------
/codeformer/facelib/parsing/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from codeformer.facelib.parsing.bisenet import BiSeNet
4 | from codeformer.facelib.parsing.parsenet import ParseNet
5 | from codeformer.facelib.utils import load_file_from_url
6 |
7 |
8 | def init_parsing_model(model_name="bisenet", half=False, device="cuda"):
9 | if model_name == "bisenet":
10 | model = BiSeNet(num_class=19)
11 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth"
12 | elif model_name == "parsenet":
13 | model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
14 | model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth"
15 | else:
16 | raise NotImplementedError(f"{model_name} is not implemented.")
17 |
18 | model_path = load_file_from_url(url=model_url, model_dir="weights/facelib", progress=True, file_name=None)
19 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
20 | model.load_state_dict(load_net, strict=True)
21 | model.eval()
22 | model = model.to(device)
23 | return model
24 |
--------------------------------------------------------------------------------
/codeformer/facelib/parsing/bisenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .resnet import ResNet18
6 |
7 |
8 | class ConvBNReLU(nn.Module):
9 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
10 | super(ConvBNReLU, self).__init__()
11 | self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
12 | self.bn = nn.BatchNorm2d(out_chan)
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | x = F.relu(self.bn(x))
17 | return x
18 |
19 |
20 | class BiSeNetOutput(nn.Module):
21 | def __init__(self, in_chan, mid_chan, num_class):
22 | super(BiSeNetOutput, self).__init__()
23 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
24 | self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
25 |
26 | def forward(self, x):
27 | feat = self.conv(x)
28 | out = self.conv_out(feat)
29 | return out, feat
30 |
31 |
32 | class AttentionRefinementModule(nn.Module):
33 | def __init__(self, in_chan, out_chan):
34 | super(AttentionRefinementModule, self).__init__()
35 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
36 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
37 | self.bn_atten = nn.BatchNorm2d(out_chan)
38 | self.sigmoid_atten = nn.Sigmoid()
39 |
40 | def forward(self, x):
41 | feat = self.conv(x)
42 | atten = F.avg_pool2d(feat, feat.size()[2:])
43 | atten = self.conv_atten(atten)
44 | atten = self.bn_atten(atten)
45 | atten = self.sigmoid_atten(atten)
46 | out = torch.mul(feat, atten)
47 | return out
48 |
49 |
50 | class ContextPath(nn.Module):
51 | def __init__(self):
52 | super(ContextPath, self).__init__()
53 | self.resnet = ResNet18()
54 | self.arm16 = AttentionRefinementModule(256, 128)
55 | self.arm32 = AttentionRefinementModule(512, 128)
56 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
57 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
58 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
59 |
60 | def forward(self, x):
61 | feat8, feat16, feat32 = self.resnet(x)
62 | h8, w8 = feat8.size()[2:]
63 | h16, w16 = feat16.size()[2:]
64 | h32, w32 = feat32.size()[2:]
65 |
66 | avg = F.avg_pool2d(feat32, feat32.size()[2:])
67 | avg = self.conv_avg(avg)
68 | avg_up = F.interpolate(avg, (h32, w32), mode="nearest")
69 |
70 | feat32_arm = self.arm32(feat32)
71 | feat32_sum = feat32_arm + avg_up
72 | feat32_up = F.interpolate(feat32_sum, (h16, w16), mode="nearest")
73 | feat32_up = self.conv_head32(feat32_up)
74 |
75 | feat16_arm = self.arm16(feat16)
76 | feat16_sum = feat16_arm + feat32_up
77 | feat16_up = F.interpolate(feat16_sum, (h8, w8), mode="nearest")
78 | feat16_up = self.conv_head16(feat16_up)
79 |
80 | return feat8, feat16_up, feat32_up # x8, x8, x16
81 |
82 |
83 | class FeatureFusionModule(nn.Module):
84 | def __init__(self, in_chan, out_chan):
85 | super(FeatureFusionModule, self).__init__()
86 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
87 | self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
88 | self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
89 | self.relu = nn.ReLU(inplace=True)
90 | self.sigmoid = nn.Sigmoid()
91 |
92 | def forward(self, fsp, fcp):
93 | fcat = torch.cat([fsp, fcp], dim=1)
94 | feat = self.convblk(fcat)
95 | atten = F.avg_pool2d(feat, feat.size()[2:])
96 | atten = self.conv1(atten)
97 | atten = self.relu(atten)
98 | atten = self.conv2(atten)
99 | atten = self.sigmoid(atten)
100 | feat_atten = torch.mul(feat, atten)
101 | feat_out = feat_atten + feat
102 | return feat_out
103 |
104 |
105 | class BiSeNet(nn.Module):
106 | def __init__(self, num_class):
107 | super(BiSeNet, self).__init__()
108 | self.cp = ContextPath()
109 | self.ffm = FeatureFusionModule(256, 256)
110 | self.conv_out = BiSeNetOutput(256, 256, num_class)
111 | self.conv_out16 = BiSeNetOutput(128, 64, num_class)
112 | self.conv_out32 = BiSeNetOutput(128, 64, num_class)
113 |
114 | def forward(self, x, return_feat=False):
115 | h, w = x.size()[2:]
116 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
117 | feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
118 | feat_fuse = self.ffm(feat_sp, feat_cp8)
119 |
120 | out, feat = self.conv_out(feat_fuse)
121 | out16, feat16 = self.conv_out16(feat_cp8)
122 | out32, feat32 = self.conv_out32(feat_cp16)
123 |
124 | out = F.interpolate(out, (h, w), mode="bilinear", align_corners=True)
125 | out16 = F.interpolate(out16, (h, w), mode="bilinear", align_corners=True)
126 | out32 = F.interpolate(out32, (h, w), mode="bilinear", align_corners=True)
127 |
128 | if return_feat:
129 | feat = F.interpolate(feat, (h, w), mode="bilinear", align_corners=True)
130 | feat16 = F.interpolate(feat16, (h, w), mode="bilinear", align_corners=True)
131 | feat32 = F.interpolate(feat32, (h, w), mode="bilinear", align_corners=True)
132 | return out, out16, out32, feat, feat16, feat32
133 | else:
134 | return out, out16, out32
135 |
--------------------------------------------------------------------------------
/codeformer/facelib/parsing/parsenet.py:
--------------------------------------------------------------------------------
1 | """Modified from https://github.com/chaofengc/PSFRGAN
2 | """
3 | import numpy as np
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class NormLayer(nn.Module):
9 | """Normalization Layers.
10 |
11 | Args:
12 | channels: input channels, for batch norm and instance norm.
13 | input_size: input shape without batch size, for layer norm.
14 | """
15 |
16 | def __init__(self, channels, normalize_shape=None, norm_type="bn"):
17 | super(NormLayer, self).__init__()
18 | norm_type = norm_type.lower()
19 | self.norm_type = norm_type
20 | if norm_type == "bn":
21 | self.norm = nn.BatchNorm2d(channels, affine=True)
22 | elif norm_type == "in":
23 | self.norm = nn.InstanceNorm2d(channels, affine=False)
24 | elif norm_type == "gn":
25 | self.norm = nn.GroupNorm(32, channels, affine=True)
26 | elif norm_type == "pixel":
27 | self.norm = lambda x: F.normalize(x, p=2, dim=1)
28 | elif norm_type == "layer":
29 | self.norm = nn.LayerNorm(normalize_shape)
30 | elif norm_type == "none":
31 | self.norm = lambda x: x * 1.0
32 | else:
33 | assert 1 == 0, f"Norm type {norm_type} not support."
34 |
35 | def forward(self, x, ref=None):
36 | if self.norm_type == "spade":
37 | return self.norm(x, ref)
38 | else:
39 | return self.norm(x)
40 |
41 |
42 | class ReluLayer(nn.Module):
43 | """Relu Layer.
44 |
45 | Args:
46 | relu type: type of relu layer, candidates are
47 | - ReLU
48 | - LeakyReLU: default relu slope 0.2
49 | - PRelu
50 | - SELU
51 | - none: direct pass
52 | """
53 |
54 | def __init__(self, channels, relu_type="relu"):
55 | super(ReluLayer, self).__init__()
56 | relu_type = relu_type.lower()
57 | if relu_type == "relu":
58 | self.func = nn.ReLU(True)
59 | elif relu_type == "leakyrelu":
60 | self.func = nn.LeakyReLU(0.2, inplace=True)
61 | elif relu_type == "prelu":
62 | self.func = nn.PReLU(channels)
63 | elif relu_type == "selu":
64 | self.func = nn.SELU(True)
65 | elif relu_type == "none":
66 | self.func = lambda x: x * 1.0
67 | else:
68 | assert 1 == 0, f"Relu type {relu_type} not support."
69 |
70 | def forward(self, x):
71 | return self.func(x)
72 |
73 |
74 | class ConvLayer(nn.Module):
75 | def __init__(
76 | self,
77 | in_channels,
78 | out_channels,
79 | kernel_size=3,
80 | scale="none",
81 | norm_type="none",
82 | relu_type="none",
83 | use_pad=True,
84 | bias=True,
85 | ):
86 | super(ConvLayer, self).__init__()
87 | self.use_pad = use_pad
88 | self.norm_type = norm_type
89 | if norm_type in ["bn"]:
90 | bias = False
91 |
92 | stride = 2 if scale == "down" else 1
93 |
94 | self.scale_func = lambda x: x
95 | if scale == "up":
96 | self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode="nearest")
97 |
98 | self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.0) / 2)))
99 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
100 |
101 | self.relu = ReluLayer(out_channels, relu_type)
102 | self.norm = NormLayer(out_channels, norm_type=norm_type)
103 |
104 | def forward(self, x):
105 | out = self.scale_func(x)
106 | if self.use_pad:
107 | out = self.reflection_pad(out)
108 | out = self.conv2d(out)
109 | out = self.norm(out)
110 | out = self.relu(out)
111 | return out
112 |
113 |
114 | class ResidualBlock(nn.Module):
115 | """
116 | Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
117 | """
118 |
119 | def __init__(self, c_in, c_out, relu_type="prelu", norm_type="bn", scale="none"):
120 | super(ResidualBlock, self).__init__()
121 |
122 | if scale == "none" and c_in == c_out:
123 | self.shortcut_func = lambda x: x
124 | else:
125 | self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
126 |
127 | scale_config_dict = {"down": ["none", "down"], "up": ["up", "none"], "none": ["none", "none"]}
128 | scale_conf = scale_config_dict[scale]
129 |
130 | self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
131 | self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type="none")
132 |
133 | def forward(self, x):
134 | identity = self.shortcut_func(x)
135 |
136 | res = self.conv1(x)
137 | res = self.conv2(res)
138 | return identity + res
139 |
140 |
141 | class ParseNet(nn.Module):
142 | def __init__(
143 | self,
144 | in_size=128,
145 | out_size=128,
146 | min_feat_size=32,
147 | base_ch=64,
148 | parsing_ch=19,
149 | res_depth=10,
150 | relu_type="LeakyReLU",
151 | norm_type="bn",
152 | ch_range=[32, 256],
153 | ):
154 | super().__init__()
155 | self.res_depth = res_depth
156 | act_args = {"norm_type": norm_type, "relu_type": relu_type}
157 | min_ch, max_ch = ch_range
158 |
159 | ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
160 | min_feat_size = min(in_size, min_feat_size)
161 |
162 | down_steps = int(np.log2(in_size // min_feat_size))
163 | up_steps = int(np.log2(out_size // min_feat_size))
164 |
165 | # =============== define encoder-body-decoder ====================
166 | self.encoder = []
167 | self.encoder.append(ConvLayer(3, base_ch, 3, 1))
168 | head_ch = base_ch
169 | for i in range(down_steps):
170 | cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
171 | self.encoder.append(ResidualBlock(cin, cout, scale="down", **act_args))
172 | head_ch = head_ch * 2
173 |
174 | self.body = []
175 | for i in range(res_depth):
176 | self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
177 |
178 | self.decoder = []
179 | for i in range(up_steps):
180 | cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
181 | self.decoder.append(ResidualBlock(cin, cout, scale="up", **act_args))
182 | head_ch = head_ch // 2
183 |
184 | self.encoder = nn.Sequential(*self.encoder)
185 | self.body = nn.Sequential(*self.body)
186 | self.decoder = nn.Sequential(*self.decoder)
187 | self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
188 | self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
189 |
190 | def forward(self, x):
191 | feat = self.encoder(x)
192 | x = feat + self.body(feat)
193 | x = self.decoder(x)
194 | out_img = self.out_img_conv(x)
195 | out_mask = self.out_mask_conv(x)
196 | return out_mask, out_img
197 |
--------------------------------------------------------------------------------
/codeformer/facelib/parsing/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | """3x3 convolution with padding"""
7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | def __init__(self, in_chan, out_chan, stride=1):
12 | super(BasicBlock, self).__init__()
13 | self.conv1 = conv3x3(in_chan, out_chan, stride)
14 | self.bn1 = nn.BatchNorm2d(out_chan)
15 | self.conv2 = conv3x3(out_chan, out_chan)
16 | self.bn2 = nn.BatchNorm2d(out_chan)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.downsample = None
19 | if in_chan != out_chan or stride != 1:
20 | self.downsample = nn.Sequential(
21 | nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
22 | nn.BatchNorm2d(out_chan),
23 | )
24 |
25 | def forward(self, x):
26 | residual = self.conv1(x)
27 | residual = F.relu(self.bn1(residual))
28 | residual = self.conv2(residual)
29 | residual = self.bn2(residual)
30 |
31 | shortcut = x
32 | if self.downsample is not None:
33 | shortcut = self.downsample(x)
34 |
35 | out = shortcut + residual
36 | out = self.relu(out)
37 | return out
38 |
39 |
40 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
41 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
42 | for i in range(bnum - 1):
43 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
44 | return nn.Sequential(*layers)
45 |
46 |
47 | class ResNet18(nn.Module):
48 | def __init__(self):
49 | super(ResNet18, self).__init__()
50 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
51 | self.bn1 = nn.BatchNorm2d(64)
52 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
53 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
54 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
55 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
56 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
57 |
58 | def forward(self, x):
59 | x = self.conv1(x)
60 | x = F.relu(self.bn1(x))
61 | x = self.maxpool(x)
62 |
63 | x = self.layer1(x)
64 | feat8 = self.layer2(x) # 1/8
65 | feat16 = self.layer3(feat8) # 1/16
66 | feat32 = self.layer4(feat16) # 1/32
67 | return feat8, feat16, feat32
68 |
--------------------------------------------------------------------------------
/codeformer/facelib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
2 | from .misc import download_pretrained_models, img2tensor, load_file_from_url, scandir
3 |
4 | __all__ = [
5 | "align_crop_face_landmarks",
6 | "compute_increased_bbox",
7 | "get_valid_bboxes",
8 | "load_file_from_url",
9 | "download_pretrained_models",
10 | "paste_face_back",
11 | "img2tensor",
12 | "scandir",
13 | ]
14 |
--------------------------------------------------------------------------------
/codeformer/facelib/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from urllib.parse import urlparse
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 | from torch.hub import download_url_to_file, get_dir
10 |
11 | # from basicsr.utils.download_util import download_file_from_google_drive
12 |
13 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14 |
15 |
16 | def download_pretrained_models(file_ids, save_path_root):
17 | import gdown
18 |
19 | os.makedirs(save_path_root, exist_ok=True)
20 |
21 | for file_name, file_id in file_ids.items():
22 | file_url = "https://drive.google.com/uc?id=" + file_id
23 | save_path = osp.abspath(osp.join(save_path_root, file_name))
24 | if osp.exists(save_path):
25 | user_response = input(f"{file_name} already exist. Do you want to cover it? Y/N\n")
26 | if user_response.lower() == "y":
27 | print(f"Covering {file_name} to {save_path}")
28 | gdown.download(file_url, save_path, quiet=False)
29 | # download_file_from_google_drive(file_id, save_path)
30 | elif user_response.lower() == "n":
31 | print(f"Skipping {file_name}")
32 | else:
33 | raise ValueError("Wrong input. Only accepts Y/N.")
34 | else:
35 | print(f"Downloading {file_name} to {save_path}")
36 | gdown.download(file_url, save_path, quiet=False)
37 | # download_file_from_google_drive(file_id, save_path)
38 |
39 |
40 | def imwrite(img, file_path, params=None, auto_mkdir=True):
41 | """Write image to file.
42 |
43 | Args:
44 | img (ndarray): Image array to be written.
45 | file_path (str): Image file path.
46 | params (None or list): Same as opencv's :func:`imwrite` interface.
47 | auto_mkdir (bool): If the parent folder of `file_path` does not exist,
48 | whether to create it automatically.
49 |
50 | Returns:
51 | bool: Successful or not.
52 | """
53 | if auto_mkdir:
54 | dir_name = os.path.abspath(os.path.dirname(file_path))
55 | os.makedirs(dir_name, exist_ok=True)
56 | return cv2.imwrite(file_path, img, params)
57 |
58 |
59 | def img2tensor(imgs, bgr2rgb=True, float32=True):
60 | """Numpy array to tensor.
61 |
62 | Args:
63 | imgs (list[ndarray] | ndarray): Input images.
64 | bgr2rgb (bool): Whether to change bgr to rgb.
65 | float32 (bool): Whether to change to float32.
66 |
67 | Returns:
68 | list[tensor] | tensor: Tensor images. If returned results only have
69 | one element, just return tensor.
70 | """
71 |
72 | def _totensor(img, bgr2rgb, float32):
73 | if img.shape[2] == 3 and bgr2rgb:
74 | if img.dtype == "float64":
75 | img = img.astype("float32")
76 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
77 | img = torch.from_numpy(img.transpose(2, 0, 1))
78 | if float32:
79 | img = img.float()
80 | return img
81 |
82 | if isinstance(imgs, list):
83 | return [_totensor(img, bgr2rgb, float32) for img in imgs]
84 | else:
85 | return _totensor(imgs, bgr2rgb, float32)
86 |
87 |
88 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
89 | """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py"""
90 | if model_dir is None:
91 | hub_dir = get_dir()
92 | model_dir = os.path.join(hub_dir, "checkpoints")
93 |
94 | os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
95 |
96 | parts = urlparse(url)
97 | filename = os.path.basename(parts.path)
98 | if file_name is not None:
99 | filename = file_name
100 | cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
101 | if not os.path.exists(cached_file):
102 | print(f'Downloading: "{url}" to {cached_file}\n')
103 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
104 | return cached_file
105 |
106 |
107 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
108 | """Scan a directory to find the interested files.
109 | Args:
110 | dir_path (str): Path of the directory.
111 | suffix (str | tuple(str), optional): File suffix that we are
112 | interested in. Default: None.
113 | recursive (bool, optional): If set to True, recursively scan the
114 | directory. Default: False.
115 | full_path (bool, optional): If set to True, include the dir_path.
116 | Default: False.
117 | Returns:
118 | A generator for all the interested files with relative paths.
119 | """
120 |
121 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
122 | raise TypeError('"suffix" must be a string or tuple of strings')
123 |
124 | root = dir_path
125 |
126 | def _scandir(dir_path, suffix, recursive):
127 | for entry in os.scandir(dir_path):
128 | if not entry.name.startswith(".") and entry.is_file():
129 | if full_path:
130 | return_path = entry.path
131 | else:
132 | return_path = osp.relpath(entry.path, root)
133 |
134 | if suffix is None:
135 | yield return_path
136 | elif return_path.endswith(suffix):
137 | yield return_path
138 | else:
139 | if recursive:
140 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
141 | else:
142 | continue
143 |
144 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
145 |
146 |
147 | def is_gray(img, threshold=10):
148 | img = Image.fromarray(img)
149 | if len(img.getbands()) == 1:
150 | return True
151 | img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
152 | img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
153 | img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
154 | diff1 = (img1 - img2).var()
155 | diff2 = (img2 - img3).var()
156 | diff3 = (img3 - img1).var()
157 | diff_sum = (diff1 + diff2 + diff3) / 3.0
158 | if diff_sum <= threshold:
159 | return True
160 | else:
161 | return False
162 |
163 |
164 | def rgb2gray(img, out_channel=3):
165 | r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
166 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
167 | if out_channel == 3:
168 | gray = gray[:, :, np.newaxis].repeat(3, axis=2)
169 | return gray
170 |
171 |
172 | def bgr2gray(img, out_channel=3):
173 | b, g, r = img[:, :, 0], img[:, :, 1], img[:, :, 2]
174 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
175 | if out_channel == 3:
176 | gray = gray[:, :, np.newaxis].repeat(3, axis=2)
177 | return gray
178 |
179 |
180 | def calc_mean_std(feat, eps=1e-5):
181 | """
182 | Args:
183 | feat (numpy): 3D [w h c]s
184 | """
185 | size = feat.shape
186 | assert len(size) == 3, "The input feature should be 3D tensor."
187 | c = size[2]
188 | feat_var = feat.reshape(-1, c).var(axis=0) + eps
189 | feat_std = np.sqrt(feat_var).reshape(1, 1, c)
190 | feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
191 | return feat_mean, feat_std
192 |
193 |
194 | def adain_npy(content_feat, style_feat):
195 | """Adaptive instance normalization for numpy.
196 |
197 | Args:
198 | content_feat (numpy): The input feature.
199 | style_feat (numpy): The reference feature.
200 | """
201 | size = content_feat.shape
202 | style_mean, style_std = calc_mean_std(style_feat)
203 | content_mean, content_std = calc_mean_std(content_feat)
204 | normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
205 | return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
206 |
--------------------------------------------------------------------------------
/codeformer/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kadirnar/codeformer-pip/374cd2be99baaf15770995b1d054456631867ea5/codeformer/scripts/__init__.py
--------------------------------------------------------------------------------
/codeformer/scripts/code_format.sh:
--------------------------------------------------------------------------------
1 | black . --config pyproject.toml
2 | isort .
--------------------------------------------------------------------------------
/codeformer/scripts/crop_align_face.py:
--------------------------------------------------------------------------------
1 | """
2 | brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
3 | author: lzhbrian (https://lzhbrian.me)
4 | link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5
5 | date: 2020.1.5
6 | note: code is heavily borrowed from
7 | https://github.com/NVlabs/ffhq-dataset
8 | http://dlib.net/face_landmark_detection.py.html
9 | requirements:
10 | conda install Pillow numpy scipy
11 | conda install -c conda-forge dlib
12 | # download face landmark model from:
13 | # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
14 | """
15 |
16 | import argparse
17 | import glob
18 | import os
19 | import sys
20 |
21 | import cv2
22 | import dlib
23 | import numpy as np
24 | import PIL
25 | import PIL.Image
26 | import scipy
27 | import scipy.ndimage
28 |
29 | # download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
30 | predictor = dlib.shape_predictor("weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat")
31 |
32 |
33 | def get_landmark(filepath, only_keep_largest=True):
34 | """get landmark with dlib
35 | :return: np.array shape=(68, 2)
36 | """
37 | detector = dlib.get_frontal_face_detector()
38 |
39 | img = dlib.load_rgb_image(filepath)
40 | dets = detector(img, 1)
41 |
42 | # Shangchen modified
43 | print("Number of faces detected: {}".format(len(dets)))
44 | if only_keep_largest:
45 | print("Detect several faces and only keep the largest.")
46 | face_areas = []
47 | for k, d in enumerate(dets):
48 | face_area = (d.right() - d.left()) * (d.bottom() - d.top())
49 | face_areas.append(face_area)
50 |
51 | largest_idx = face_areas.index(max(face_areas))
52 | d = dets[largest_idx]
53 | shape = predictor(img, d)
54 | print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1)))
55 | else:
56 | for k, d in enumerate(dets):
57 | print(
58 | "Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
59 | k, d.left(), d.top(), d.right(), d.bottom()
60 | )
61 | )
62 | # Get the landmarks/parts for the face in box d.
63 | shape = predictor(img, d)
64 | print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1)))
65 |
66 | t = list(shape.parts())
67 | a = []
68 | for tt in t:
69 | a.append([tt.x, tt.y])
70 | lm = np.array(a)
71 | # lm is a shape=(68,2) np.array
72 | return lm
73 |
74 |
75 | def align_face(filepath, out_path):
76 | """
77 | :param filepath: str
78 | :return: PIL Image
79 | """
80 | try:
81 | lm = get_landmark(filepath)
82 | except:
83 | print("No landmark ...")
84 | return
85 |
86 | lm_chin = lm[0:17] # left-right
87 | lm_eyebrow_left = lm[17:22] # left-right
88 | lm_eyebrow_right = lm[22:27] # left-right
89 | lm_nose = lm[27:31] # top-down
90 | lm_nostrils = lm[31:36] # top-down
91 | lm_eye_left = lm[36:42] # left-clockwise
92 | lm_eye_right = lm[42:48] # left-clockwise
93 | lm_mouth_outer = lm[48:60] # left-clockwise
94 | lm_mouth_inner = lm[60:68] # left-clockwise
95 |
96 | # Calculate auxiliary vectors.
97 | eye_left = np.mean(lm_eye_left, axis=0)
98 | eye_right = np.mean(lm_eye_right, axis=0)
99 | eye_avg = (eye_left + eye_right) * 0.5
100 | eye_to_eye = eye_right - eye_left
101 | mouth_left = lm_mouth_outer[0]
102 | mouth_right = lm_mouth_outer[6]
103 | mouth_avg = (mouth_left + mouth_right) * 0.5
104 | eye_to_mouth = mouth_avg - eye_avg
105 |
106 | # Choose oriented crop rectangle.
107 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
108 | x /= np.hypot(*x)
109 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
110 | y = np.flipud(x) * [-1, 1]
111 | c = eye_avg + eye_to_mouth * 0.1
112 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
113 | qsize = np.hypot(*x) * 2
114 |
115 | # read image
116 | img = PIL.Image.open(filepath)
117 |
118 | output_size = 512
119 | transform_size = 4096
120 | enable_padding = False
121 |
122 | # Shrink.
123 | shrink = int(np.floor(qsize / output_size * 0.5))
124 | if shrink > 1:
125 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
126 | img = img.resize(rsize, PIL.Image.ANTIALIAS)
127 | quad /= shrink
128 | qsize /= shrink
129 |
130 | # Crop.
131 | border = max(int(np.rint(qsize * 0.1)), 3)
132 | crop = (
133 | int(np.floor(min(quad[:, 0]))),
134 | int(np.floor(min(quad[:, 1]))),
135 | int(np.ceil(max(quad[:, 0]))),
136 | int(np.ceil(max(quad[:, 1]))),
137 | )
138 | crop = (
139 | max(crop[0] - border, 0),
140 | max(crop[1] - border, 0),
141 | min(crop[2] + border, img.size[0]),
142 | min(crop[3] + border, img.size[1]),
143 | )
144 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
145 | img = img.crop(crop)
146 | quad -= crop[0:2]
147 |
148 | # Pad.
149 | pad = (
150 | int(np.floor(min(quad[:, 0]))),
151 | int(np.floor(min(quad[:, 1]))),
152 | int(np.ceil(max(quad[:, 0]))),
153 | int(np.ceil(max(quad[:, 1]))),
154 | )
155 | pad = (
156 | max(-pad[0] + border, 0),
157 | max(-pad[1] + border, 0),
158 | max(pad[2] - img.size[0] + border, 0),
159 | max(pad[3] - img.size[1] + border, 0),
160 | )
161 | if enable_padding and max(pad) > border - 4:
162 | pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
163 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), "reflect")
164 | h, w, _ = img.shape
165 | y, x, _ = np.ogrid[:h, :w, :1]
166 | mask = np.maximum(
167 | 1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
168 | 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]),
169 | )
170 | blur = qsize * 0.02
171 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
172 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
173 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), "RGB")
174 | quad += pad[:2]
175 |
176 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
177 |
178 | if output_size < transform_size:
179 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
180 |
181 | # Save aligned image.
182 | print("saveing: ", out_path)
183 | img.save(out_path)
184 |
185 | return img, np.max(quad[:, 0]) - np.min(quad[:, 0])
186 |
187 |
188 | if __name__ == "__main__":
189 | parser = argparse.ArgumentParser()
190 | parser.add_argument("--in_dir", type=str, default="./inputs/whole_imgs")
191 | parser.add_argument("--out_dir", type=str, default="./inputs/cropped_faces")
192 | args = parser.parse_args()
193 |
194 | img_list = sorted(glob.glob(f"{args.in_dir}/*.png"))
195 | img_list = sorted(img_list)
196 |
197 | for in_path in img_list:
198 | out_path = os.path.join(args.out_dir, in_path.split("/")[-1])
199 | out_path = out_path.replace(".jpg", ".png")
200 | size_ = align_face(in_path, out_path)
201 |
--------------------------------------------------------------------------------
/codeformer/scripts/download_pretrained_models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os import path as osp
4 |
5 | from basicsr.utils.download_util import load_file_from_url
6 |
7 |
8 | def download_pretrained_models(method, file_urls):
9 | save_path_root = f"./weights/{method}"
10 | os.makedirs(save_path_root, exist_ok=True)
11 |
12 | for file_name, file_url in file_urls.items():
13 | save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
14 |
15 |
16 | if __name__ == "__main__":
17 | parser = argparse.ArgumentParser()
18 |
19 | parser.add_argument(
20 | "method", type=str, help=("Options: 'CodeFormer' 'facelib' 'dlib'. Set to 'all' to download all the models.")
21 | )
22 | args = parser.parse_args()
23 |
24 | file_urls = {
25 | "CodeFormer": {
26 | "codeformer.pth": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
27 | },
28 | "facelib": {
29 | # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
30 | "detection_Resnet50_Final.pth": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
31 | "parsing_parsenet.pth": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
32 | },
33 | "dlib": {
34 | "mmod_human_face_detector-4cb19393.dat": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat",
35 | "shape_predictor_5_face_landmarks-c4b1e980.dat": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat",
36 | },
37 | }
38 |
39 | if args.method == "all":
40 | for method in file_urls.keys():
41 | download_pretrained_models(method, file_urls[method])
42 | else:
43 | download_pretrained_models(args.method, file_urls[args.method])
44 |
--------------------------------------------------------------------------------
/codeformer/scripts/download_pretrained_models_from_gdrive.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os import path as osp
4 |
5 | # from basicsr.utils.download_util import download_file_from_google_drive
6 | import gdown
7 |
8 |
9 | def download_pretrained_models(method, file_ids):
10 | save_path_root = f"./weights/{method}"
11 | os.makedirs(save_path_root, exist_ok=True)
12 |
13 | for file_name, file_id in file_ids.items():
14 | file_url = "https://drive.google.com/uc?id=" + file_id
15 | save_path = osp.abspath(osp.join(save_path_root, file_name))
16 | if osp.exists(save_path):
17 | user_response = input(f"{file_name} already exist. Do you want to cover it? Y/N\n")
18 | if user_response.lower() == "y":
19 | print(f"Covering {file_name} to {save_path}")
20 | gdown.download(file_url, save_path, quiet=False)
21 | # download_file_from_google_drive(file_id, save_path)
22 | elif user_response.lower() == "n":
23 | print(f"Skipping {file_name}")
24 | else:
25 | raise ValueError("Wrong input. Only accepts Y/N.")
26 | else:
27 | print(f"Downloading {file_name} to {save_path}")
28 | gdown.download(file_url, save_path, quiet=False)
29 | # download_file_from_google_drive(file_id, save_path)
30 |
31 |
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser()
34 |
35 | parser.add_argument(
36 | "method", type=str, help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models.")
37 | )
38 | args = parser.parse_args()
39 |
40 | # file name: file id
41 | # 'dlib': {
42 | # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
43 | # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
44 | # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
45 | # }
46 | file_ids = {
47 | "CodeFormer": {"codeformer.pth": "1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB"},
48 | "facelib": {
49 | "yolov5l-face.pth": "131578zMA6B2x8VQHyHfa6GEPtulMCNzV",
50 | "parsing_parsenet.pth": "16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK",
51 | },
52 | }
53 |
54 | if args.method == "all":
55 | for method in file_ids.keys():
56 | download_pretrained_models(method, file_ids[method])
57 | else:
58 | download_pretrained_models(args.method, file_ids[args.method])
59 |
--------------------------------------------------------------------------------
/codeformer/scripts/package.sh:
--------------------------------------------------------------------------------
1 | python setup.py sdist
2 | twine upload dist/*
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 120
3 |
4 | [tool.isort]
5 | line_length = 120
6 | profile = "black"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | addict
2 | future
3 | lmdb
4 | numpy
5 | opencv-python
6 | Pillow
7 | pyyaml
8 | requests
9 | scikit-image
10 | scipy
11 | tb-nightly
12 | torch>=1.7.1
13 | torchvision
14 | tqdm
15 | yapf
16 | lpips
17 | gdown # supports downloading the large file from Google Drive
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 119
3 | max-complexity = 18
4 | exclude =.git,__pycache__,docs/source/conf.py,build,dist
5 | ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,E501,E722,W503,B006
6 | inline-quotes = "
7 | statistics = true
8 | count = true
9 | [mypy]
10 | ignore_missing_imports = True
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import re
4 |
5 | import setuptools
6 |
7 |
8 | def get_long_description():
9 | base_dir = os.path.abspath(os.path.dirname(__file__))
10 | with io.open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f:
11 | return f.read()
12 |
13 |
14 | def get_requirements():
15 | with open("requirements.txt") as f:
16 | return f.read().splitlines()
17 |
18 |
19 | def get_version():
20 | current_dir = os.path.abspath(os.path.dirname(__file__))
21 | version_file = os.path.join(current_dir, "codeformer", "__init__.py")
22 | with io.open(version_file, encoding="utf-8") as f:
23 | return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', f.read(), re.M).group(1)
24 |
25 |
26 | _DEV_REQUIREMENTS = [
27 | "black==21.7b0",
28 | "flake8==3.9.2",
29 | "isort==5.9.2",
30 | "click==8.0.4",
31 | "importlib-metadata>=1.1.0,<4.3;python_version<'3.8'",
32 | ]
33 |
34 |
35 | setuptools.setup(
36 | name="codeformer-pip",
37 | version=get_version(),
38 | author="kadirnar",
39 | license="S-Lab",
40 | description="PyTorch implementation of CodeFormer",
41 | long_description=get_long_description(),
42 | long_description_content_type="text/markdown",
43 | url="https://github.com/kadirnar/codeformer-pip",
44 | packages=setuptools.find_packages(exclude=["tests"]),
45 | include_package_data=True,
46 | python_requires=">=3.6",
47 | install_requires=get_requirements(),
48 | classifiers=[
49 | "Development Status :: 5 - Production/Stable",
50 | "Operating System :: OS Independent",
51 | "Intended Audience :: Developers",
52 | "Intended Audience :: Science/Research",
53 | "Programming Language :: Python :: 3",
54 | "Programming Language :: Python :: 3.7",
55 | "Programming Language :: Python :: 3.8",
56 | "Programming Language :: Python :: 3.9",
57 | "Topic :: Software Development :: Libraries",
58 | "Topic :: Software Development :: Libraries :: Python Modules",
59 | "Topic :: Education",
60 | "Topic :: Scientific/Engineering",
61 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
62 | ],
63 | keywords="pytorch, codeformer, face, face-swap, face-swapping, face-swap-pytorch, face-swapping-pytorch",
64 | )
65 |
--------------------------------------------------------------------------------