├── basicsr ├── ops │ ├── __init__.py │ ├── upfirdn2d │ │ ├── __init__.py │ │ └── src │ │ │ └── upfirdn2d.cpp │ ├── fused_act │ │ ├── __init__.py │ │ ├── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ │ └── fused_act.py │ └── dcn │ │ └── __init__.py ├── metrics │ ├── niqe_pris_params.npz │ ├── __init__.py │ ├── metric_util.py │ ├── README.md │ ├── README_CN.md │ ├── test_metrics │ │ └── test_psnr_ssim.py │ └── fid.py ├── data │ ├── meta_info │ │ ├── meta_info_REDS4_test_GT.txt │ │ ├── meta_info_REDSofficial4_test_GT.txt │ │ └── meta_info_REDSval_official_test_GT.txt │ ├── data_sampler.py │ ├── single_image_dataset.py │ ├── ffhq_dataset.py │ ├── prefetch_dataloader.py │ ├── __init__.py │ ├── paired_image_dataset.py │ └── realesrgan_paired_dataset.py ├── __init__.py ├── models │ ├── video_gan_model.py │ ├── __init__.py │ ├── swinir_model.py │ ├── edvr_model.py │ ├── esrgan_model.py │ └── lr_scheduler.py ├── archs │ ├── __init__.py │ ├── edsr_arch.py │ ├── srresnet_arch.py │ ├── srvgg_arch.py │ ├── spynet_arch.py │ ├── rrdbnet_arch.py │ └── rcan_arch.py ├── losses │ └── __init__.py ├── utils │ ├── __init__.py │ ├── registry.py │ ├── plot_util.py │ ├── img_process_util.py │ ├── dist_util.py │ ├── download_util.py │ └── misc.py └── test.py ├── facelib ├── detection │ ├── yolov5face │ │ ├── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── yolov5n.yaml │ │ │ ├── yolov5l.yaml │ │ │ └── experimental.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── extract_ckpt.py │ │ │ ├── autoanchor.py │ │ │ ├── torch_utils.py │ │ │ └── datasets.py │ └── __init__.py ├── utils │ └── __init__.py └── parsing │ ├── __init__.py │ └── resnet.py ├── pic ├── 001.jpg ├── 002.jpg ├── 003.jpg ├── 004.jpg └── 005.jpg ├── assets ├── 0368.png ├── 0729.png ├── 0885.png ├── 0934.png ├── .DS_Store ├── Hepburn.png ├── oldimg_05.png ├── DifFace_Framework.png └── Solvay_conference.png ├── configs ├── .DS_Store ├── sample │ ├── iddpm_ffhq512.yaml │ └── iddpm_ffhq512_swinir.yaml └── training │ ├── diffusion_ffhq512.yaml │ └── swinir_ffhq512.yaml ├── datapipe ├── .DS_Store ├── prepare │ ├── .DS_Store │ └── face │ │ ├── make_testing_data_hq.py │ │ ├── make_lfw.py │ │ ├── big2small_face.py │ │ ├── split_train_val.py │ │ ├── make_testing_data_bicubic.py │ │ ├── degradation_split.py │ │ └── make_testing_data.py ├── __init__.py ├── degradation_bsrgan │ ├── test.png │ ├── __init__.py │ ├── README.md │ ├── utils_logger.py │ └── utils_googledownload.py └── face_degradation_testing.py ├── testdata ├── whole_imgs │ ├── 00.jpg │ ├── 01.jpg │ ├── 02.png │ ├── 03.png │ ├── 04.jpg │ ├── 05.jpg │ └── .DS_Store ├── cropped_faces │ ├── 0143.png │ ├── 0240.png │ ├── 0342.png │ ├── 0345.png │ ├── 0368.png │ ├── 0412.png │ ├── 0444.png │ ├── 0478.png │ ├── 0500.png │ ├── 0599.png │ ├── 0717.png │ ├── 0720.png │ ├── 0729.png │ ├── 0763.png │ ├── 0770.png │ ├── 0777.png │ ├── 0885.png │ ├── 0934.png │ ├── Solvay_conference_1927_0018.png │ └── Solvay_conference_1927_2_16.png └── Solvay_conference_1927.png ├── models ├── __pycache__ │ ├── unet.cpython-38.pyc │ ├── losses.cpython-38.pyc │ ├── basic_ops.cpython-38.pyc │ ├── fp16_util.cpython-38.pyc │ ├── resample.cpython-38.pyc │ ├── respace.cpython-38.pyc │ ├── solvers.cpython-38.pyc │ ├── respace_ori.cpython-38.pyc │ ├── script_util.cpython-38.pyc │ ├── script_util_ori.cpython-38.pyc │ ├── gaussian_diffusion.cpython-38.pyc │ └── gaussian_diffusion_ori.cpython-38.pyc ├── script_util.py ├── fp16_util.py ├── losses.py ├── srcnn.py ├── basic_ops.py └── respace.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── util_image.cpython-36.pyc │ ├── util_image.cpython-37.pyc │ ├── util_image.cpython-38.pyc │ ├── util_image.cpython-39.pyc │ ├── util_net.cpython-36.pyc │ ├── util_net.cpython-38.pyc │ ├── util_opts.cpython-36.pyc │ ├── util_opts.cpython-37.pyc │ ├── util_opts.cpython-38.pyc │ ├── util_sisr.cpython-38.pyc │ ├── util_sisr.cpython-39.pyc │ ├── util_common.cpython-36.pyc │ ├── util_common.cpython-37.pyc │ ├── util_common.cpython-38.pyc │ └── util_denoising.cpython-38.pyc ├── util_opts.py ├── util_common.py └── util_net.py ├── ResizeRight ├── __pycache__ │ ├── interp_methods.cpython-38.pyc │ └── resize_right.cpython-38.pyc ├── LICENSE └── interp_methods.py ├── main_sr.py ├── main_diffusion.py ├── LICENSE ├── environment.yaml └── README.md /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pic/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/pic/001.jpg -------------------------------------------------------------------------------- /pic/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/pic/002.jpg -------------------------------------------------------------------------------- /pic/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/pic/003.jpg -------------------------------------------------------------------------------- /pic/004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/pic/004.jpg -------------------------------------------------------------------------------- /pic/005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/pic/005.jpg -------------------------------------------------------------------------------- /assets/0368.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/assets/0368.png -------------------------------------------------------------------------------- /assets/0729.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/assets/0729.png -------------------------------------------------------------------------------- /assets/0885.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/assets/0885.png -------------------------------------------------------------------------------- /assets/0934.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/assets/0934.png -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/assets/.DS_Store -------------------------------------------------------------------------------- /assets/Hepburn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/assets/Hepburn.png -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/configs/.DS_Store -------------------------------------------------------------------------------- /datapipe/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/datapipe/.DS_Store -------------------------------------------------------------------------------- /assets/oldimg_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/assets/oldimg_05.png -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /assets/DifFace_Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/assets/DifFace_Framework.png -------------------------------------------------------------------------------- /assets/Solvay_conference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/assets/Solvay_conference.png -------------------------------------------------------------------------------- /datapipe/prepare/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/datapipe/prepare/.DS_Store -------------------------------------------------------------------------------- /testdata/whole_imgs/00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/whole_imgs/00.jpg -------------------------------------------------------------------------------- /testdata/whole_imgs/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/whole_imgs/01.jpg -------------------------------------------------------------------------------- /testdata/whole_imgs/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/whole_imgs/02.png -------------------------------------------------------------------------------- /testdata/whole_imgs/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/whole_imgs/03.png -------------------------------------------------------------------------------- /testdata/whole_imgs/04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/whole_imgs/04.jpg -------------------------------------------------------------------------------- /testdata/whole_imgs/05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/whole_imgs/05.jpg -------------------------------------------------------------------------------- /testdata/whole_imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/whole_imgs/.DS_Store -------------------------------------------------------------------------------- /testdata/cropped_faces/0143.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0143.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0240.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0240.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0342.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0342.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0345.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0345.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0368.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0368.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0412.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0412.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0444.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0444.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0478.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0478.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0500.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0500.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0599.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0599.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0717.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0717.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0720.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0720.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0729.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0729.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0763.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0763.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0770.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0770.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0777.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0777.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0885.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0885.png -------------------------------------------------------------------------------- /testdata/cropped_faces/0934.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/0934.png -------------------------------------------------------------------------------- /testdata/Solvay_conference_1927.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/Solvay_conference_1927.png -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /datapipe/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-06-07 17:27:22 4 | 5 | -------------------------------------------------------------------------------- /datapipe/degradation_bsrgan/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/datapipe/degradation_bsrgan/test.png -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-01-18 11:40:23 4 | 5 | 6 | -------------------------------------------------------------------------------- /models/__pycache__/basic_ops.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/basic_ops.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/fp16_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/fp16_util.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/resample.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/respace.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/respace.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/solvers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/solvers.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_image.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_image.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_image.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_image.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_image.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_image.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_image.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_net.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_net.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_opts.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_opts.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_opts.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_opts.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_opts.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_opts.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_sisr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_sisr.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_sisr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_sisr.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/respace_ori.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/respace_ori.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/script_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/script_util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_common.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_common.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_common.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util_denoising.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/utils/__pycache__/util_denoising.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 011 100 (720,1280,3) 3 | 015 100 (720,1280,3) 4 | 020 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /datapipe/degradation_bsrgan/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-10-17 15:48:20 4 | 5 | 6 | -------------------------------------------------------------------------------- /models/__pycache__/script_util_ori.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/script_util_ori.cpython-38.pyc -------------------------------------------------------------------------------- /ResizeRight/__pycache__/interp_methods.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/ResizeRight/__pycache__/interp_methods.cpython-38.pyc -------------------------------------------------------------------------------- /ResizeRight/__pycache__/resize_right.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/ResizeRight/__pycache__/resize_right.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 246 100 (720,1280,3) 4 | 257 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /models/__pycache__/gaussian_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/gaussian_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /testdata/cropped_faces/Solvay_conference_1927_0018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/Solvay_conference_1927_0018.png -------------------------------------------------------------------------------- /testdata/cropped_faces/Solvay_conference_1927_2_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/testdata/cropped_faces/Solvay_conference_1927_2_16.png -------------------------------------------------------------------------------- /models/__pycache__/gaussian_diffusion_ori.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/DifFace/master/models/__pycache__/gaussian_diffusion_ori.cpython-38.pyc -------------------------------------------------------------------------------- /datapipe/degradation_bsrgan/README.md: -------------------------------------------------------------------------------- 1 | 2 | # How to use the degradation model: 3 | ```python 4 | from utils import utils_blindsr as blindsr 5 | img_lq, img_hq = blindsr.degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64) 6 | ``` 7 | 8 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/extract_ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | sys.path.insert(0,'./facelib/detection/yolov5face') 4 | model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model'] 5 | torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth') -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .test import * 10 | from .train import * 11 | from .utils import * 12 | # from .version import __gitsha__, __version__ 13 | -------------------------------------------------------------------------------- /facelib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back 2 | from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir 3 | 4 | __all__ = [ 5 | 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 6 | 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' 7 | ] 8 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/autoanchor.py: -------------------------------------------------------------------------------- 1 | # Auto-anchor utils 2 | 3 | 4 | def check_anchor_order(m): 5 | # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary 6 | a = m.anchor_grid.prod(-1).view(-1) # anchor area 7 | da = a[-1] - a[0] # delta a 8 | ds = m.stride[-1] - m.stride[0] # delta s 9 | if da.sign() != ds.sign(): # same order 10 | print("Reversing anchor order") 11 | m.anchors[:] = m.anchors.flip(0) 12 | m.anchor_grid[:] = m.anchor_grid.flip(0) 13 | -------------------------------------------------------------------------------- /basicsr/models/video_gan_model.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import MODEL_REGISTRY 2 | from .srgan_model import SRGANModel 3 | from .video_base_model import VideoBaseModel 4 | 5 | 6 | @MODEL_REGISTRY.register() 7 | class VideoGANModel(SRGANModel, VideoBaseModel): 8 | """Video GAN model. 9 | 10 | Use multiple inheritance. 11 | It will first use the functions of :class:`SRGANModel`: 12 | 13 | - :func:`init_training_settings` 14 | - :func:`setup_optimizers` 15 | - :func:`optimize_parameters` 16 | - :func:`save` 17 | 18 | Then find functions in :class:`VideoBaseModel`. 19 | """ 20 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .niqe import calculate_niqe 5 | from .psnr_ssim import calculate_psnr, calculate_ssim 6 | 7 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 8 | 9 | 10 | def calculate_metric(data, opt): 11 | """Calculate metric from data and options. 12 | 13 | Args: 14 | opt (dict): Configuration. It must contain: 15 | type (str): Model type. 16 | """ 17 | opt = deepcopy(opt) 18 | metric_type = opt.pop('type') 19 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 20 | return metric 21 | -------------------------------------------------------------------------------- /utils/util_opts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2021-11-24 15:07:43 4 | 5 | def update_args(args_json, args_parser): 6 | for arg in vars(args_parser): 7 | args_json[arg] = getattr(args_parser, arg) 8 | 9 | def str2bool(v): 10 | """ 11 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 12 | """ 13 | if isinstance(v, bool): 14 | return v 15 | if v.lower() in ("yes", "true", "t", "y", "1"): 16 | return True 17 | elif v.lower() in ("no", "false", "f", "n", "0"): 18 | return False 19 | else: 20 | raise argparse.ArgumentTypeError("boolean value expected") 21 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 242 100 (720,1280,3) 4 | 243 100 (720,1280,3) 5 | 244 100 (720,1280,3) 6 | 245 100 (720,1280,3) 7 | 246 100 (720,1280,3) 8 | 247 100 (720,1280,3) 9 | 248 100 (720,1280,3) 10 | 249 100 (720,1280,3) 11 | 250 100 (720,1280,3) 12 | 251 100 (720,1280,3) 13 | 252 100 (720,1280,3) 14 | 253 100 (720,1280,3) 15 | 254 100 (720,1280,3) 16 | 255 100 (720,1280,3) 17 | 256 100 (720,1280,3) 18 | 257 100 (720,1280,3) 19 | 258 100 (720,1280,3) 20 | 259 100 (720,1280,3) 21 | 260 100 (720,1280,3) 22 | 261 100 (720,1280,3) 23 | 262 100 (720,1280,3) 24 | 263 100 (720,1280,3) 25 | 264 100 (720,1280,3) 26 | 265 100 (720,1280,3) 27 | 266 100 (720,1280,3) 28 | 267 100 (720,1280,3) 29 | 268 100 (720,1280,3) 30 | 269 100 (720,1280,3) 31 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with '_arch.py' 12 | arch_folder = osp.dirname(osp.abspath(__file__)) 13 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 14 | # import all the arch modules 15 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 16 | 17 | 18 | def build_network(opt): 19 | opt = deepcopy(opt) 20 | network_type = opt.pop('type') 21 | net = ARCH_REGISTRY.get(network_type)(**opt) 22 | logger = get_root_logger() 23 | logger.info(f'Network [{net.__class__.__name__}] is created.') 24 | return net 25 | -------------------------------------------------------------------------------- /configs/sample/iddpm_ffhq512.yaml: -------------------------------------------------------------------------------- 1 | gpu_id: "" 2 | seed: 10000 3 | display: True 4 | im_size: 512 5 | 6 | diffusion: 7 | target: models.script_util.create_gaussian_diffusion 8 | params: 9 | steps: 1000 10 | learn_sigma: True 11 | sigma_small: False 12 | noise_schedule: linear 13 | use_kl: False 14 | predict_xstart: False 15 | rescale_timesteps: False 16 | rescale_learned_sigmas: True 17 | timestep_respacing: "1000" 18 | 19 | model: 20 | target: models.unet.UNetModel 21 | ckpt_path: pretrained_zoo/iddpm_ffhq512/ema0999_model_500000.pth 22 | params: 23 | image_size: 512 24 | in_channels: 3 25 | model_channels: 32 26 | out_channels: 6 27 | attention_resolutions: [32, 16, 8] 28 | dropout: 0 29 | channel_mult: [1, 2, 4, 8, 8, 16, 16] 30 | num_res_blocks: [1, 2, 2, 2, 2, 3, 4] 31 | conv_resample: True 32 | dims: 2 33 | use_fp16: False 34 | num_head_channels: 64 35 | use_scale_shift_norm: True 36 | resblock_updown: False 37 | use_new_attention_order: False 38 | 39 | model_ir: ~ 40 | -------------------------------------------------------------------------------- /facelib/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from facelib.utils import load_file_from_url 4 | from .bisenet import BiSeNet 5 | from .parsenet import ParseNet 6 | 7 | 8 | def init_parsing_model(model_name='bisenet', half=False, device='cuda'): 9 | if model_name == 'bisenet': 10 | model = BiSeNet(num_class=19) 11 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' 12 | elif model_name == 'parsenet': 13 | model = ParseNet(in_size=512, out_size=512, parsing_ch=19) 14 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' 15 | else: 16 | raise NotImplementedError(f'{model_name} is not implemented.') 17 | 18 | model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) 19 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 20 | model.load_state_dict(load_net, strict=True) 21 | model.eval() 22 | model = model.to(device) 23 | return model 24 | -------------------------------------------------------------------------------- /datapipe/prepare/face/make_testing_data_hq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-07-16 12:42:42 4 | 5 | import sys 6 | from pathlib import Path 7 | sys.path.append(str(Path(__file__).resolve().parents[3])) 8 | 9 | import os 10 | import argparse 11 | from utils import util_common 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--files_txt", type=str, default='', help="File names") 15 | parser.add_argument("--num_images", type=int, default=3000, help="Number of trainging iamges") 16 | parser.add_argument("--save_dir", type=str, default='', help="Folder to save the fake iamges") 17 | args = parser.parse_args() 18 | 19 | files_path = util_common.readline_txt(args.files_txt) 20 | print(f'Number of images in txt file: {len(files_path)}') 21 | 22 | assert len(files_path) >= args.num_images 23 | files_path = files_path[:args.num_images] 24 | 25 | if not Path(args.save_dir).exists(): 26 | Path(args.save_dir).mkdir(parents=False) 27 | 28 | for path in files_path: 29 | commond = f'cp {path} {args.save_dir}' 30 | os.system(commond) 31 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with '_model.py' 12 | model_folder = osp.dirname(osp.abspath(__file__)) 13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 14 | # import all the model modules 15 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 16 | 17 | 18 | def build_model(opt): 19 | """Build model from options. 20 | 21 | Args: 22 | opt (dict): Configuration. It must contain: 23 | model_type (str): Model type. 24 | """ 25 | opt = deepcopy(opt) 26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 27 | logger = get_root_logger() 28 | logger.info(f'Model [{model.__class__.__name__}] is created.') 29 | return model 30 | -------------------------------------------------------------------------------- /ResizeRight/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Assaf Shocher 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/models/swinir_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from .sr_model import SRModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class SwinIRModel(SRModel): 10 | 11 | def test(self): 12 | # pad to multiplication of window_size 13 | window_size = self.opt['network_g']['window_size'] 14 | scale = self.opt.get('scale', 1) 15 | mod_pad_h, mod_pad_w = 0, 0 16 | _, _, h, w = self.lq.size() 17 | if h % window_size != 0: 18 | mod_pad_h = window_size - h % window_size 19 | if w % window_size != 0: 20 | mod_pad_w = window_size - w % window_size 21 | img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 22 | if hasattr(self, 'net_g_ema'): 23 | self.net_g_ema.eval() 24 | with torch.no_grad(): 25 | self.output = self.net_g_ema(img) 26 | else: 27 | self.net_g.eval() 28 | with torch.no_grad(): 29 | self.output = self.net_g(img) 30 | self.net_g.train() 31 | 32 | _, _, h, w = self.output.size() 33 | self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] 34 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import LOSS_REGISTRY 7 | from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty 8 | 9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] 10 | 11 | # automatically scan and import loss modules for registry 12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 13 | loss_folder = osp.dirname(osp.abspath(__file__)) 14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 15 | # import all the loss modules 16 | _model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] 17 | 18 | 19 | def build_loss(opt): 20 | """Build loss from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | loss_type = opt.pop('type') 28 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 29 | logger = get_root_logger() 30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 31 | return loss 32 | -------------------------------------------------------------------------------- /datapipe/prepare/face/make_lfw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-07-19 12:32:34 4 | 5 | import os 6 | import argparse 7 | from pathlib import Path 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--save_dir", type=str, default="./testdata/LFW-Test", 11 | help="Folder to save the LR images") 12 | parser.add_argument("--data_dir", type=str, default="./testdata/lfw", 13 | help="LFW Testing dataset") 14 | parser.add_argument("--txt_file", type=str, default="./testdata/peopleDevTest.txt", 15 | help="LFW Testing data file paths") 16 | args = parser.parse_args() 17 | 18 | with open(args.txt_file, 'r') as ff: 19 | file_dirs = [x.split('\t')[0] for x in ff.readlines()][1:] 20 | 21 | if not Path(args.save_dir).exists(): 22 | Path(args.save_dir).mkdir(parents=True) 23 | 24 | for current_dir in file_dirs: 25 | current_dir = Path(args.data_dir) / current_dir 26 | file_path = sorted([str(x) for x in current_dir.glob('*.jpg')])[0] 27 | commond = f'cp {file_path} {args.save_dir}' 28 | os.system(commond) 29 | 30 | num_images = len([x for x in Path(args.save_dir).glob('*.jpg')]) 31 | print(f'Number of images: {num_images}') 32 | 33 | -------------------------------------------------------------------------------- /main_sr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-05-19 15:15:17 4 | 5 | import argparse 6 | from trainer import TrainerSR as Trainer 7 | from omegaconf import OmegaConf 8 | 9 | def get_parser(**parser_kwargs): 10 | parser = argparse.ArgumentParser(**parser_kwargs) 11 | parser.add_argument("--save_dir", type=str, default="./save_dir", 12 | help="Folder to save the checkpoints and training log") 13 | parser.add_argument("--resume", type=str, const=True, default="", nargs="?", 14 | help="resume from the save_dir or checkpoint") 15 | parser.add_argument("--cfg_path", type=str, default="./configs/inpainting_debug.yaml", 16 | help="Configs of yaml file") 17 | parser.add_argument("--gpu_id", type=str, default='', help="GPU Index, e.g., 025") 18 | parser.add_argument("--seed", type=int, default=10000, help="Random seed") 19 | args = parser.parse_args() 20 | 21 | return args 22 | 23 | if __name__ == "__main__": 24 | args = get_parser() 25 | 26 | configs = OmegaConf.load(args.cfg_path) 27 | 28 | # merge args to config 29 | for key in vars(args): 30 | configs[key] = getattr(args, key) 31 | 32 | trainer = Trainer(configs) 33 | trainer.train() 34 | -------------------------------------------------------------------------------- /main_diffusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-05-19 15:15:17 4 | 5 | import argparse 6 | from trainer import TrainerDiffusionFace as Trainer 7 | from omegaconf import OmegaConf 8 | 9 | def get_parser(**parser_kwargs): 10 | parser = argparse.ArgumentParser(**parser_kwargs) 11 | parser.add_argument("--save_dir", type=str, default="./save_dir", 12 | help="Folder to save the checkpoints and training log") 13 | parser.add_argument("--resume", type=str, const=True, default="", nargs="?", 14 | help="resume from the save_dir or checkpoint") 15 | parser.add_argument("--cfg_path", type=str, default="./configs/inpainting_debug.yaml", 16 | help="Configs of yaml file") 17 | parser.add_argument("--gpu_id", type=str, default='', help="GPU Index, e.g., 025") 18 | parser.add_argument("--seed", type=int, default=10000, help="Random seed") 19 | args = parser.parse_args() 20 | 21 | return args 22 | 23 | if __name__ == "__main__": 24 | args = get_parser() 25 | 26 | configs = OmegaConf.load(args.cfg_path) 27 | 28 | # merge args to config 29 | for key in vars(args): 30 | configs[key] = getattr(args, key) 31 | 32 | trainer = Trainer(configs) 33 | trainer.train() 34 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb 2 | from .diffjpeg import DiffJPEG 3 | from .file_client import FileClient 4 | from .img_process_util import USMSharp, usm_sharp 5 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 6 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 7 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 8 | from .options import yaml_load 9 | 10 | __all__ = [ 11 | # color_util.py 12 | 'bgr2ycbcr', 13 | 'rgb2ycbcr', 14 | 'rgb2ycbcr_pt', 15 | 'ycbcr2bgr', 16 | 'ycbcr2rgb', 17 | # file_client.py 18 | 'FileClient', 19 | # img_util.py 20 | 'img2tensor', 21 | 'tensor2img', 22 | 'imfrombytes', 23 | 'imwrite', 24 | 'crop_border', 25 | # logger.py 26 | 'MessageLogger', 27 | 'AvgTimer', 28 | 'init_tb_logger', 29 | 'init_wandb_logger', 30 | 'get_root_logger', 31 | 'get_env_info', 32 | # misc.py 33 | 'set_random_seed', 34 | 'get_time_str', 35 | 'mkdir_and_rename', 36 | 'make_exp_dirs', 37 | 'scandir', 38 | 'check_resume', 39 | 'sizeof_fmt', 40 | # diffjpeg 41 | 'DiffJPEG', 42 | # img_process_util 43 | 'USMSharp', 44 | 'usm_sharp', 45 | # options 46 | 'yaml_load' 47 | ] 48 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /configs/training/diffusion_ffhq512.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: models.unet.UNetModel 3 | params: 4 | image_size: 512 5 | in_channels: 3 6 | model_channels: 32 7 | out_channels: 6 8 | attention_resolutions: [32, 16, 8] 9 | dropout: 0 10 | channel_mult: [1, 2, 4, 8, 8, 16, 16] 11 | num_res_blocks: [1, 2, 2, 2, 2, 3, 4] 12 | conv_resample: True 13 | dims: 2 14 | use_fp16: False 15 | num_head_channels: 64 16 | use_scale_shift_norm: True 17 | resblock_updown: False 18 | use_new_attention_order: False 19 | 20 | diffusion: 21 | target: models.script_util.create_gaussian_diffusion 22 | params: 23 | steps: 1000 24 | learn_sigma: True 25 | sigma_small: False 26 | noise_schedule: linear 27 | use_kl: False 28 | predict_xstart: False 29 | rescale_timesteps: False 30 | rescale_learned_sigmas: True 31 | timestep_respacing: "" 32 | 33 | train: 34 | lr: 1e-4 35 | batch: [32, 4] # batchsize for training and validation 36 | microbatch: 8 37 | use_fp16: False 38 | num_workers: 16 39 | prefetch_factor: 2 40 | iterations: 800000 41 | weight_decay: 0 42 | scheduler: step # step or cosin 43 | milestones: [10000, 800000] 44 | ema_rates: [0.999] 45 | save_freq: 10000 46 | val_freq: 5000 47 | log_freq: [1000, 2000] 48 | 49 | data: 50 | train: 51 | type: face 52 | params: 53 | ffhq_txt: ./datapipe/files_txt/ffhq512.txt 54 | out_size: 512 55 | transform_type: face 56 | -------------------------------------------------------------------------------- /models/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | 7 | def create_gaussian_diffusion( 8 | *, 9 | steps=1000, 10 | learn_sigma=False, 11 | sigma_small=False, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | predict_xstart=False, 15 | rescale_timesteps=False, 16 | rescale_learned_sigmas=False, 17 | timestep_respacing="", 18 | ): 19 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 20 | if use_kl: 21 | loss_type = gd.LossType.RESCALED_KL 22 | elif rescale_learned_sigmas: 23 | loss_type = gd.LossType.RESCALED_MSE 24 | else: 25 | loss_type = gd.LossType.MSE 26 | if not timestep_respacing: 27 | timestep_respacing = [steps] 28 | return SpacedDiffusion( 29 | use_timesteps=space_timesteps(steps, timestep_respacing), 30 | betas=betas, 31 | model_mean_type=( 32 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 33 | ), 34 | model_var_type=( 35 | ( 36 | gd.ModelVarType.FIXED_LARGE 37 | if not sigma_small 38 | else gd.ModelVarType.FIXED_SMALL 39 | ) 40 | if not learn_sigma 41 | else gd.ModelVarType.LEARNED_RANGE 42 | ), 43 | loss_type=loss_type, 44 | rescale_timesteps=rescale_timesteps, 45 | ) 46 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/yolov5n.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 1 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [4,5, 8,10, 13,16] # P3/8 9 | - [23,29, 43,55, 73,105] # P4/16 10 | - [146,217, 231,300, 335,433] # P5/32 11 | 12 | # YOLOv5 backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 16 | [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 17 | [-1, 3, ShuffleV2Block, [128, 1]], # 2 18 | [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 19 | [-1, 7, ShuffleV2Block, [256, 1]], # 4 20 | [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 21 | [-1, 3, ShuffleV2Block, [512, 1]], # 6 22 | ] 23 | 24 | # YOLOv5 head 25 | head: 26 | [[-1, 1, Conv, [128, 1, 1]], 27 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 28 | [[-1, 4], 1, Concat, [1]], # cat backbone P4 29 | [-1, 1, C3, [128, False]], # 10 30 | 31 | [-1, 1, Conv, [128, 1, 1]], 32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 33 | [[-1, 2], 1, Concat, [1]], # cat backbone P3 34 | [-1, 1, C3, [128, False]], # 14 (P3/8-small) 35 | 36 | [-1, 1, Conv, [128, 3, 2]], 37 | [[-1, 11], 1, Concat, [1]], # cat head P4 38 | [-1, 1, C3, [128, False]], # 17 (P4/16-medium) 39 | 40 | [-1, 1, Conv, [128, 3, 2]], 41 | [[-1, 7], 1, Concat, [1]], # cat head P5 42 | [-1, 1, C3, [128, False]], # 20 (P5/32-large) 43 | 44 | [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) 45 | ] 46 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/yolov5l.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 1 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [4,5, 8,10, 13,16] # P3/8 9 | - [23,29, 43,55, 73,105] # P4/16 10 | - [146,217, 231,300, 335,433] # P5/32 11 | 12 | # YOLOv5 backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 16 | [-1, 3, C3, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 18 | [-1, 9, C3, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 20 | [-1, 9, C3, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 22 | [-1, 1, SPP, [1024, [3,5,7]]], 23 | [-1, 3, C3, [1024, False]], # 8 24 | ] 25 | 26 | # YOLOv5 head 27 | head: 28 | [[-1, 1, Conv, [512, 1, 1]], 29 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 30 | [[-1, 5], 1, Concat, [1]], # cat backbone P4 31 | [-1, 3, C3, [512, False]], # 12 32 | 33 | [-1, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 3], 1, Concat, [1]], # cat backbone P3 36 | [-1, 3, C3, [256, False]], # 16 (P3/8-small) 37 | 38 | [-1, 1, Conv, [256, 3, 2]], 39 | [[-1, 13], 1, Concat, [1]], # cat head P4 40 | [-1, 3, C3, [512, False]], # 19 (P4/16-medium) 41 | 42 | [-1, 1, Conv, [512, 3, 2]], 43 | [[-1, 9], 1, Concat, [1]], # cat head P5 44 | [-1, 3, C3, [1024, False]], # 22 (P5/32-large) 45 | 46 | [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) 47 | ] -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def fuse_conv_and_bn(conv, bn): 6 | # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ 7 | fusedconv = ( 8 | nn.Conv2d( 9 | conv.in_channels, 10 | conv.out_channels, 11 | kernel_size=conv.kernel_size, 12 | stride=conv.stride, 13 | padding=conv.padding, 14 | groups=conv.groups, 15 | bias=True, 16 | ) 17 | .requires_grad_(False) 18 | .to(conv.weight.device) 19 | ) 20 | 21 | # prepare filters 22 | w_conv = conv.weight.clone().view(conv.out_channels, -1) 23 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) 24 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) 25 | 26 | # prepare spatial bias 27 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias 28 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) 29 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) 30 | 31 | return fusedconv 32 | 33 | 34 | def copy_attr(a, b, include=(), exclude=()): 35 | # Copy attributes from b to a, options to only include [...] and to exclude [...] 36 | for k, v in b.__dict__.items(): 37 | if (include and k not in include) or k.startswith("_") or k in exclude: 38 | continue 39 | 40 | setattr(a, k, v) 41 | -------------------------------------------------------------------------------- /datapipe/face_degradation_testing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-08-19 10:00:39 4 | 5 | import cv2 6 | import numpy as np 7 | from basicsr.data import degradations as degradations 8 | 9 | from utils import util_common 10 | 11 | def face_degradation(im, sf, sig_x, sig_y, theta, nf, qf): 12 | ''' 13 | Face degradation on testing data 14 | Input: 15 | im: numpy array, h x w x c, [0, 1], bgr 16 | sf: scale factor for super-resolution 17 | sig_x, sig_y, theta: parameters for generating gaussian kernel 18 | nf: noise level 19 | qf: quality factor for jpeg compression 20 | Output: 21 | im_lq: numpy array, h x w x c, [0, 1], bgr 22 | ''' 23 | h, w = im.shape[:2] 24 | 25 | # blur 26 | kernel = degradations.bivariate_Gaussian( 27 | kernel_size=41, 28 | sig_x=sig_x, 29 | sig_y=sig_y, 30 | theta=theta, 31 | isotropic=False, 32 | ) 33 | im_lq = cv2.filter2D(im, -1, kernel) 34 | 35 | # downsample 36 | im_lq = cv2.resize(im_lq, (int(w // sf), int(h // sf)), interpolation=cv2.INTER_LINEAR) 37 | 38 | # noise 39 | im_lq = degradations.add_gaussian_noise(im_lq, sigma=nf, clip=True, rounds=False) 40 | 41 | # jpeg compression 42 | im_lq = degradations.add_jpg_compression(im_lq, quality=qf) 43 | 44 | 45 | # resize to original size 46 | im_lq = cv2.resize(im_lq, (w, h), interpolation=cv2.INTER_LINEAR) 47 | 48 | # round and clip 49 | im_lq = np.clip((im_lq * 255.0).round(), 0, 255) / 255. 50 | 51 | return im_lq 52 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): 6 | # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 7 | shape = img.shape[:2] # current shape [height, width] 8 | if isinstance(new_shape, int): 9 | new_shape = (new_shape, new_shape) 10 | 11 | # Scale ratio (new / old) 12 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 13 | if not scaleup: # only scale down, do not scale up (for better test mAP) 14 | r = min(r, 1.0) 15 | 16 | # Compute padding 17 | ratio = r, r # width, height ratios 18 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 19 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding 20 | if auto: # minimum rectangle 21 | dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding 22 | elif scale_fill: # stretch 23 | dw, dh = 0.0, 0.0 24 | new_unpad = (new_shape[1], new_shape[0]) 25 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios 26 | 27 | dw /= 2 # divide padding into 2 sides 28 | dh /= 2 29 | 30 | if shape[::-1] != new_unpad: # resize 31 | img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) 32 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 33 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 34 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border 35 | return img, ratio, (dw, dh) 36 | -------------------------------------------------------------------------------- /basicsr/metrics/README.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | - [约定](#约定) 6 | - [PSNR 和 SSIM](#psnr-和-ssim) 7 | 8 | ## 约定 9 | 10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: 11 | 12 | - Numpy 类型 (一般是 cv2 的结果) 13 | - UINT8: BGR, [0, 255], (h, w, c) 14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 15 | - Tensor 类型 16 | - float: RGB, [0, 1], (n, c, h, w) 17 | 18 | 其他约定: 19 | 20 | - 以 `_pt` 结尾的是 PyTorch 结果 21 | - PyTorch version 支持 batch 计算 22 | - 颜色转换在 float32 上做;metric计算在 float64 上做 23 | 24 | ## PSNR 和 SSIM 25 | 26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) 28 | 29 | 下面列了各个实现的结果比对. 30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 31 | 32 | - PSNR 比对 33 | 34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 35 | |:---| :---: | :---: | :---: | :---: | :---: | 36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | 37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916| 38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | 39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663| 40 | 41 | - SSIM 比对 42 | 43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 44 | |:---| :---: | :---: | :---: | :---: | :---: | 45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | 46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171| 47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| 48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 | 49 | -------------------------------------------------------------------------------- /basicsr/metrics/README_CN.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | - [约定](#约定) 6 | - [PSNR 和 SSIM](#psnr-和-ssim) 7 | 8 | ## 约定 9 | 10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: 11 | 12 | - Numpy 类型 (一般是 cv2 的结果) 13 | - UINT8: BGR, [0, 255], (h, w, c) 14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 15 | - Tensor 类型 16 | - float: RGB, [0, 1], (n, c, h, w) 17 | 18 | 其他约定: 19 | 20 | - 以 `_pt` 结尾的是 PyTorch 结果 21 | - PyTorch version 支持 batch 计算 22 | - 颜色转换在 float32 上做;metric计算在 float64 上做 23 | 24 | ## PSNR 和 SSIM 25 | 26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) 28 | 29 | 下面列了各个实现的结果比对. 30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 31 | 32 | - PSNR 比对 33 | 34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 35 | |:---| :---: | :---: | :---: | :---: | :---: | 36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | 37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916| 38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | 39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663| 40 | 41 | - SSIM 比对 42 | 43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 44 | |:---| :---: | :---: | :---: | :---: | :---: | 45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | 46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171| 47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| 48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 | 49 | -------------------------------------------------------------------------------- /configs/training/swinir_ffhq512.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: models.swinir.SwinIR 3 | params: 4 | img_size: 64 5 | patch_size: 1 6 | in_chans: 3 7 | embed_dim: 180 8 | depths: [6, 6, 6, 6, 6, 6, 6, 6] 9 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6] 10 | window_size: 8 11 | mlp_ratio: 2 12 | sf: 8 13 | img_range: 1.0 14 | upsampler: "nearest+conv" 15 | resi_connection: "1conv" 16 | unshuffle: True 17 | unshuffle_scale: 8 18 | 19 | train: 20 | lr: 1e-4 21 | lr_min: 5e-6 22 | batch: [16, 4] # batchsize for training and validation 23 | microbatch: 4 24 | num_workers: 8 25 | prefetch_factor: 2 26 | iterations: 800000 27 | weight_decay: 0 28 | save_freq: 20000 29 | val_freq: 20000 30 | log_freq: [100, 2000, 100] 31 | 32 | data: 33 | train: 34 | type: gfpgan 35 | params: 36 | files_txt: ./datapipe/files_txt/ffhq512.txt 37 | io_backend: 38 | type: disk 39 | 40 | use_hflip: true 41 | mean: [0.0, 0.0, 0.0] 42 | std: [1.0, 1.0, 1.0] 43 | out_size: 512 44 | 45 | blur_kernel_size: 41 46 | kernel_list: ['iso', 'aniso'] 47 | kernel_prob: [0.5, 0.5] 48 | blur_sigma: [0.1, 15] 49 | downsample_range: [0.8, 32] 50 | noise_range: [0, 20] 51 | jpeg_range: [30, 100] 52 | 53 | color_jitter_prob: ~ 54 | color_jitter_pt_prob: ~ 55 | gray_prob: 0.01 56 | gt_gray: True 57 | 58 | need_gt_path: False 59 | val: 60 | type: folder 61 | params: 62 | dir_path: /mnt/lustre/zsyue/disk/IRDiff/Face/testing_data/syn_iclr_celeba512/lq 63 | dir_path_gt: /mnt/lustre/zsyue/disk/IRDiff/Face/testing_data/syn_iclr_celeba512/hq 64 | ext: png 65 | need_gt_path: False 66 | length: ~ 67 | mean: 0.0 68 | std: 1.0 69 | 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2022 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /configs/sample/iddpm_ffhq512_swinir.yaml: -------------------------------------------------------------------------------- 1 | gpu_id: "" 2 | seed: 10000 3 | display: True 4 | im_size: 512 5 | aligned: True 6 | 7 | diffusion: 8 | target: models.script_util.create_gaussian_diffusion 9 | params: 10 | steps: 1000 11 | learn_sigma: True 12 | sigma_small: False 13 | noise_schedule: linear 14 | use_kl: False 15 | predict_xstart: False 16 | rescale_timesteps: False 17 | rescale_learned_sigmas: True 18 | timestep_respacing: "250" 19 | 20 | model: 21 | target: models.unet.UNetModel 22 | ckpt_path: ./weights/diffusion/iddpm_ffhq512_ema500000.pth 23 | params: 24 | image_size: 512 25 | in_channels: 3 26 | model_channels: 32 27 | out_channels: 6 28 | attention_resolutions: [32, 16, 8] 29 | dropout: 0 30 | channel_mult: [1, 2, 4, 8, 8, 16, 16] 31 | num_res_blocks: [1, 2, 2, 2, 2, 3, 4] 32 | conv_resample: True 33 | dims: 2 34 | use_fp16: False 35 | num_head_channels: 64 36 | use_scale_shift_norm: True 37 | resblock_updown: False 38 | use_new_attention_order: False 39 | 40 | model_ir: 41 | target: models.swinir.SwinIR 42 | ckpt_path: ./weights/SwinIR/General_Face_ffhq512.pth 43 | params: 44 | img_size: 64 45 | patch_size: 1 46 | in_chans: 3 47 | embed_dim: 180 48 | depths: [6, 6, 6, 6, 6, 6, 6, 6] 49 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6] 50 | window_size: 8 51 | mlp_ratio: 2 52 | sf: 8 53 | img_range: 1.0 54 | upsampler: "nearest+conv" 55 | resi_connection: "1conv" 56 | unshuffle: True 57 | unshuffle_scale: 8 58 | 59 | # face detection model for unaligned face 60 | detection: 61 | det_model: "retinaface_resnet50" # large model: 'YOLOv5l', 'retinaface_resnet50'; small model: 'YOLOv5n', 'retinaface_mobile0.25' 62 | upscale: 2 # The final upscaling factor for the whole image 63 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import build_dataloader, build_dataset 6 | from basicsr.models import build_model 7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs 8 | from basicsr.utils.options import dict2str, parse_options 9 | 10 | 11 | def test_pipeline(root_path): 12 | # parse options, set distributed setting, set ramdom seed 13 | opt, _ = parse_options(root_path, is_train=False) 14 | 15 | torch.backends.cudnn.benchmark = True 16 | # torch.backends.cudnn.deterministic = True 17 | 18 | # mkdir and initialize loggers 19 | make_exp_dirs(opt) 20 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 21 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 22 | logger.info(get_env_info()) 23 | logger.info(dict2str(opt)) 24 | 25 | # create test dataset and dataloader 26 | test_loaders = [] 27 | for _, dataset_opt in sorted(opt['datasets'].items()): 28 | test_set = build_dataset(dataset_opt) 29 | test_loader = build_dataloader( 30 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 31 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 32 | test_loaders.append(test_loader) 33 | 34 | # create model 35 | model = build_model(opt) 36 | 37 | for test_loader in test_loaders: 38 | test_set_name = test_loader.dataset.opt['name'] 39 | logger.info(f'Testing {test_set_name}...') 40 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 41 | 42 | 43 | if __name__ == '__main__': 44 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 45 | test_pipeline(root_path) 46 | -------------------------------------------------------------------------------- /facelib/detection/yolov5face/models/experimental.py: -------------------------------------------------------------------------------- 1 | # # This file contains experimental modules 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | from facelib.detection.yolov5face.models.common import Conv 8 | 9 | 10 | class CrossConv(nn.Module): 11 | # Cross Convolution Downsample 12 | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): 13 | # ch_in, ch_out, kernel, stride, groups, expansion, shortcut 14 | super().__init__() 15 | c_ = int(c2 * e) # hidden channels 16 | self.cv1 = Conv(c1, c_, (1, k), (1, s)) 17 | self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) 18 | self.add = shortcut and c1 == c2 19 | 20 | def forward(self, x): 21 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) 22 | 23 | 24 | class MixConv2d(nn.Module): 25 | # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 26 | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): 27 | super().__init__() 28 | groups = len(k) 29 | if equal_ch: # equal c_ per group 30 | i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices 31 | c_ = [(i == g).sum() for g in range(groups)] # intermediate channels 32 | else: # equal weight.numel() per group 33 | b = [c2] + [0] * groups 34 | a = np.eye(groups + 1, groups, k=-1) 35 | a -= np.roll(a, 1, axis=1) 36 | a *= np.array(k) ** 2 37 | a[0] = 1 38 | c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b 39 | 40 | self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) 41 | self.bn = nn.BatchNorm2d(c2) 42 | self.act = nn.LeakyReLU(0.1, inplace=True) 43 | 44 | def forward(self, x): 45 | return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) 46 | -------------------------------------------------------------------------------- /datapipe/degradation_bsrgan/utils_logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import datetime 3 | import logging 4 | 5 | 6 | ''' 7 | # -------------------------------------------- 8 | # Kai Zhang (github: https://github.com/cszn) 9 | # 03/Mar/2019 10 | # -------------------------------------------- 11 | # https://github.com/xinntao/BasicSR 12 | # -------------------------------------------- 13 | ''' 14 | 15 | 16 | def log(*args, **kwargs): 17 | print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) 18 | 19 | 20 | ''' 21 | # -------------------------------------------- 22 | # logger 23 | # -------------------------------------------- 24 | ''' 25 | 26 | 27 | def logger_info(logger_name, log_path='default_logger.log'): 28 | ''' set up logger 29 | modified by Kai Zhang (github: https://github.com/cszn) 30 | ''' 31 | log = logging.getLogger(logger_name) 32 | if log.hasHandlers(): 33 | print('LogHandlers exist!') 34 | else: 35 | print('LogHandlers setup!') 36 | level = logging.INFO 37 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') 38 | fh = logging.FileHandler(log_path, mode='a') 39 | fh.setFormatter(formatter) 40 | log.setLevel(level) 41 | log.addHandler(fh) 42 | # print(len(log.handlers)) 43 | 44 | sh = logging.StreamHandler() 45 | sh.setFormatter(formatter) 46 | log.addHandler(sh) 47 | 48 | 49 | ''' 50 | # -------------------------------------------- 51 | # print to file and std_out simultaneously 52 | # -------------------------------------------- 53 | ''' 54 | 55 | 56 | class logger_print(object): 57 | def __init__(self, log_path="default.log"): 58 | self.terminal = sys.stdout 59 | self.log = open(log_path, 'a') 60 | 61 | def write(self, message): 62 | self.terminal.write(message) 63 | self.log.write(message) # write the message 64 | 65 | def flush(self): 66 | pass 67 | -------------------------------------------------------------------------------- /ResizeRight/interp_methods.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | torch = None 7 | 8 | try: 9 | import numpy 10 | except ImportError: 11 | numpy = None 12 | 13 | if numpy is None and torch is None: 14 | raise ImportError("Must have either Numpy or PyTorch but both not found") 15 | 16 | 17 | def set_framework_dependencies(x): 18 | if type(x) is numpy.ndarray: 19 | to_dtype = lambda a: a 20 | fw = numpy 21 | else: 22 | to_dtype = lambda a: a.to(x.dtype) 23 | fw = torch 24 | eps = fw.finfo(fw.float32).eps 25 | return fw, to_dtype, eps 26 | 27 | 28 | def support_sz(sz): 29 | def wrapper(f): 30 | f.support_sz = sz 31 | return f 32 | return wrapper 33 | 34 | @support_sz(4) 35 | def cubic(x): 36 | fw, to_dtype, eps = set_framework_dependencies(x) 37 | absx = fw.abs(x) 38 | absx2 = absx ** 2 39 | absx3 = absx ** 3 40 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + 41 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * 42 | to_dtype((1. < absx) & (absx <= 2.))) 43 | 44 | @support_sz(4) 45 | def lanczos2(x): 46 | fw, to_dtype, eps = set_framework_dependencies(x) 47 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / 48 | ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) 49 | 50 | @support_sz(6) 51 | def lanczos3(x): 52 | fw, to_dtype, eps = set_framework_dependencies(x) 53 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / 54 | ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) 55 | 56 | @support_sz(2) 57 | def linear(x): 58 | fw, to_dtype, eps = set_framework_dependencies(x) 59 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * 60 | to_dtype((0 <= x) & (x <= 1))) 61 | 62 | @support_sz(1) 63 | def box(x): 64 | fw, to_dtype, eps = set_framework_dependencies(x) 65 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) 66 | -------------------------------------------------------------------------------- /datapipe/prepare/face/big2small_face.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-05-18 07:58:01 4 | 5 | import sys 6 | from pathlib import Path 7 | sys.path.append(str(Path(__file__).resolve().parents[3])) 8 | 9 | import argparse 10 | import multiprocessing 11 | import albumentations as Aug 12 | from utils import util_image 13 | 14 | parser = argparse.ArgumentParser(prog='SISR dataset Generation') 15 | parser.add_argument('--face_dir', default='/home/jupyter/data/FFHQ/images1024x1024', type=str, 16 | metavar='PATH', help="Path to save the HR face images") 17 | parser.add_argument('--save_dir', default='/home/jupyter/data/FFHQ/', type=str, 18 | metavar='PATH', help="Path to save the resized face images") 19 | # FFHQ: png 20 | parser.add_argument('--ext', default='png', type=str, help="Image format of the HR face images") 21 | parser.add_argument('--pch_size', default=512, type=int, metavar='PATH', help="Cropped patch size") 22 | args = parser.parse_args() 23 | 24 | # check the floder to save the cropped patches 25 | pch_size = args.pch_size 26 | pch_dir = Path(args.face_dir).parent / f"images{pch_size}x{pch_size}" 27 | if not pch_dir.exists(): pch_dir.mkdir(parents=False) 28 | 29 | transform = Aug.Compose([Aug.SmallestMaxSize(max_size=pch_size),]) 30 | 31 | # HR face path 32 | path_hr_list = [x for x in Path(args.face_dir).glob('*.'+args.ext)] 33 | 34 | def process(im_path): 35 | im = util_image.imread(im_path, chn='rgb', dtype='uint8') 36 | pch = transform(image=im)['image'] 37 | pch_path = pch_dir / (im_path.stem + '.png') 38 | util_image.imwrite(pch, pch_path, chn='rgb') 39 | 40 | num_workers = multiprocessing.cpu_count() 41 | pool = multiprocessing.Pool(num_workers) 42 | pool.imap(func=process, iterable=path_hr_list, chunksize=16) 43 | pool.close() 44 | pool.join() 45 | 46 | num_pch = len([x for x in pch_dir.glob('*.png')]) 47 | print('Totally process {:d} images'.format(num_pch)) 48 | -------------------------------------------------------------------------------- /datapipe/prepare/face/split_train_val.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-05-18 09:18:02 4 | 5 | import random 6 | import argparse 7 | from pathlib import Path 8 | 9 | parser = argparse.ArgumentParser(prog='Face dataset Generation') 10 | parser.add_argument('--face_dir', default='/home/jupyter/data/FFHQ/images512x512', type=str, 11 | metavar='PATH', help="Path to save the face images") 12 | # FFHQ: png, Celeba: png 13 | parser.add_argument('--prefix', default='ffhq', type=str, help="Image format of the HR face images") 14 | parser.add_argument('--num_val', default=500, type=int, help="Ratio for Validation set") 15 | parser.add_argument('--seed', default=1234, type=int, help="Random seed") 16 | parser.add_argument('--im_size', default=512, type=int, help="Random seed") 17 | args = parser.parse_args() 18 | 19 | base_dir = Path(__file__).resolve().parents[2] / 'files_txt' 20 | if not base_dir.exists(): 21 | base_dir.mkdir() 22 | 23 | path_list = sorted([str(x.resolve()) for x in Path(args.face_dir).glob('*.png')]) 24 | 25 | file_path = base_dir / f"{args.prefix}{args.im_size}.txt" 26 | if file_path.exists(): 27 | file_path.unlink() 28 | with open(file_path, mode='w') as ff: 29 | for line in path_list: ff.write(line+'\n') 30 | 31 | random.seed(args.seed) 32 | random.shuffle(path_list) 33 | num_train = int(len(path_list) - args.num_val) 34 | 35 | file_path_train = base_dir / f"{args.prefix}{args.im_size}_train.txt" 36 | if file_path_train.exists(): 37 | file_path_train.unlink() 38 | with open(file_path_train, mode='w') as ff: 39 | for line in path_list[:num_train]: ff.write(line+'\n') 40 | 41 | file_path_val = base_dir / f"{args.prefix}{args.im_size}_val.txt" 42 | if file_path_val.exists(): 43 | file_path_val.unlink() 44 | with open(file_path_val, mode='w') as ff: 45 | for line in path_list[num_train:]: ff.write(line+'\n') 46 | 47 | print('Train / Validation: {:d}/{:d}'.format(num_train, len(path_list)-num_train)) 48 | 49 | -------------------------------------------------------------------------------- /utils/util_common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-02-06 10:34:59 4 | 5 | import importlib 6 | from pathlib import Path 7 | 8 | def mkdir(dir_path, delete=False, parents=True): 9 | import shutil 10 | if not isinstance(dir_path, Path): 11 | dir_path = Path(dir_path) 12 | if delete: 13 | if dir_path.exists(): 14 | shutil.rmtree(str(dir_path)) 15 | if not dir_path.exists(): 16 | dir_path.mkdir(parents=parents) 17 | 18 | def get_obj_from_str(string, reload=False): 19 | module, cls = string.rsplit(".", 1) 20 | if reload: 21 | module_imp = importlib.import_module(module) 22 | importlib.reload(module_imp) 23 | return getattr(importlib.import_module(module, package=None), cls) 24 | 25 | def instantiate_from_config(config): 26 | if not "target" in config: 27 | raise KeyError("Expected key `target` to instantiate.") 28 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 29 | 30 | def str2bool(v): 31 | if isinstance(v, bool): 32 | return v 33 | if v.lower() in ("yes", "true", "t", "y", "1"): 34 | return True 35 | elif v.lower() in ("no", "false", "f", "n", "0"): 36 | return False 37 | else: 38 | raise argparse.ArgumentTypeError("Boolean value expected.") 39 | 40 | def get_filenames(dir_path, exts=['png', 'jpg'], recursive=True): 41 | ''' 42 | Get the file paths in the given folder. 43 | param exts: list, e.g., ['png',] 44 | return: list 45 | ''' 46 | if not isinstance(dir_path, Path): 47 | dir_path = Path(dir_path) 48 | 49 | file_paths = [] 50 | for current_ext in exts: 51 | if recursive: 52 | file_paths.extend([str(x) for x in dir_path.glob('**/*.'+current_ext)]) 53 | else: 54 | file_paths.extend([str(x) for x in dir_path.glob('*.'+current_ext)]) 55 | 56 | return file_paths 57 | 58 | def readline_txt(txt_file): 59 | if txt_file is None: 60 | out = [] 61 | else: 62 | with open(txt_file, 'r') as ff: 63 | out = [x[:-1] for x in ff.readlines()] 64 | return out 65 | -------------------------------------------------------------------------------- /basicsr/archs/edsr_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class EDSR(nn.Module): 10 | """EDSR network structure. 11 | 12 | Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. 13 | Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch 14 | 15 | Args: 16 | num_in_ch (int): Channel number of inputs. 17 | num_out_ch (int): Channel number of outputs. 18 | num_feat (int): Channel number of intermediate features. 19 | Default: 64. 20 | num_block (int): Block number in the trunk network. Default: 16. 21 | upscale (int): Upsampling factor. Support 2^n and 3. 22 | Default: 4. 23 | res_scale (float): Used to scale the residual in residual block. 24 | Default: 1. 25 | img_range (float): Image range. Default: 255. 26 | rgb_mean (tuple[float]): Image mean in RGB orders. 27 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 28 | """ 29 | 30 | def __init__(self, 31 | num_in_ch, 32 | num_out_ch, 33 | num_feat=64, 34 | num_block=16, 35 | upscale=4, 36 | res_scale=1, 37 | img_range=255., 38 | rgb_mean=(0.4488, 0.4371, 0.4040)): 39 | super(EDSR, self).__init__() 40 | 41 | self.img_range = img_range 42 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 43 | 44 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 45 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True) 46 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 47 | self.upsample = Upsample(upscale, num_feat) 48 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 49 | 50 | def forward(self, x): 51 | self.mean = self.mean.type_as(x) 52 | 53 | x = (x - self.mean) * self.img_range 54 | x = self.conv_first(x) 55 | res = self.conv_after_body(self.body(x)) 56 | res += x 57 | 58 | x = self.conv_last(self.upsample(res)) 59 | x = x / self.img_range + self.mean 60 | 61 | return x 62 | -------------------------------------------------------------------------------- /datapipe/prepare/face/make_testing_data_bicubic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-07-16 12:42:42 4 | 5 | import sys 6 | from pathlib import Path 7 | sys.path.append(str(Path(__file__).resolve().parents[3])) 8 | 9 | import os 10 | import math 11 | import torch 12 | import argparse 13 | from einops import rearrange 14 | from datapipe.datasets import DatasetBicubic 15 | 16 | from utils import util_image 17 | from utils import util_common 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--files_txt", 22 | type=str, 23 | default='./datapipe/files_txt/celeba512_val.txt', 24 | help="File names") 25 | parser.add_argument( 26 | "--sf", 27 | type=int, 28 | default=8, 29 | help="Number of trainging iamges", 30 | ) 31 | parser.add_argument( 32 | "--bs", 33 | type=int, 34 | default=8, 35 | help="Batch size", 36 | ) 37 | parser.add_argument( 38 | "--save_dir", 39 | type=str, 40 | default='', 41 | help="Folder to save the fake iamges", 42 | ) 43 | parser.add_argument( 44 | "--num_images", 45 | type=int, 46 | default=100, 47 | help="Number of iamges", 48 | ) 49 | args = parser.parse_args() 50 | 51 | save_dir = Path(args.save_dir) 52 | if not save_dir.stem.endswith(f'x{args.sf}'): 53 | save_dir = save_dir.parent / f"{save_dir.stem}_x{args.sf}" 54 | util_common.mkdir(save_dir, delete=True) 55 | 56 | dataset = DatasetBicubic( 57 | files_txt=args.files_txt, 58 | up_back=True, 59 | need_gt_path=True, 60 | sf=args.sf, 61 | length=args.num_images, 62 | ) 63 | dataloader = torch.utils.data.DataLoader( 64 | dataset, 65 | batch_size=args.bs, 66 | drop_last=False, 67 | num_workers=4, 68 | pin_memory=False, 69 | ) 70 | 71 | for ii, data_batch in enumerate(dataloader): 72 | im_lq_batch = data_batch['lq'] 73 | im_path_batch = data_batch['gt_path'] 74 | print(f"Processing: {ii+1}/{math.ceil(len(dataset) / args.bs)}...") 75 | 76 | for jj in range(im_lq_batch.shape[0]): 77 | im_lq = rearrange( 78 | im_lq_batch[jj].clamp(0.0, 1.0).numpy(), 79 | 'c h w -> h w c', 80 | ) 81 | im_name = Path(im_path_batch[jj]).name 82 | im_path = save_dir / im_name 83 | util_image.imwrite(im_lq, im_path, chn='rgb', dtype_in='float32') 84 | 85 | -------------------------------------------------------------------------------- /models/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /facelib/parsing/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | 12 | def __init__(self, in_chan, out_chan, stride=1): 13 | super(BasicBlock, self).__init__() 14 | self.conv1 = conv3x3(in_chan, out_chan, stride) 15 | self.bn1 = nn.BatchNorm2d(out_chan) 16 | self.conv2 = conv3x3(out_chan, out_chan) 17 | self.bn2 = nn.BatchNorm2d(out_chan) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = None 20 | if in_chan != out_chan or stride != 1: 21 | self.downsample = nn.Sequential( 22 | nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(out_chan), 24 | ) 25 | 26 | def forward(self, x): 27 | residual = self.conv1(x) 28 | residual = F.relu(self.bn1(residual)) 29 | residual = self.conv2(residual) 30 | residual = self.bn2(residual) 31 | 32 | shortcut = x 33 | if self.downsample is not None: 34 | shortcut = self.downsample(x) 35 | 36 | out = shortcut + residual 37 | out = self.relu(out) 38 | return out 39 | 40 | 41 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 42 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 43 | for i in range(bnum - 1): 44 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 45 | return nn.Sequential(*layers) 46 | 47 | 48 | class ResNet18(nn.Module): 49 | 50 | def __init__(self): 51 | super(ResNet18, self).__init__() 52 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 53 | self.bn1 = nn.BatchNorm2d(64) 54 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 55 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 56 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 57 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 58 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = F.relu(self.bn1(x)) 63 | x = self.maxpool(x) 64 | 65 | x = self.layer1(x) 66 | feat8 = self.layer2(x) # 1/8 67 | feat16 = self.layer3(feat8) # 1/16 68 | feat32 = self.layer4(feat16) # 1/32 69 | return feat8, feat16, feat32 70 | -------------------------------------------------------------------------------- /basicsr/models/edvr_model.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils import get_root_logger 2 | from basicsr.utils.registry import MODEL_REGISTRY 3 | from .video_base_model import VideoBaseModel 4 | 5 | 6 | @MODEL_REGISTRY.register() 7 | class EDVRModel(VideoBaseModel): 8 | """EDVR Model. 9 | 10 | Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 11 | """ 12 | 13 | def __init__(self, opt): 14 | super(EDVRModel, self).__init__(opt) 15 | if self.is_train: 16 | self.train_tsa_iter = opt['train'].get('tsa_iter') 17 | 18 | def setup_optimizers(self): 19 | train_opt = self.opt['train'] 20 | dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) 21 | logger = get_root_logger() 22 | logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') 23 | if dcn_lr_mul == 1: 24 | optim_params = self.net_g.parameters() 25 | else: # separate dcn params and normal params for different lr 26 | normal_params = [] 27 | dcn_params = [] 28 | for name, param in self.net_g.named_parameters(): 29 | if 'dcn' in name: 30 | dcn_params.append(param) 31 | else: 32 | normal_params.append(param) 33 | optim_params = [ 34 | { # add normal params first 35 | 'params': normal_params, 36 | 'lr': train_opt['optim_g']['lr'] 37 | }, 38 | { 39 | 'params': dcn_params, 40 | 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul 41 | }, 42 | ] 43 | 44 | optim_type = train_opt['optim_g'].pop('type') 45 | self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) 46 | self.optimizers.append(self.optimizer_g) 47 | 48 | def optimize_parameters(self, current_iter): 49 | if self.train_tsa_iter: 50 | if current_iter == 1: 51 | logger = get_root_logger() 52 | logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') 53 | for name, param in self.net_g.named_parameters(): 54 | if 'fusion' not in name: 55 | param.requires_grad = False 56 | elif current_iter == self.train_tsa_iter: 57 | logger = get_root_logger() 58 | logger.warning('Train all the parameters.') 59 | for param in self.net_g.parameters(): 60 | param.requires_grad = True 61 | 62 | super(EDVRModel, self).optimize_parameters(current_iter) 63 | -------------------------------------------------------------------------------- /datapipe/prepare/face/degradation_split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-07-16 12:11:42 4 | 5 | import sys 6 | from pathlib import Path 7 | sys.path.append(str(Path(__file__).resolve().parents[3])) 8 | 9 | import os 10 | import math 11 | import torch 12 | import random 13 | import argparse 14 | import numpy as np 15 | from einops import rearrange 16 | 17 | from utils import util_image 18 | from utils import util_common 19 | 20 | from datapipe.face_degradation_testing import face_degradation 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--lq_dir", type=str, default='', help="floder for the lq image") 24 | parser.add_argument("--source_txt", type=str, default='', help="ffhq or celeba") 25 | parser.add_argument("--prefix", type=str, default='celeba512', help="Data type") 26 | parser.add_argument("--seed", type=int, default=10000, help="Random seed") 27 | args = parser.parse_args() 28 | 29 | qf_list = [30, 40, 50, 60, 70] # quality factor for jpeg compression 30 | sf_list = [4, 8, 16, 24, 30] # scale factor for upser-resolution 31 | nf_list = [1, 5, 10, 15, 20] # noise level for gaussian noise 32 | sig_list = [2, 4, 6, 8, 10, 12, 14] # sigma for gaussian kernel 33 | theta_list = [x*math.pi for x in [0, 0.25, 0.5, 0.75]] # angle for gaussian kernel 34 | num_val = len(qf_list) * len(sf_list) * len(nf_list) * len(sig_list) * len(theta_list) 35 | 36 | # setting seed 37 | random.seed(args.seed) 38 | np.random.seed(args.seed) 39 | torch.manual_seed(args.seed) 40 | 41 | files_path = util_common.readline_txt(args.source_txt) 42 | assert num_val <= len(files_path) 43 | print(f'Number of images in validation: {num_val}') 44 | 45 | save_dir = Path(args.lq_dir).parent / (Path(args.lq_dir).stem+'_split') 46 | if not save_dir.exists(): 47 | save_dir.mkdir() 48 | 49 | for sf_target in sf_list: 50 | num_iters = 0 51 | 52 | num_sf = 0 53 | file_path = save_dir / f"{args.prefix}_val_sf{sf_target}.txt" 54 | if file_path.exists(): 55 | file_path.unlink() 56 | with open(file_path, mode='w') as ff: 57 | for qf in qf_list: 58 | for sf in sf_list: 59 | for nf in nf_list: 60 | for sig_x in sig_list: 61 | for theta in theta_list: 62 | 63 | im_name = Path(files_path[num_iters]).name 64 | im_path = str(Path(args.lq_dir).parent / im_name) 65 | if sf == sf_target: 66 | ff.write(im_path+'\n') 67 | num_sf += 1 68 | 69 | num_iters += 1 70 | 71 | print(f'{num_sf} images for sf: {sf_target}') 72 | 73 | -------------------------------------------------------------------------------- /basicsr/metrics/test_metrics/test_psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | from basicsr.metrics import calculate_psnr, calculate_ssim 5 | from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt 6 | from basicsr.utils import img2tensor 7 | 8 | 9 | def test(img_path, img_path2, crop_border, test_y_channel=False): 10 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 11 | img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED) 12 | 13 | # --------------------- Numpy --------------------- 14 | psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 15 | ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 16 | print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') 17 | 18 | # --------------------- PyTorch (CPU) --------------------- 19 | img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 20 | img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 21 | 22 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 23 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 24 | print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 25 | 26 | # --------------------- PyTorch (GPU) --------------------- 27 | img = img.cuda() 28 | img2 = img2.cuda() 29 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 30 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 31 | print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 32 | 33 | psnr_pth = calculate_psnr_pt( 34 | torch.repeat_interleave(img, 2, dim=0), 35 | torch.repeat_interleave(img2, 2, dim=0), 36 | crop_border=crop_border, 37 | test_y_channel=test_y_channel) 38 | ssim_pth = calculate_ssim_pt( 39 | torch.repeat_interleave(img, 2, dim=0), 40 | torch.repeat_interleave(img2, 2, dim=0), 41 | crop_border=crop_border, 42 | test_y_channel=test_y_channel) 43 | print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,' 44 | f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}') 45 | 46 | 47 | if __name__ == '__main__': 48 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False) 49 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True) 50 | 51 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False) 52 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True) 53 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj, suffix=None): 39 | if isinstance(suffix, str): 40 | name = name + '_' + suffix 41 | 42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 43 | f"in '{self._name}' registry!") 44 | self._obj_map[name] = obj 45 | 46 | def register(self, obj=None, suffix=None): 47 | """ 48 | Register the given object under the the name `obj.__name__`. 49 | Can be used as either a decorator or not. 50 | See docstring of this class for usage. 51 | """ 52 | if obj is None: 53 | # used as a decorator 54 | def deco(func_or_class): 55 | name = func_or_class.__name__ 56 | self._do_register(name, func_or_class, suffix) 57 | return func_or_class 58 | 59 | return deco 60 | 61 | # used as a function call 62 | name = obj.__name__ 63 | self._do_register(name, obj, suffix) 64 | 65 | def get(self, name, suffix='basicsr'): 66 | ret = self._obj_map.get(name) 67 | if ret is None: 68 | ret = self._obj_map.get(name + '_' + suffix) 69 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 70 | if ret is None: 71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 72 | return ret 73 | 74 | def __contains__(self, name): 75 | return name in self._obj_map 76 | 77 | def __iter__(self): 78 | return iter(self._obj_map.items()) 79 | 80 | def keys(self): 81 | return self._obj_map.keys() 82 | 83 | 84 | DATASET_REGISTRY = Registry('dataset') 85 | ARCH_REGISTRY = Registry('arch') 86 | MODEL_REGISTRY = Registry('model') 87 | LOSS_REGISTRY = Registry('loss') 88 | METRIC_REGISTRY = Registry('metric') 89 | -------------------------------------------------------------------------------- /basicsr/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def read_data_from_tensorboard(log_path, tag): 5 | """Get raw data (steps and values) from tensorboard events. 6 | 7 | Args: 8 | log_path (str): Path to the tensorboard log. 9 | tag (str): tag to be read. 10 | """ 11 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 12 | 13 | # tensorboard event 14 | event_acc = EventAccumulator(log_path) 15 | event_acc.Reload() 16 | scalar_list = event_acc.Tags()['scalars'] 17 | print('tag list: ', scalar_list) 18 | steps = [int(s.step) for s in event_acc.Scalars(tag)] 19 | values = [s.value for s in event_acc.Scalars(tag)] 20 | return steps, values 21 | 22 | 23 | def read_data_from_txt_2v(path, pattern, step_one=False): 24 | """Read data from txt with 2 returned values (usually [step, value]). 25 | 26 | Args: 27 | path (str): path to the txt file. 28 | pattern (str): re (regular expression) pattern. 29 | step_one (bool): add 1 to steps. Default: False. 30 | """ 31 | with open(path) as f: 32 | lines = f.readlines() 33 | lines = [line.strip() for line in lines] 34 | steps = [] 35 | values = [] 36 | 37 | pattern = re.compile(pattern) 38 | for line in lines: 39 | match = pattern.match(line) 40 | if match: 41 | steps.append(int(match.group(1))) 42 | values.append(float(match.group(2))) 43 | if step_one: 44 | steps = [v + 1 for v in steps] 45 | return steps, values 46 | 47 | 48 | def read_data_from_txt_1v(path, pattern): 49 | """Read data from txt with 1 returned values. 50 | 51 | Args: 52 | path (str): path to the txt file. 53 | pattern (str): re (regular expression) pattern. 54 | """ 55 | with open(path) as f: 56 | lines = f.readlines() 57 | lines = [line.strip() for line in lines] 58 | data = [] 59 | 60 | pattern = re.compile(pattern) 61 | for line in lines: 62 | match = pattern.match(line) 63 | if match: 64 | data.append(float(match.group(1))) 65 | return data 66 | 67 | 68 | def smooth_data(values, smooth_weight): 69 | """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). 70 | 71 | Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501 72 | 73 | Args: 74 | values (list): A list of values to be smoothed. 75 | smooth_weight (float): Smooth weight. 76 | """ 77 | values_sm = [] 78 | last_sm_value = values[0] 79 | for value in values: 80 | value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value 81 | values_sm.append(value_sm) 82 | last_sm_value = value_sm 83 | return values_sm 84 | -------------------------------------------------------------------------------- /basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /basicsr/archs/srresnet_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class MSRResNet(nn.Module): 10 | """Modified SRResNet. 11 | 12 | A compacted version modified from SRResNet in 13 | "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" 14 | It uses residual blocks without BN, similar to EDSR. 15 | Currently, it supports x2, x3 and x4 upsampling scale factor. 16 | 17 | Args: 18 | num_in_ch (int): Channel number of inputs. Default: 3. 19 | num_out_ch (int): Channel number of outputs. Default: 3. 20 | num_feat (int): Channel number of intermediate features. Default: 64. 21 | num_block (int): Block number in the body network. Default: 16. 22 | upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4. 23 | """ 24 | 25 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4): 26 | super(MSRResNet, self).__init__() 27 | self.upscale = upscale 28 | 29 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 30 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat) 31 | 32 | # upsampling 33 | if self.upscale in [2, 3]: 34 | self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1) 35 | self.pixel_shuffle = nn.PixelShuffle(self.upscale) 36 | elif self.upscale == 4: 37 | self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 38 | self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 39 | self.pixel_shuffle = nn.PixelShuffle(2) 40 | 41 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 42 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 43 | 44 | # activation function 45 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 46 | 47 | # initialization 48 | default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1) 49 | if self.upscale == 4: 50 | default_init_weights(self.upconv2, 0.1) 51 | 52 | def forward(self, x): 53 | feat = self.lrelu(self.conv_first(x)) 54 | out = self.body(feat) 55 | 56 | if self.upscale == 4: 57 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 58 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 59 | elif self.upscale in [2, 3]: 60 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 61 | 62 | out = self.conv_last(self.lrelu(self.conv_hr(out))) 63 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 64 | out += base 65 | return out 66 | -------------------------------------------------------------------------------- /models/srcnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-07-12 20:35:28 4 | 5 | import math 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | class SRCNN(nn.Module): 10 | def __init__(self, in_chns, out_chns=None, num_chns=64, depth=8, sf=4): 11 | super().__init__() 12 | self.sf = sf 13 | out_chns = in_chns if out_chns is None else out_chns 14 | 15 | self.head = nn.Conv2d(in_chns, num_chns, kernel_size=5, padding=2) 16 | 17 | body = [] 18 | for _ in range(depth-1): 19 | body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2)) 20 | body.append(nn.LeakyReLU(0.2, inplace=True)) 21 | self.body = nn.Sequential(*body) 22 | 23 | tail = [] 24 | for _ in range(int(math.log(sf, 2))): 25 | tail.append(nn.Conv2d(num_chns, num_chns*4, kernel_size=3, padding=1)) 26 | tail.append(nn.LeakyReLU(0.2, inplace=True)) 27 | tail.append(nn.PixelShuffle(2)) 28 | tail.append(nn.Conv2d(num_chns, out_chns, kernel_size=5, padding=2)) 29 | self.tail = nn.Sequential(*tail) 30 | 31 | def forward(self, x): 32 | y = self.head(x) 33 | y = self.body(y) 34 | y = self.tail(y) 35 | return y 36 | 37 | class SRCNNFSR(nn.Module): 38 | def __init__(self, in_chns, down_scale_factor=2, num_chns=64, depth=8, sf=4): 39 | super().__init__() 40 | self.sf = sf 41 | 42 | head = [] 43 | in_chns_shuffle = in_chns * 4 44 | assert num_chns % 4 == 0 45 | for ii in range(int(math.log(down_scale_factor, 2))): 46 | head.append(nn.PixelUnshuffle(2)) 47 | head.append(nn.Conv2d(in_chns_shuffle, num_chns, kernel_size=3, padding=1)) 48 | if ii + 1 < int(math.log(down_scale_factor, 2)): 49 | head.append(nn.Conv2d(num_chns, num_chns//4, kernel_size=5, padding=2)) 50 | head.append(nn.LeakyReLU(0.2, inplace=True)) 51 | in_chns_shuffle = num_chns 52 | self.head = nn.Sequential(*head) 53 | 54 | body = [] 55 | for _ in range(depth-1): 56 | body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2)) 57 | body.append(nn.LeakyReLU(0.2, inplace=True)) 58 | self.body = nn.Sequential(*body) 59 | 60 | tail = [] 61 | for _ in range(int(math.log(down_scale_factor, 2))): 62 | tail.append(nn.Conv2d(num_chns, num_chns, kernel_size=3, padding=1)) 63 | tail.append(nn.LeakyReLU(0.2, inplace=True)) 64 | tail.append(nn.PixelShuffle(2)) 65 | num_chns //= 4 66 | tail.append(nn.Conv2d(num_chns, in_chns, kernel_size=5, padding=2)) 67 | self.tail = nn.Sequential(*tail) 68 | 69 | def forward(self, x): 70 | y = self.head(x) 71 | y = self.body(y) 72 | y = self.tail(y) 73 | return y 74 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SingleImageDataset(data.Dataset): 12 | """Read only lq images in the test phase. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 15 | 16 | There are two modes: 17 | 1. 'meta_info_file': Use meta information file to generate paths. 18 | 2. 'folder': Scan folders to generate paths. 19 | 20 | Args: 21 | opt (dict): Config for train datasets. It contains the following keys: 22 | dataroot_lq (str): Data root path for lq. 23 | meta_info_file (str): Path for meta information file. 24 | io_backend (dict): IO backend type and other kwarg. 25 | """ 26 | 27 | def __init__(self, opt): 28 | super(SingleImageDataset, self).__init__() 29 | self.opt = opt 30 | # file client (io backend) 31 | self.file_client = None 32 | self.io_backend_opt = opt['io_backend'] 33 | self.mean = opt['mean'] if 'mean' in opt else None 34 | self.std = opt['std'] if 'std' in opt else None 35 | self.lq_folder = opt['dataroot_lq'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = [self.lq_folder] 39 | self.io_backend_opt['client_keys'] = ['lq'] 40 | self.paths = paths_from_lmdb(self.lq_folder) 41 | elif 'meta_info_file' in self.opt: 42 | with open(self.opt['meta_info_file'], 'r') as fin: 43 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] 44 | else: 45 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load lq image 52 | lq_path = self.paths[index] 53 | img_bytes = self.file_client.get(lq_path, 'lq') 54 | img_lq = imfrombytes(img_bytes, float32=True) 55 | 56 | # color space transform 57 | if 'color' in self.opt and self.opt['color'] == 'y': 58 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 59 | 60 | # BGR to RGB, HWC to CHW, numpy to tensor 61 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 62 | # normalize 63 | if self.mean is not None or self.std is not None: 64 | normalize(img_lq, self.mean, self.std, inplace=True) 65 | return {'lq': img_lq, 'lq_path': lq_path} 66 | 67 | def __len__(self): 68 | return len(self.paths) 69 | -------------------------------------------------------------------------------- /basicsr/archs/srvgg_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | 6 | 7 | @ARCH_REGISTRY.register(suffix='basicsr') 8 | class SRVGGNetCompact(nn.Module): 9 | """A compact VGG-style network structure for super-resolution. 10 | 11 | It is a compact network structure, which performs upsampling in the last layer and no convolution is 12 | conducted on the HR feature space. 13 | 14 | Args: 15 | num_in_ch (int): Channel number of inputs. Default: 3. 16 | num_out_ch (int): Channel number of outputs. Default: 3. 17 | num_feat (int): Channel number of intermediate features. Default: 64. 18 | num_conv (int): Number of convolution layers in the body network. Default: 16. 19 | upscale (int): Upsampling factor. Default: 4. 20 | act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. 21 | """ 22 | 23 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): 24 | super(SRVGGNetCompact, self).__init__() 25 | self.num_in_ch = num_in_ch 26 | self.num_out_ch = num_out_ch 27 | self.num_feat = num_feat 28 | self.num_conv = num_conv 29 | self.upscale = upscale 30 | self.act_type = act_type 31 | 32 | self.body = nn.ModuleList() 33 | # the first conv 34 | self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) 35 | # the first activation 36 | if act_type == 'relu': 37 | activation = nn.ReLU(inplace=True) 38 | elif act_type == 'prelu': 39 | activation = nn.PReLU(num_parameters=num_feat) 40 | elif act_type == 'leakyrelu': 41 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 42 | self.body.append(activation) 43 | 44 | # the body structure 45 | for _ in range(num_conv): 46 | self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) 47 | # activation 48 | if act_type == 'relu': 49 | activation = nn.ReLU(inplace=True) 50 | elif act_type == 'prelu': 51 | activation = nn.PReLU(num_parameters=num_feat) 52 | elif act_type == 'leakyrelu': 53 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 54 | self.body.append(activation) 55 | 56 | # the last conv 57 | self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) 58 | # upsample 59 | self.upsampler = nn.PixelShuffle(upscale) 60 | 61 | def forward(self, x): 62 | out = x 63 | for i in range(0, len(self.body)): 64 | out = self.body[i](out) 65 | 66 | out = self.upsampler(out) 67 | # add the nearest upsampled image, so that the network learns the residual 68 | base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') 69 | out += base 70 | return out 71 | -------------------------------------------------------------------------------- /datapipe/degradation_bsrgan/utils_googledownload.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | 6 | ''' 7 | borrowed from 8 | https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py 9 | ''' 10 | 11 | 12 | def sizeof_fmt(size, suffix='B'): 13 | """Get human readable file size. 14 | Args: 15 | size (int): File size. 16 | suffix (str): Suffix. Default: 'B'. 17 | Return: 18 | str: Formated file siz. 19 | """ 20 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 21 | if abs(size) < 1024.0: 22 | return f'{size:3.1f} {unit}{suffix}' 23 | size /= 1024.0 24 | return f'{size:3.1f} Y{suffix}' 25 | 26 | 27 | def download_file_from_google_drive(file_id, save_path): 28 | """Download files from google drive. 29 | Ref: 30 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 31 | Args: 32 | file_id (str): File id. 33 | save_path (str): Save path. 34 | """ 35 | 36 | session = requests.Session() 37 | URL = 'https://docs.google.com/uc?export=download' 38 | params = {'id': file_id} 39 | 40 | response = session.get(URL, params=params, stream=True) 41 | token = get_confirm_token(response) 42 | if token: 43 | params['confirm'] = token 44 | response = session.get(URL, params=params, stream=True) 45 | 46 | # get file size 47 | response_file_size = session.get( 48 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 49 | if 'Content-Range' in response_file_size.headers: 50 | file_size = int( 51 | response_file_size.headers['Content-Range'].split('/')[1]) 52 | else: 53 | file_size = None 54 | 55 | save_response_content(response, save_path, file_size) 56 | 57 | 58 | def get_confirm_token(response): 59 | for key, value in response.cookies.items(): 60 | if key.startswith('download_warning'): 61 | return value 62 | return None 63 | 64 | 65 | def save_response_content(response, 66 | destination, 67 | file_size=None, 68 | chunk_size=32768): 69 | if file_size is not None: 70 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 71 | 72 | readable_file_size = sizeof_fmt(file_size) 73 | else: 74 | pbar = None 75 | 76 | with open(destination, 'wb') as f: 77 | downloaded_size = 0 78 | for chunk in response.iter_content(chunk_size): 79 | downloaded_size += chunk_size 80 | if pbar is not None: 81 | pbar.update(1) 82 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 83 | f'/ {readable_file_size}') 84 | if chunk: # filter out keep-alive new chunks 85 | f.write(chunk) 86 | if pbar is not None: 87 | pbar.close() 88 | 89 | 90 | if __name__ == "__main__": 91 | file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv' 92 | save_path = 'BSRGAN.pth' 93 | download_file_from_google_drive(file_id, save_path) 94 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from os import path as osp 4 | from torch.utils import data as data 5 | from torchvision.transforms.functional import normalize 6 | 7 | from basicsr.data.transforms import augment 8 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 9 | from basicsr.utils.registry import DATASET_REGISTRY 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class FFHQDataset(data.Dataset): 14 | """FFHQ dataset for StyleGAN. 15 | 16 | Args: 17 | opt (dict): Config for train datasets. It contains the following keys: 18 | dataroot_gt (str): Data root path for gt. 19 | io_backend (dict): IO backend type and other kwarg. 20 | mean (list | tuple): Image mean. 21 | std (list | tuple): Image std. 22 | use_hflip (bool): Whether to horizontally flip. 23 | 24 | """ 25 | 26 | def __init__(self, opt): 27 | super(FFHQDataset, self).__init__() 28 | self.opt = opt 29 | # file client (io backend) 30 | self.file_client = None 31 | self.io_backend_opt = opt['io_backend'] 32 | 33 | self.gt_folder = opt['dataroot_gt'] 34 | self.mean = opt['mean'] 35 | self.std = opt['std'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = self.gt_folder 39 | if not self.gt_folder.endswith('.lmdb'): 40 | raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 41 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 42 | self.paths = [line.split('.')[0] for line in fin] 43 | else: 44 | # FFHQ has 70000 images in total 45 | self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)] 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | # avoid errors caused by high latency in reading files 54 | retry = 3 55 | while retry > 0: 56 | try: 57 | img_bytes = self.file_client.get(gt_path) 58 | except Exception as e: 59 | logger = get_root_logger() 60 | logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') 61 | # change another file to read 62 | index = random.randint(0, self.__len__()) 63 | gt_path = self.paths[index] 64 | time.sleep(1) # sleep 1s for occasional server congestion 65 | else: 66 | break 67 | finally: 68 | retry -= 1 69 | img_gt = imfrombytes(img_bytes, float32=True) 70 | 71 | # random horizontal flip 72 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 73 | # BGR to RGB, HWC to CHW, numpy to tensor 74 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 75 | # normalize 76 | normalize(img_gt, self.mean, self.std, inplace=True) 77 | return {'gt': img_gt, 'gt_path': gt_path} 78 | 79 | def __len__(self): 80 | return len(self.paths) 81 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | fused_act_ext = load( 13 | 'fused', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 16 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import fused_act_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class FusedLeakyReLUFunctionBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, out, negative_slope, scale): 34 | ctx.save_for_backward(out) 35 | ctx.negative_slope = negative_slope 36 | ctx.scale = scale 37 | 38 | empty = grad_output.new_empty(0) 39 | 40 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 41 | 42 | dim = [0] 43 | 44 | if grad_input.ndim > 2: 45 | dim += list(range(2, grad_input.ndim)) 46 | 47 | grad_bias = grad_input.sum(dim).detach() 48 | 49 | return grad_input, grad_bias 50 | 51 | @staticmethod 52 | def backward(ctx, gradgrad_input, gradgrad_bias): 53 | out, = ctx.saved_tensors 54 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 55 | ctx.scale) 56 | 57 | return gradgrad_out, None, None, None 58 | 59 | 60 | class FusedLeakyReLUFunction(Function): 61 | 62 | @staticmethod 63 | def forward(ctx, input, bias, negative_slope, scale): 64 | empty = input.new_empty(0) 65 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 66 | ctx.save_for_backward(out) 67 | ctx.negative_slope = negative_slope 68 | ctx.scale = scale 69 | 70 | return out 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | out, = ctx.saved_tensors 75 | 76 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 77 | 78 | return grad_input, grad_bias, None, None 79 | 80 | 81 | class FusedLeakyReLU(nn.Module): 82 | 83 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 95 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 96 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: DifFace 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2022.10.11=h06a4308_0 8 | - certifi=2022.9.24=py38h06a4308_0 9 | - cudatoolkit=11.3.1=h2bc3f7f_2 10 | - cudnn=8.2.1=cuda11.3_0 11 | - ld_impl_linux-64=2.38=h1181459_1 12 | - libffi=3.4.2=h6a678d5_6 13 | - libgcc-ng=11.2.0=h1234567_1 14 | - libgomp=11.2.0=h1234567_1 15 | - libstdcxx-ng=11.2.0=h1234567_1 16 | - ncurses=6.3=h5eee18b_3 17 | - openssl=1.1.1s=h7f8727e_0 18 | - pip=22.2.2=py38h06a4308_0 19 | - python=3.8.15=h7a1cb2a_2 20 | - readline=8.2=h5eee18b_0 21 | - setuptools=65.5.0=py38h06a4308_0 22 | - sqlite=3.40.0=h5082296_0 23 | - tk=8.6.12=h1ccaba5_0 24 | - wheel=0.37.1=pyhd3eb1b0_0 25 | - xz=5.2.6=h5eee18b_0 26 | - zlib=1.2.13=h5eee18b_0 27 | - pip: 28 | - absl-py==1.3.0 29 | - antlr4-python3-runtime==4.9.3 30 | - asttokens==2.2.0 31 | - astunparse==1.6.3 32 | - backcall==0.2.0 33 | - beautifulsoup4==4.11.1 34 | - cachetools==5.2.0 35 | - charset-normalizer==2.1.1 36 | - decorator==5.1.1 37 | - einops==0.6.0 38 | - executing==1.2.0 39 | - filelock==3.8.2 40 | - flatbuffers==22.11.23 41 | - gast==0.4.0 42 | - gdown==4.6.0 43 | - google-auth==2.14.1 44 | - google-auth-oauthlib==0.4.6 45 | - google-pasta==0.2.0 46 | - grpcio==1.50.0 47 | - h5py==3.7.0 48 | - huggingface-hub==0.11.1 49 | - idna==3.4 50 | - imageio==2.22.4 51 | - importlib-metadata==5.1.0 52 | - ipdb==0.13.9 53 | - ipython==8.7.0 54 | - jedi==0.18.2 55 | - keras==2.10.0 56 | - keras-preprocessing==1.1.2 57 | - libclang==14.0.6 58 | - markdown==3.4.1 59 | - markupsafe==2.1.1 60 | - matplotlib-inline==0.1.6 61 | - networkx==2.8.8 62 | - numpy==1.23.5 63 | - oauthlib==3.2.2 64 | - omegaconf==2.3.0 65 | - opencv-python==4.5.3.56 66 | - opt-einsum==3.3.0 67 | - packaging==21.3 68 | - parso==0.8.3 69 | - pexpect==4.8.0 70 | - pickleshare==0.7.5 71 | - pillow==9.3.0 72 | - prompt-toolkit==3.0.33 73 | - protobuf==3.19.6 74 | - ptyprocess==0.7.0 75 | - pure-eval==0.2.2 76 | - pyasn1==0.4.8 77 | - pyasn1-modules==0.2.8 78 | - pygments==2.13.0 79 | - pyparsing==3.0.9 80 | - pysocks==1.7.1 81 | - pywavelets==1.4.1 82 | - pyyaml==6.0 83 | - requests==2.28.1 84 | - requests-oauthlib==1.3.1 85 | - rsa==4.9 86 | - scikit-image==0.19.3 87 | - scipy==1.9.3 88 | - six==1.16.0 89 | - soupsieve==2.3.2.post1 90 | - stack-data==0.6.2 91 | - tensorboard==2.10.1 92 | - tensorboard-data-server==0.6.1 93 | - tensorboard-plugin-wit==1.8.1 94 | - tensorflow-estimator==2.10.0 95 | - tensorflow-gpu==2.10.0 96 | - tensorflow-io-gcs-filesystem==0.28.0 97 | - termcolor==2.1.1 98 | - tifffile==2022.10.10 99 | - timm==0.6.12 100 | - toml==0.10.2 101 | - torch==1.12.0 102 | - torchaudio==0.12.0 103 | - torchvision==0.13.0 104 | - tqdm==4.64.1 105 | - traitlets==5.5.0 106 | - typing-extensions==4.4.0 107 | - urllib3==1.26.13 108 | - wcwidth==0.2.5 109 | - werkzeug==2.2.2 110 | - wrapt==1.14.1 111 | - zipp==3.11.0 112 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /utils/util_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2021-11-24 20:29:36 4 | 5 | import math 6 | import torch 7 | from pathlib import Path 8 | from collections import OrderedDict 9 | import torch.nn.functional as F 10 | 11 | def calculate_parameters(net): 12 | out = 0 13 | for param in net.parameters(): 14 | out += param.numel() 15 | return out 16 | 17 | def pad_input(x, mod): 18 | h, w = x.shape[-2:] 19 | bottom = int(math.ceil(h/mod)*mod -h) 20 | right = int(math.ceil(w/mod)*mod - w) 21 | x_pad = F.pad(x, pad=(0, right, 0, bottom), mode='reflect') 22 | return x_pad 23 | 24 | def forward_chop(net, x, net_kwargs=None, scale=1, shave=10, min_size=160000): 25 | n_GPUs = 1 26 | b, c, h, w = x.size() 27 | h_half, w_half = h // 2, w // 2 28 | h_size, w_size = h_half + shave, w_half + shave 29 | lr_list = [ 30 | x[:, :, 0:h_size, 0:w_size], 31 | x[:, :, 0:h_size, (w - w_size):w], 32 | x[:, :, (h - h_size):h, 0:w_size], 33 | x[:, :, (h - h_size):h, (w - w_size):w]] 34 | 35 | if w_size * h_size < min_size: 36 | sr_list = [] 37 | for i in range(0, 4, n_GPUs): 38 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 39 | if net_kwargs is None: 40 | sr_batch = net(lr_batch) 41 | else: 42 | sr_batch = net(lr_batch, **net_kwargs) 43 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 44 | else: 45 | sr_list = [ 46 | forward_chop(patch, shave=shave, min_size=min_size) \ 47 | for patch in lr_list 48 | ] 49 | 50 | h, w = scale * h, scale * w 51 | h_half, w_half = scale * h_half, scale * w_half 52 | h_size, w_size = scale * h_size, scale * w_size 53 | shave *= scale 54 | 55 | output = x.new(b, c, h, w) 56 | output[:, :, 0:h_half, 0:w_half] \ 57 | = sr_list[0][:, :, 0:h_half, 0:w_half] 58 | output[:, :, 0:h_half, w_half:w] \ 59 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 60 | output[:, :, h_half:h, 0:w_half] \ 61 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 62 | output[:, :, h_half:h, w_half:w] \ 63 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 64 | 65 | return output 66 | 67 | def measure_time(net, inputs, num_forward=100): 68 | ''' 69 | Measuring the average runing time (seconds) for pytorch. 70 | out = net(*inputs) 71 | ''' 72 | start = torch.cuda.Event(enable_timing=True) 73 | end = torch.cuda.Event(enable_timing=True) 74 | 75 | start.record() 76 | with torch.set_grad_enabled(False): 77 | for _ in range(num_forward): 78 | out = net(*inputs) 79 | end.record() 80 | 81 | torch.cuda.synchronize() 82 | 83 | return start.elapsed_time(end) / 1000 84 | 85 | def reload_model(model, ckpt): 86 | if list(model.state_dict().keys())[0].startswith('module.'): 87 | if list(ckpt.keys())[0].startswith('module.'): 88 | ckpt = ckpt 89 | else: 90 | ckpt = OrderedDict({f'module.{key}':value for key, value in ckpt.items()}) 91 | else: 92 | if list(ckpt.keys())[0].startswith('module.'): 93 | ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()}) 94 | else: 95 | ckpt = ckpt 96 | model.load_state_dict(ckpt) 97 | -------------------------------------------------------------------------------- /basicsr/models/esrgan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from .srgan_model import SRGANModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class ESRGANModel(SRGANModel): 10 | """ESRGAN model for single image super-resolution.""" 11 | 12 | def optimize_parameters(self, current_iter): 13 | # optimize net_g 14 | for p in self.net_d.parameters(): 15 | p.requires_grad = False 16 | 17 | self.optimizer_g.zero_grad() 18 | self.output = self.net_g(self.lq) 19 | 20 | l_g_total = 0 21 | loss_dict = OrderedDict() 22 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): 23 | # pixel loss 24 | if self.cri_pix: 25 | l_g_pix = self.cri_pix(self.output, self.gt) 26 | l_g_total += l_g_pix 27 | loss_dict['l_g_pix'] = l_g_pix 28 | # perceptual loss 29 | if self.cri_perceptual: 30 | l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) 31 | if l_g_percep is not None: 32 | l_g_total += l_g_percep 33 | loss_dict['l_g_percep'] = l_g_percep 34 | if l_g_style is not None: 35 | l_g_total += l_g_style 36 | loss_dict['l_g_style'] = l_g_style 37 | # gan loss (relativistic gan) 38 | real_d_pred = self.net_d(self.gt).detach() 39 | fake_g_pred = self.net_d(self.output) 40 | l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False) 41 | l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False) 42 | l_g_gan = (l_g_real + l_g_fake) / 2 43 | 44 | l_g_total += l_g_gan 45 | loss_dict['l_g_gan'] = l_g_gan 46 | 47 | l_g_total.backward() 48 | self.optimizer_g.step() 49 | 50 | # optimize net_d 51 | for p in self.net_d.parameters(): 52 | p.requires_grad = True 53 | 54 | self.optimizer_d.zero_grad() 55 | # gan loss (relativistic gan) 56 | 57 | # In order to avoid the error in distributed training: 58 | # "Error detected in CudnnBatchNormBackward: RuntimeError: one of 59 | # the variables needed for gradient computation has been modified by 60 | # an inplace operation", 61 | # we separate the backwards for real and fake, and also detach the 62 | # tensor for calculating mean. 63 | 64 | # real 65 | fake_d_pred = self.net_d(self.output).detach() 66 | real_d_pred = self.net_d(self.gt) 67 | l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5 68 | l_d_real.backward() 69 | # fake 70 | fake_d_pred = self.net_d(self.output.detach()) 71 | l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5 72 | l_d_fake.backward() 73 | self.optimizer_d.step() 74 | 75 | loss_dict['l_d_real'] = l_d_real 76 | loss_dict['l_d_fake'] = l_d_fake 77 | loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) 78 | loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) 79 | 80 | self.log_dict = self.reduce_loss_dict(loss_dict) 81 | 82 | if self.ema_decay > 0: 83 | self.model_ema(decay=self.ema_decay) 84 | -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): 11 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 12 | # does resize the input. 13 | inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) 14 | inception = nn.DataParallel(inception).eval().to(device) 15 | return inception 16 | 17 | 18 | @torch.no_grad() 19 | def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): 20 | """Extract inception features. 21 | 22 | Args: 23 | data_generator (generator): A data generator. 24 | inception (nn.Module): Inception model. 25 | len_generator (int): Length of the data_generator to show the 26 | progressbar. Default: None. 27 | device (str): Device. Default: cuda. 28 | 29 | Returns: 30 | Tensor: Extracted features. 31 | """ 32 | if len_generator is not None: 33 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 34 | else: 35 | pbar = None 36 | features = [] 37 | 38 | for data in data_generator: 39 | if pbar: 40 | pbar.update(1) 41 | data = data.to(device) 42 | feature = inception(data)[0].view(data.shape[0], -1) 43 | features.append(feature.to('cpu')) 44 | if pbar: 45 | pbar.close() 46 | features = torch.cat(features, 0) 47 | return features 48 | 49 | 50 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 51 | """Numpy implementation of the Frechet Distance. 52 | 53 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is: 54 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 55 | Stable version by Dougal J. Sutherland. 56 | 57 | Args: 58 | mu1 (np.array): The sample mean over activations. 59 | sigma1 (np.array): The covariance matrix over activations for generated samples. 60 | mu2 (np.array): The sample mean over activations, precalculated on an representative data set. 61 | sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set. 62 | 63 | Returns: 64 | float: The Frechet Distance. 65 | """ 66 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 67 | assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') 68 | 69 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 70 | 71 | # Product might be almost singular 72 | if not np.isfinite(cov_sqrt).all(): 73 | print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') 74 | offset = np.eye(sigma1.shape[0]) * eps 75 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 76 | 77 | # Numerical error might give slight imaginary component 78 | if np.iscomplexobj(cov_sqrt): 79 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 80 | m = np.max(np.abs(cov_sqrt.imag)) 81 | raise ValueError(f'Imaginary component {m}') 82 | cov_sqrt = cov_sqrt.real 83 | 84 | mean_diff = mu1 - mu2 85 | mean_norm = mean_diff @ mean_diff 86 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 87 | fid = mean_norm + trace 88 | 89 | return fid 90 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | 14 | Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive 15 | 16 | Args: 17 | file_id (str): File id. 18 | save_path (str): Save path. 19 | """ 20 | 21 | session = requests.Session() 22 | URL = 'https://docs.google.com/uc?export=download' 23 | params = {'id': file_id} 24 | 25 | response = session.get(URL, params=params, stream=True) 26 | token = get_confirm_token(response) 27 | if token: 28 | params['confirm'] = token 29 | response = session.get(URL, params=params, stream=True) 30 | 31 | # get file size 32 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 33 | if 'Content-Range' in response_file_size.headers: 34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 49 | if file_size is not None: 50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 51 | 52 | readable_file_size = sizeof_fmt(file_size) 53 | else: 54 | pbar = None 55 | 56 | with open(destination, 'wb') as f: 57 | downloaded_size = 0 58 | for chunk in response.iter_content(chunk_size): 59 | downloaded_size += chunk_size 60 | if pbar is not None: 61 | pbar.update(1) 62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | if pbar is not None: 66 | pbar.close() 67 | 68 | 69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 70 | """Load file form http url, will download models if necessary. 71 | 72 | Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 73 | 74 | Args: 75 | url (str): URL to be downloaded. 76 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 77 | Default: None. 78 | progress (bool): Whether to show the download progress. Default: True. 79 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 80 | 81 | Returns: 82 | str: The path to the downloaded file. 83 | """ 84 | if model_dir is None: # use the pytorch hub_dir 85 | hub_dir = get_dir() 86 | model_dir = os.path.join(hub_dir, 'checkpoints') 87 | 88 | os.makedirs(model_dir, exist_ok=True) 89 | 90 | parts = urlparse(url) 91 | filename = os.path.basename(parts.path) 92 | if file_name is not None: 93 | filename = file_name 94 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 95 | if not os.path.exists(cached_file): 96 | print(f'Downloading: "{url}" to {cached_file}\n') 97 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 98 | return cached_file 99 | -------------------------------------------------------------------------------- /models/basic_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | class SiLU(nn.Module): 11 | def forward(self, x): 12 | return x * th.sigmoid(x) 13 | 14 | 15 | class GroupNorm32(nn.GroupNorm): 16 | def forward(self, x): 17 | return super().forward(x.float()).type(x.dtype) 18 | 19 | 20 | def conv_nd(dims, *args, **kwargs): 21 | """ 22 | Create a 1D, 2D, or 3D convolution module. 23 | """ 24 | if dims == 1: 25 | return nn.Conv1d(*args, **kwargs) 26 | elif dims == 2: 27 | return nn.Conv2d(*args, **kwargs) 28 | elif dims == 3: 29 | return nn.Conv3d(*args, **kwargs) 30 | raise ValueError(f"unsupported dimensions: {dims}") 31 | 32 | def linear(*args, **kwargs): 33 | """ 34 | Create a linear module. 35 | """ 36 | return nn.Linear(*args, **kwargs) 37 | 38 | def avg_pool_nd(dims, *args, **kwargs): 39 | """ 40 | Create a 1D, 2D, or 3D average pooling module. 41 | """ 42 | if dims == 1: 43 | return nn.AvgPool1d(*args, **kwargs) 44 | elif dims == 2: 45 | return nn.AvgPool2d(*args, **kwargs) 46 | elif dims == 3: 47 | return nn.AvgPool3d(*args, **kwargs) 48 | raise ValueError(f"unsupported dimensions: {dims}") 49 | 50 | 51 | def update_ema(target_params, source_params, rate=0.99): 52 | """ 53 | Update target parameters to be closer to those of source parameters using 54 | an exponential moving average. 55 | 56 | :param target_params: the target parameter sequence. 57 | :param source_params: the source parameter sequence. 58 | :param rate: the EMA rate (closer to 1 means slower). 59 | """ 60 | for targ, src in zip(target_params, source_params): 61 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 62 | 63 | 64 | def zero_module(module): 65 | """ 66 | Zero out the parameters of a module and return it. 67 | """ 68 | for p in module.parameters(): 69 | p.detach().zero_() 70 | return module 71 | 72 | 73 | def scale_module(module, scale): 74 | """ 75 | Scale the parameters of a module and return it. 76 | """ 77 | for p in module.parameters(): 78 | p.detach().mul_(scale) 79 | return module 80 | 81 | 82 | def mean_flat(tensor): 83 | """ 84 | Take the mean over all non-batch dimensions. 85 | """ 86 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 87 | 88 | 89 | def normalization(channels): 90 | """ 91 | Make a standard normalization layer. 92 | 93 | :param channels: number of input channels. 94 | :return: an nn.Module for normalization. 95 | """ 96 | return GroupNorm32(32, channels) 97 | 98 | 99 | def timestep_embedding(timesteps, dim, max_period=10000): 100 | """ 101 | Create sinusoidal timestep embeddings. 102 | 103 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 104 | These may be fractional. 105 | :param dim: the dimension of the output. 106 | :param max_period: controls the minimum frequency of the embeddings. 107 | :return: an [N x dim] Tensor of positional embeddings. 108 | """ 109 | half = dim // 2 110 | freqs = th.exp( 111 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 112 | ).to(device=timesteps.device) 113 | args = timesteps[:, None].float() * freqs[None] # B x half 114 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 115 | if dim % 2: 116 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 117 | return embedding 118 | 119 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 11 | 12 | Args: 13 | generator: Python generator. 14 | num_prefetch_queue (int): Number of prefetch queue. 15 | """ 16 | 17 | def __init__(self, generator, num_prefetch_queue): 18 | threading.Thread.__init__(self) 19 | self.queue = Queue.Queue(num_prefetch_queue) 20 | self.generator = generator 21 | self.daemon = True 22 | self.start() 23 | 24 | def run(self): 25 | for item in self.generator: 26 | self.queue.put(item) 27 | self.queue.put(None) 28 | 29 | def __next__(self): 30 | next_item = self.queue.get() 31 | if next_item is None: 32 | raise StopIteration 33 | return next_item 34 | 35 | def __iter__(self): 36 | return self 37 | 38 | 39 | class PrefetchDataLoader(DataLoader): 40 | """Prefetch version of dataloader. 41 | 42 | Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 43 | 44 | TODO: 45 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 46 | ddp. 47 | 48 | Args: 49 | num_prefetch_queue (int): Number of prefetch queue. 50 | kwargs (dict): Other arguments for dataloader. 51 | """ 52 | 53 | def __init__(self, num_prefetch_queue, **kwargs): 54 | self.num_prefetch_queue = num_prefetch_queue 55 | super(PrefetchDataLoader, self).__init__(**kwargs) 56 | 57 | def __iter__(self): 58 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 59 | 60 | 61 | class CPUPrefetcher(): 62 | """CPU prefetcher. 63 | 64 | Args: 65 | loader: Dataloader. 66 | """ 67 | 68 | def __init__(self, loader): 69 | self.ori_loader = loader 70 | self.loader = iter(loader) 71 | 72 | def next(self): 73 | try: 74 | return next(self.loader) 75 | except StopIteration: 76 | return None 77 | 78 | def reset(self): 79 | self.loader = iter(self.ori_loader) 80 | 81 | 82 | class CUDAPrefetcher(): 83 | """CUDA prefetcher. 84 | 85 | Reference: https://github.com/NVIDIA/apex/issues/304# 86 | 87 | It may consume more GPU memory. 88 | 89 | Args: 90 | loader: Dataloader. 91 | opt (dict): Options. 92 | """ 93 | 94 | def __init__(self, loader, opt): 95 | self.ori_loader = loader 96 | self.loader = iter(loader) 97 | self.opt = opt 98 | self.stream = torch.cuda.Stream() 99 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 100 | self.preload() 101 | 102 | def preload(self): 103 | try: 104 | self.batch = next(self.loader) # self.batch is a dict 105 | except StopIteration: 106 | self.batch = None 107 | return None 108 | # put tensors to gpu 109 | with torch.cuda.stream(self.stream): 110 | for k, v in self.batch.items(): 111 | if torch.is_tensor(v): 112 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 113 | 114 | def next(self): 115 | torch.cuda.current_stream().wait_stream(self.stream) 116 | batch = self.batch 117 | self.preload() 118 | return batch 119 | 120 | def reset(self): 121 | self.loader = iter(self.ori_loader) 122 | self.preload() 123 | -------------------------------------------------------------------------------- /basicsr/archs/spynet_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | from .arch_util import flow_warp 8 | 9 | 10 | class BasicModule(nn.Module): 11 | """Basic Module for SpyNet. 12 | """ 13 | 14 | def __init__(self): 15 | super(BasicModule, self).__init__() 16 | 17 | self.basic_module = nn.Sequential( 18 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 19 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 20 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 21 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 22 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) 23 | 24 | def forward(self, tensor_input): 25 | return self.basic_module(tensor_input) 26 | 27 | 28 | @ARCH_REGISTRY.register() 29 | class SpyNet(nn.Module): 30 | """SpyNet architecture. 31 | 32 | Args: 33 | load_path (str): path for pretrained SpyNet. Default: None. 34 | """ 35 | 36 | def __init__(self, load_path=None): 37 | super(SpyNet, self).__init__() 38 | self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) 39 | if load_path: 40 | self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) 41 | 42 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 43 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 44 | 45 | def preprocess(self, tensor_input): 46 | tensor_output = (tensor_input - self.mean) / self.std 47 | return tensor_output 48 | 49 | def process(self, ref, supp): 50 | flow = [] 51 | 52 | ref = [self.preprocess(ref)] 53 | supp = [self.preprocess(supp)] 54 | 55 | for level in range(5): 56 | ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) 57 | supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) 58 | 59 | flow = ref[0].new_zeros( 60 | [ref[0].size(0), 2, 61 | int(math.floor(ref[0].size(2) / 2.0)), 62 | int(math.floor(ref[0].size(3) / 2.0))]) 63 | 64 | for level in range(len(ref)): 65 | upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 66 | 67 | if upsampled_flow.size(2) != ref[level].size(2): 68 | upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') 69 | if upsampled_flow.size(3) != ref[level].size(3): 70 | upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') 71 | 72 | flow = self.basic_module[level](torch.cat([ 73 | ref[level], 74 | flow_warp( 75 | supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), 76 | upsampled_flow 77 | ], 1)) + upsampled_flow 78 | 79 | return flow 80 | 81 | def forward(self, ref, supp): 82 | assert ref.size() == supp.size() 83 | 84 | h, w = ref.size(2), ref.size(3) 85 | w_floor = math.floor(math.ceil(w / 32.0) * 32.0) 86 | h_floor = math.floor(math.ceil(h / 32.0) * 32.0) 87 | 88 | ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) 89 | supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) 90 | 91 | flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False) 92 | 93 | flow[:, 0, :, :] *= float(w) / float(w_floor) 94 | flow[:, 1, :, :] *= float(h) / float(h_floor) 95 | 96 | return flow 97 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /datapipe/prepare/face/make_testing_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2022-07-16 12:11:42 4 | 5 | import sys 6 | import pickle 7 | from pathlib import Path 8 | sys.path.append(str(Path(__file__).resolve().parents[3])) 9 | 10 | import os 11 | import math 12 | import torch 13 | import random 14 | import argparse 15 | import numpy as np 16 | from einops import rearrange 17 | 18 | from utils import util_image 19 | from utils import util_common 20 | 21 | from datapipe.face_degradation_testing import face_degradation 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--save_dir", type=str, default='', help="Folder to save the testing data") 25 | parser.add_argument("--files_txt", type=str, default='', help="ffhq or celeba") 26 | parser.add_argument("--seed", type=int, default=10000, help="Random seed") 27 | args = parser.parse_args() 28 | 29 | ############################ ICLR #################################################### 30 | # qf_list = [30, 40, 50, 60, 70] # quality factor for jpeg compression 31 | # sf_list = [4, 8, 16, 24, 30] # scale factor for upser-resolution 32 | # nf_list = [1, 5, 10, 15, 20] # noise level for gaussian noise 33 | # sig_list = [2, 4, 6, 8, 10, 12, 14] # sigma for gaussian kernel 34 | # theta_list = [x*math.pi for x in [0, 0.25, 0.5, 0.75]] # angle for gaussian kernel 35 | ###################################################################################### 36 | 37 | ############################ Journal ################################################# 38 | qf_list = [30, 40, 50, 60, 70] # quality factor for jpeg compression 39 | nf_list = [1, 5, 10, 15, 20] # noise level for gaussian noise 40 | sig_list = [4, 8, 12, 16] # sigma for gaussian kernel 41 | theta_list = [x*math.pi for x in [0, 0.25, 0.5, 0.75]] # angle for gaussian kernel 42 | sf_list = [4, 8, 12, 16, 20, 24, 28, 32, 36, 40] # scale factor for upser-resolution 43 | ############################ ICLR #################################################### 44 | 45 | num_val = len(qf_list) * len(sf_list) * len(nf_list) * len(sig_list) * len(theta_list) 46 | 47 | # setting seed 48 | random.seed(args.seed) 49 | np.random.seed(args.seed) 50 | torch.manual_seed(args.seed) 51 | 52 | # checking save_dir 53 | lq_dir = Path(args.save_dir) / "lq" 54 | hq_dir = Path(args.save_dir) / "hq" 55 | info_dir = Path(args.save_dir) / "split_infos" 56 | 57 | util_common.mkdir(lq_dir, delete=True) 58 | util_common.mkdir(hq_dir, delete=True) 59 | util_common.mkdir(info_dir, delete=True) 60 | 61 | files_path = util_common.readline_txt(args.files_txt) 62 | assert num_val <= len(files_path) 63 | print(f'Number of images in validation: {num_val}') 64 | 65 | sf_split = {} 66 | for sf in sf_list: 67 | sf_split[f"sf{sf}"] = [] 68 | 69 | num_iters = 0 70 | for qf in qf_list: 71 | for sf in sf_list: 72 | for nf in nf_list: 73 | for sig_x in sig_list: 74 | for theta in theta_list: 75 | if (num_iters+1) % 100 == 0: 76 | print(f'Processing: {num_iters+1}/{num_val}') 77 | im_gt_path = files_path[num_iters] 78 | im_gt = util_image.imread(im_gt_path, chn='bgr', dtype='float32') 79 | 80 | sig_y = random.choice(sig_list) 81 | im_lq = face_degradation( 82 | im_gt, 83 | sf=sf, 84 | sig_x=sig_x, 85 | sig_y=sig_y, 86 | theta=theta, 87 | qf=qf, 88 | nf=nf, 89 | ) 90 | 91 | im_name = Path(im_gt_path).name 92 | 93 | sf_split[f"sf{sf}"].append(im_name) 94 | 95 | im_save_path = lq_dir / im_name 96 | util_image.imwrite(im_lq, im_save_path, chn="bgr", dtype_in='float32') 97 | 98 | im_save_path = hq_dir / im_name 99 | util_image.imwrite(im_gt, im_save_path, chn="bgr", dtype_in='float32') 100 | 101 | num_iters += 1 102 | 103 | info_path = info_dir / 'sf_split.pkl' 104 | with open(str(info_path), mode='wb') as ff: 105 | pickle.dump(sf_split, ff) 106 | 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DifFace: Blind Face Restoration with Diffused Error Contraction 2 | 3 | [Zongsheng Yue](https://zsyoaoa.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) 4 | 5 | [Paper](https://arxiv.org/abs/2212.06512) 6 | 7 | google colab logo [![Hugging Face](https://img.shields.io/badge/Demo-Hugging%20Face-blue)](https://huggingface.co/spaces/OAOA/DifFace) ![visitors](https://visitor-badge.laobi.icu/badge?page_id=zsyOAOA/DifFace) 8 | 9 | 10 | 11 | :star: If DifFace is helpful to your images or projects, please help star this repo. Thanks! :hugs: 12 | 13 | ## Update 14 | - **2022.12.19**: Add Colab demo google colab logo. 15 | - **2022.12.17**: Add the [![Hugging Face](https://img.shields.io/badge/Demo-Hugging%20Face-blue)](https://huggingface.co/spaces/OAOA/DifFace). 16 | - **2022.12.13**: Create this repo. 17 | 18 | ## Applications 19 | ### :point_right: Old Photo Enhancement 20 | [](https://imgsli.com/MTM5NTgw) 21 | 22 | [](https://imgsli.com/MTM5NTc5) [](https://imgsli.com/MTM5NTgy) 23 | 24 | ### :point_right: Face Restoration 25 | 26 | 27 | 28 | ## Requirements 29 | A suitable [conda](https://conda.io/) environment named `DifFace` can be created and activated with: 30 | 31 | ``` 32 | conda env create -f environment.yaml 33 | conda activate taming 34 | ``` 35 | 36 | ## Inference 37 | #### :boy: Face image restoration (cropped and aligned) 38 | ``` 39 | python inference_difface.py --aligned --in_path [image folder/image path] --out_path [result folder] --gpu_id [gpu index] 40 | ``` 41 | #### :couple: Whole image enhancement 42 | ``` 43 | python inference_difface.py --in_path [image folder/image path] --out_path [result folder] --gpu_id [gpu index] 44 | ``` 45 | 46 | ## Training 47 | #### :turtle: Prepare data 48 | 1. Download the [FFHQ](https://github.com/NVlabs/ffhq-dataset) dataset, and resize them into size 512x512. 49 | ``` 50 | python datapipe/prepare/face/big2small_face.py --face_dir [Face folder(1024x1024)] --save_dir [Saving folder] --pch_size 512 51 | ``` 52 | 2. Extract the image path into 'datapipe/files_txt/ffhq512.txt' 53 | ``` 54 | python datapipe/prepare/face/split_train_val.py --face_dir [Face folder(512x512)] --save_dir [Saving folder] 55 | ``` 56 | 3. Making the testing dataset 57 | ``` 58 | python datapipe/prepare/face/make_testing_data.py --files_txt datapipe/files_txt/ffhq512.txt --save_dir [Saving folder] 59 | ``` 60 | #### :dolphin: Train diffusion model 61 | ``` 62 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 --nnodes=1 main_diffusion.py --cfg_path configs/training/diffsuion_ffhq512.yaml --save_dir [Logging Folder] 63 | ``` 64 | #### :whale: Train diffused estimator (SwinIR) 65 | ``` 66 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 --nnodes=1 main_sr.py --cfg_path configs/training/swinir_ffhq512.yaml --save_dir [Logging Folder] 67 | ``` 68 | 69 | ## License 70 | 71 | This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license. 72 | 73 | ## Acknowledgement 74 | 75 | This project is based on [Improved Diffusion Model](https://github.com/openai/improved-diffusion). Some codes are brought from [BasicSR](https://github.com/XPixelGroup/BasicSR), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works. 76 | 77 | ### Contact 78 | If you have any questions, please feel free to contact me via `zsyzam@gmail.com`. 79 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /facelib/detection/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from copy import deepcopy 5 | 6 | from facelib.utils import load_file_from_url 7 | from facelib.utils import download_pretrained_models 8 | from facelib.detection.yolov5face.models.common import Conv 9 | 10 | from .retinaface.retinaface import RetinaFace 11 | from .yolov5face.face_detector import YoloDetector 12 | 13 | 14 | def init_detection_model(model_name, half=False, device='cuda'): 15 | if 'retinaface' in model_name: 16 | model = init_retinaface_model(model_name, half, device) 17 | elif 'YOLOv5' in model_name: 18 | model = init_yolov5face_model(model_name, device) 19 | else: 20 | raise NotImplementedError(f'{model_name} is not implemented.') 21 | 22 | return model 23 | 24 | 25 | def init_retinaface_model(model_name, half=False, device='cuda'): 26 | if model_name == 'retinaface_resnet50': 27 | model = RetinaFace(network_name='resnet50', half=half) 28 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth' 29 | elif model_name == 'retinaface_mobile0.25': 30 | model = RetinaFace(network_name='mobile0.25', half=half) 31 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' 32 | else: 33 | raise NotImplementedError(f'{model_name} is not implemented.') 34 | 35 | model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) 36 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 37 | # remove unnecessary 'module.' 38 | for k, v in deepcopy(load_net).items(): 39 | if k.startswith('module.'): 40 | load_net[k[7:]] = v 41 | load_net.pop(k) 42 | model.load_state_dict(load_net, strict=True) 43 | model.eval() 44 | model = model.to(device) 45 | 46 | return model 47 | 48 | 49 | def init_yolov5face_model(model_name, device='cuda'): 50 | if model_name == 'YOLOv5l': 51 | model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) 52 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth' 53 | elif model_name == 'YOLOv5n': 54 | model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) 55 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth' 56 | else: 57 | raise NotImplementedError(f'{model_name} is not implemented.') 58 | 59 | model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) 60 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 61 | model.detector.load_state_dict(load_net, strict=True) 62 | model.detector.eval() 63 | model.detector = model.detector.to(device).float() 64 | 65 | for m in model.detector.modules(): 66 | if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 67 | m.inplace = True # pytorch 1.7.0 compatibility 68 | elif isinstance(m, Conv): 69 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 70 | 71 | return model 72 | 73 | 74 | # Download from Google Drive 75 | # def init_yolov5face_model(model_name, device='cuda'): 76 | # if model_name == 'YOLOv5l': 77 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) 78 | # f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} 79 | # elif model_name == 'YOLOv5n': 80 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) 81 | # f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} 82 | # else: 83 | # raise NotImplementedError(f'{model_name} is not implemented.') 84 | 85 | # model_path = os.path.join('weights/facelib', list(f_id.keys())[0]) 86 | # if not os.path.exists(model_path): 87 | # download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib') 88 | 89 | # load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 90 | # model.detector.load_state_dict(load_net, strict=True) 91 | # model.detector.eval() 92 | # model.detector = model.detector.to(device).float() 93 | 94 | # for m in model.detector.modules(): 95 | # if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 96 | # m.inplace = True # pytorch 1.7.0 compatibility 97 | # elif isinstance(m, Conv): 98 | # m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 99 | 100 | # return model -------------------------------------------------------------------------------- /basicsr/archs/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | 12 | Used in RRDB block in ESRGAN. 13 | 14 | Args: 15 | num_feat (int): Channel number of intermediate features. 16 | num_grow_ch (int): Channels for each growth. 17 | """ 18 | 19 | def __init__(self, num_feat=64, num_grow_ch=32): 20 | super(ResidualDenseBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 26 | 27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | 29 | # initialization 30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 31 | 32 | def forward(self, x): 33 | x1 = self.lrelu(self.conv1(x)) 34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 38 | # Empirically, we use 0.2 to scale the residual for better performance 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | """Residual in Residual Dense Block. 44 | 45 | Used in RRDB-Net in ESRGAN. 46 | 47 | Args: 48 | num_feat (int): Channel number of intermediate features. 49 | num_grow_ch (int): Channels for each growth. 50 | """ 51 | 52 | def __init__(self, num_feat, num_grow_ch=32): 53 | super(RRDB, self).__init__() 54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 57 | 58 | def forward(self, x): 59 | out = self.rdb1(x) 60 | out = self.rdb2(out) 61 | out = self.rdb3(out) 62 | # Empirically, we use 0.2 to scale the residual for better performance 63 | return out * 0.2 + x 64 | 65 | 66 | @ARCH_REGISTRY.register() 67 | class RRDBNet(nn.Module): 68 | """Networks consisting of Residual in Residual Dense Block, which is used 69 | in ESRGAN. 70 | 71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 72 | 73 | We extend ESRGAN for scale x2 and scale x1. 74 | Note: This is one option for scale 1, scale 2 in RRDBNet. 75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 77 | 78 | Args: 79 | num_in_ch (int): Channel number of inputs. 80 | num_out_ch (int): Channel number of outputs. 81 | num_feat (int): Channel number of intermediate features. 82 | Default: 64 83 | num_block (int): Block number in the trunk network. Defaults: 23 84 | num_grow_ch (int): Channels for each growth. Default: 32. 85 | """ 86 | 87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 88 | super(RRDBNet, self).__init__() 89 | self.scale = scale 90 | if scale == 2: 91 | num_in_ch = num_in_ch * 4 92 | elif scale == 1: 93 | num_in_ch = num_in_ch * 16 94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | # upsample 98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 102 | 103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 104 | 105 | def forward(self, x): 106 | if self.scale == 2: 107 | feat = pixel_unshuffle(x, scale=2) 108 | elif self.scale == 1: 109 | feat = pixel_unshuffle(x, scale=4) 110 | else: 111 | feat = x 112 | feat = self.conv_first(feat) 113 | body_feat = self.conv_body(self.body(feat)) 114 | feat = feat + body_feat 115 | # upsample 116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 118 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 119 | return out 120 | -------------------------------------------------------------------------------- /basicsr/archs/rcan_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import Upsample, make_layer 6 | 7 | 8 | class ChannelAttention(nn.Module): 9 | """Channel attention used in RCAN. 10 | 11 | Args: 12 | num_feat (int): Channel number of intermediate features. 13 | squeeze_factor (int): Channel squeeze factor. Default: 16. 14 | """ 15 | 16 | def __init__(self, num_feat, squeeze_factor=16): 17 | super(ChannelAttention, self).__init__() 18 | self.attention = nn.Sequential( 19 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), 20 | nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid()) 21 | 22 | def forward(self, x): 23 | y = self.attention(x) 24 | return x * y 25 | 26 | 27 | class RCAB(nn.Module): 28 | """Residual Channel Attention Block (RCAB) used in RCAN. 29 | 30 | Args: 31 | num_feat (int): Channel number of intermediate features. 32 | squeeze_factor (int): Channel squeeze factor. Default: 16. 33 | res_scale (float): Scale the residual. Default: 1. 34 | """ 35 | 36 | def __init__(self, num_feat, squeeze_factor=16, res_scale=1): 37 | super(RCAB, self).__init__() 38 | self.res_scale = res_scale 39 | 40 | self.rcab = nn.Sequential( 41 | nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1), 42 | ChannelAttention(num_feat, squeeze_factor)) 43 | 44 | def forward(self, x): 45 | res = self.rcab(x) * self.res_scale 46 | return res + x 47 | 48 | 49 | class ResidualGroup(nn.Module): 50 | """Residual Group of RCAB. 51 | 52 | Args: 53 | num_feat (int): Channel number of intermediate features. 54 | num_block (int): Block number in the body network. 55 | squeeze_factor (int): Channel squeeze factor. Default: 16. 56 | res_scale (float): Scale the residual. Default: 1. 57 | """ 58 | 59 | def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1): 60 | super(ResidualGroup, self).__init__() 61 | 62 | self.residual_group = make_layer( 63 | RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale) 64 | self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 65 | 66 | def forward(self, x): 67 | res = self.conv(self.residual_group(x)) 68 | return res + x 69 | 70 | 71 | @ARCH_REGISTRY.register() 72 | class RCAN(nn.Module): 73 | """Residual Channel Attention Networks. 74 | 75 | ``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks`` 76 | 77 | Reference: https://github.com/yulunzhang/RCAN 78 | 79 | Args: 80 | num_in_ch (int): Channel number of inputs. 81 | num_out_ch (int): Channel number of outputs. 82 | num_feat (int): Channel number of intermediate features. 83 | Default: 64. 84 | num_group (int): Number of ResidualGroup. Default: 10. 85 | num_block (int): Number of RCAB in ResidualGroup. Default: 16. 86 | squeeze_factor (int): Channel squeeze factor. Default: 16. 87 | upscale (int): Upsampling factor. Support 2^n and 3. 88 | Default: 4. 89 | res_scale (float): Used to scale the residual in residual block. 90 | Default: 1. 91 | img_range (float): Image range. Default: 255. 92 | rgb_mean (tuple[float]): Image mean in RGB orders. 93 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 94 | """ 95 | 96 | def __init__(self, 97 | num_in_ch, 98 | num_out_ch, 99 | num_feat=64, 100 | num_group=10, 101 | num_block=16, 102 | squeeze_factor=16, 103 | upscale=4, 104 | res_scale=1, 105 | img_range=255., 106 | rgb_mean=(0.4488, 0.4371, 0.4040)): 107 | super(RCAN, self).__init__() 108 | 109 | self.img_range = img_range 110 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 111 | 112 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 113 | self.body = make_layer( 114 | ResidualGroup, 115 | num_group, 116 | num_feat=num_feat, 117 | num_block=num_block, 118 | squeeze_factor=squeeze_factor, 119 | res_scale=res_scale) 120 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 121 | self.upsample = Upsample(upscale, num_feat) 122 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 123 | 124 | def forward(self, x): 125 | self.mean = self.mean.type_as(x) 126 | 127 | x = (x - self.mean) * self.img_range 128 | x = self.conv_first(x) 129 | res = self.conv_after_body(self.body(x)) 130 | res += x 131 | 132 | x = self.conv_last(self.upsample(res)) 133 | x = x / self.img_range + self.mean 134 | 135 | return x 136 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file size. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /models/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim"):]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] #[250,] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | class SpacedDiffusion(GaussianDiffusion): 63 | """ 64 | A diffusion process which can skip steps in a base diffusion process. 65 | 66 | :param use_timesteps: a collection (sequence or set) of timesteps from the 67 | original diffusion process to retain. 68 | :param kwargs: the kwargs to create the base diffusion process. 69 | """ 70 | 71 | def __init__(self, use_timesteps, **kwargs): 72 | self.use_timesteps = set(use_timesteps) 73 | self.timestep_map = [] 74 | self.original_num_steps = len(kwargs["betas"]) 75 | 76 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 77 | last_alpha_cumprod = 1.0 78 | new_betas = [] 79 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 80 | if i in self.use_timesteps: 81 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 82 | last_alpha_cumprod = alpha_cumprod 83 | self.timestep_map.append(i) 84 | kwargs["betas"] = np.array(new_betas) 85 | super().__init__(**kwargs) 86 | 87 | def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs 88 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 89 | 90 | def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs 91 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 92 | 93 | def _wrap_model(self, model): 94 | if isinstance(model, _WrappedModel): 95 | return model 96 | return _WrappedModel( 97 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 98 | ) 99 | 100 | def _scale_timesteps(self, t): 101 | # Scaling is done by the wrapped model. 102 | return t 103 | 104 | class _WrappedModel: 105 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 106 | self.model = model 107 | self.timestep_map = timestep_map 108 | self.rescale_timesteps = rescale_timesteps 109 | self.original_num_steps = original_num_steps 110 | 111 | def __call__(self, x, ts, **kwargs): 112 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 113 | new_ts = map_tensor[ts] 114 | if self.rescale_timesteps: 115 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 116 | return self.model(x, new_ts, **kwargs) 117 | -------------------------------------------------------------------------------- /basicsr/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | 4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file 5 | from basicsr.data.transforms import augment, paired_random_crop 6 | from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class PairedImageDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 15 | 16 | There are three modes: 17 | 18 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. 19 | 2. **meta_info_file**: Use meta information file to generate paths. \ 20 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 21 | 3. **folder**: Scan folders to generate paths. The rest. 22 | 23 | Args: 24 | opt (dict): Config for train datasets. It contains the following keys: 25 | dataroot_gt (str): Data root path for gt. 26 | dataroot_lq (str): Data root path for lq. 27 | meta_info_file (str): Path for meta information file. 28 | io_backend (dict): IO backend type and other kwarg. 29 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 30 | Default: '{}'. 31 | gt_size (int): Cropped patched size for gt patches. 32 | use_hflip (bool): Use horizontal flips. 33 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 34 | scale (bool): Scale, which will be added automatically. 35 | phase (str): 'train' or 'val'. 36 | """ 37 | 38 | def __init__(self, opt): 39 | super(PairedImageDataset, self).__init__() 40 | self.opt = opt 41 | # file client (io backend) 42 | self.file_client = None 43 | self.io_backend_opt = opt['io_backend'] 44 | self.mean = opt['mean'] if 'mean' in opt else None 45 | self.std = opt['std'] if 'std' in opt else None 46 | 47 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 48 | if 'filename_tmpl' in opt: 49 | self.filename_tmpl = opt['filename_tmpl'] 50 | else: 51 | self.filename_tmpl = '{}' 52 | 53 | if self.io_backend_opt['type'] == 'lmdb': 54 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 55 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 56 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 57 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: 58 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], 59 | self.opt['meta_info_file'], self.filename_tmpl) 60 | else: 61 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 62 | 63 | def __getitem__(self, index): 64 | if self.file_client is None: 65 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 66 | 67 | scale = self.opt['scale'] 68 | 69 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 70 | # image range: [0, 1], float32. 71 | gt_path = self.paths[index]['gt_path'] 72 | img_bytes = self.file_client.get(gt_path, 'gt') 73 | img_gt = imfrombytes(img_bytes, float32=True) 74 | lq_path = self.paths[index]['lq_path'] 75 | img_bytes = self.file_client.get(lq_path, 'lq') 76 | img_lq = imfrombytes(img_bytes, float32=True) 77 | 78 | # augmentation for training 79 | if self.opt['phase'] == 'train': 80 | gt_size = self.opt['gt_size'] 81 | # random crop 82 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 83 | # flip, rotation 84 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 85 | 86 | # color space transform 87 | if 'color' in self.opt and self.opt['color'] == 'y': 88 | img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] 89 | img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] 90 | 91 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets 92 | # TODO: It is better to update the datasets, rather than force to crop 93 | if self.opt['phase'] != 'train': 94 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] 95 | 96 | # BGR to RGB, HWC to CHW, numpy to tensor 97 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 98 | # normalize 99 | if self.mean is not None or self.std is not None: 100 | normalize(img_lq, self.mean, self.std, inplace=True) 101 | normalize(img_gt, self.mean, self.std, inplace=True) 102 | 103 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 104 | 105 | def __len__(self): 106 | return len(self.paths) 107 | -------------------------------------------------------------------------------- /basicsr/data/realesrgan_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, imfrombytes, img2tensor 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | 10 | 11 | @DATASET_REGISTRY.register(suffix='basicsr') 12 | class RealESRGANPairedDataset(data.Dataset): 13 | """Paired image dataset for image restoration. 14 | 15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 16 | 17 | There are three modes: 18 | 19 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. 20 | 2. **meta_info_file**: Use meta information file to generate paths. \ 21 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 22 | 3. **folder**: Scan folders to generate paths. The rest. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_gt (str): Data root path for gt. 27 | dataroot_lq (str): Data root path for lq. 28 | meta_info (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 31 | Default: '{}'. 32 | gt_size (int): Cropped patched size for gt patches. 33 | use_hflip (bool): Use horizontal flips. 34 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 35 | scale (bool): Scale, which will be added automatically. 36 | phase (str): 'train' or 'val'. 37 | """ 38 | 39 | def __init__(self, opt): 40 | super(RealESRGANPairedDataset, self).__init__() 41 | self.opt = opt 42 | self.file_client = None 43 | self.io_backend_opt = opt['io_backend'] 44 | # mean and std for normalizing the input images 45 | self.mean = opt['mean'] if 'mean' in opt else None 46 | self.std = opt['std'] if 'std' in opt else None 47 | 48 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 49 | self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' 50 | 51 | # file client (lmdb io backend) 52 | if self.io_backend_opt['type'] == 'lmdb': 53 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 54 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 55 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 56 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: 57 | # disk backend with meta_info 58 | # Each line in the meta_info describes the relative path to an image 59 | with open(self.opt['meta_info']) as fin: 60 | paths = [line.strip() for line in fin] 61 | self.paths = [] 62 | for path in paths: 63 | gt_path, lq_path = path.split(', ') 64 | gt_path = os.path.join(self.gt_folder, gt_path) 65 | lq_path = os.path.join(self.lq_folder, lq_path) 66 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) 67 | else: 68 | # disk backend 69 | # it will scan the whole folder to get meta info 70 | # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file 71 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 72 | 73 | def __getitem__(self, index): 74 | if self.file_client is None: 75 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 76 | 77 | scale = self.opt['scale'] 78 | 79 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 80 | # image range: [0, 1], float32. 81 | gt_path = self.paths[index]['gt_path'] 82 | img_bytes = self.file_client.get(gt_path, 'gt') 83 | img_gt = imfrombytes(img_bytes, float32=True) 84 | lq_path = self.paths[index]['lq_path'] 85 | img_bytes = self.file_client.get(lq_path, 'lq') 86 | img_lq = imfrombytes(img_bytes, float32=True) 87 | 88 | # augmentation for training 89 | if self.opt['phase'] == 'train': 90 | gt_size = self.opt['gt_size'] 91 | # random crop 92 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 93 | # flip, rotation 94 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 95 | 96 | # BGR to RGB, HWC to CHW, numpy to tensor 97 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 98 | # normalize 99 | if self.mean is not None or self.std is not None: 100 | normalize(img_lq, self.mean, self.std, inplace=True) 101 | normalize(img_gt, self.mean, self.std, inplace=True) 102 | 103 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 104 | 105 | def __len__(self): 106 | return len(self.paths) 107 | --------------------------------------------------------------------------------