├── .gitignore ├── Experiments ├── notes.md └── optimizer_1.ipynb ├── LICENSE ├── Notebooks └── lip-wise.ipynb ├── README.md ├── assets ├── Lip-Wise-logos-1.zip ├── Lip-Wise-logos-1 │ ├── Lip-Wise-logos.jpeg │ └── logo_info.txt ├── Lip-Wise-logos-2.zip ├── Lip-Wise-logos-3.zip ├── Lip-Wise-logos-4.zip ├── Lip-Wise-logos-5.zip ├── Lip-Wise-logos-6.zip ├── Lip-Wise-logos.zip ├── Screenshot_25-2-2024_12512_127.0.0.1.jpeg └── UI ScreenShots │ ├── process image.jpeg │ └── process video.jpeg ├── basicsr ├── __init__.py ├── archs │ ├── __init__.py │ ├── arch_util.py │ ├── basicvsr_arch.py │ ├── basicvsrpp_arch.py │ ├── codeformer_arch.py │ ├── dfdnet_arch.py │ ├── dfdnet_util.py │ ├── discriminator_arch.py │ ├── duf_arch.py │ ├── ecbsr_arch.py │ ├── edsr_arch.py │ ├── edvr_arch.py │ ├── hifacegan_arch.py │ ├── hifacegan_util.py │ ├── inception.py │ ├── rcan_arch.py │ ├── ridnet_arch.py │ ├── rrdbnet_arch.py │ ├── spynet_arch.py │ ├── srresnet_arch.py │ ├── srvgg_arch.py │ ├── stylegan2_arch.py │ ├── swinir_arch.py │ ├── tof_arch.py │ ├── vgg_arch.py │ └── vqgan_arch.py ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── degradations.py │ ├── ffhq_dataset.py │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── realesrgan_dataset.py │ ├── realesrgan_paired_dataset.py │ ├── reds_dataset.py │ ├── single_image_dataset.py │ ├── transforms.py │ └── vimeo90k_dataset.py ├── losses │ ├── __init__.py │ ├── basic_loss.py │ ├── gan_loss.py │ └── loss_util.py ├── metrics │ ├── __init__.py │ ├── fid.py │ ├── metric_util.py │ ├── niqe.py │ ├── niqe_pris_params.npz │ └── psnr_ssim.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── edvr_model.py │ ├── esrgan_model.py │ ├── hifacegan_model.py │ ├── lr_scheduler.py │ ├── realesrgan_model.py │ ├── realesrnet_model.py │ ├── sr_model.py │ ├── srgan_model.py │ ├── stylegan2_model.py │ ├── swinir_model.py │ ├── video_base_model.py │ ├── video_gan_model.py │ ├── video_recurrent_gan_model.py │ └── video_recurrent_model.py ├── ops │ ├── __init__.py │ ├── dcn │ │ ├── __init__.py │ │ ├── deform_conv.py │ │ └── src │ │ │ ├── deform_conv_cuda.cpp │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ └── deform_conv_ext.cpp │ ├── fused_act │ │ ├── __init__.py │ │ ├── fused_act.py │ │ └── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ └── upfirdn2d │ │ ├── __init__.py │ │ ├── src │ │ ├── upfirdn2d.cpp │ │ └── upfirdn2d_kernel.cu │ │ └── upfirdn2d.py ├── train.py ├── utils │ ├── __init__.py │ ├── color_util.py │ ├── diffjpeg.py │ ├── dist_util.py │ ├── download_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_process_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ ├── plot_util.py │ └── registry.py └── version.py ├── helpers ├── __init__.py ├── audio.py ├── batch_processors.py ├── file_check.py ├── hparams.py ├── model_loaders.py ├── preprocess_mp.py └── vars.py ├── infer.py ├── launch.bat ├── launch.py ├── launch.sh ├── models ├── __init__.py ├── conv.py ├── syncnet.py └── wav2lip.py ├── opt.ipynb ├── requirements.txt ├── setup-colab.sh ├── setup.bat ├── setup.sh ├── styles ├── style-black.css ├── style.css ├── style2.css └── style3-cool.css └── todo.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.jpg 3 | *.mp4 4 | *.pth 5 | *.pyc 6 | *.pth 7 | __pycache__ 8 | *.h5 9 | *.avi 10 | *.wav 11 | filelists/*.txt 12 | evaluation/test_filelists/lr*.txt 13 | *.pyc 14 | *.mkv 15 | *.gif 16 | *.webm 17 | *.mp3 18 | gan 19 | *lip-wise 20 | *test* 21 | *_rgh.py 22 | *frame*.png 23 | dlib* 24 | *.onnx 25 | .vscode 26 | .exe 27 | *.png 28 | *_OLD.md 29 | *temp 30 | *weights 31 | *optimizer.ipynb 32 | infer_sl.py 33 | launch_sl.py 34 | optimizer_1.ipynb 35 | assets/gradient-ui-ux-elements.zip 36 | Experiments/models/yunet.py 37 | Experiments/models/wav2lip.py 38 | Experiments/models/syncnet.py 39 | Experiments/models/conv.py 40 | Experiments/models/__init__.py 41 | Experiments/helpers/preprocess_mp.py 42 | Experiments/helpers/model_loaders.py 43 | Experiments/helpers/hparams.py 44 | Experiments/helpers/file_check.py 45 | Experiments/helpers/batch_processors.py 46 | Experiments/helpers/audio.py 47 | Experiments/helpers/__init__.py 48 | Experiments/infer.py 49 | Experiments/__init__.py 50 | Result_Analytics/ 51 | .idx/dev.nix 52 | -------------------------------------------------------------------------------- /Experiments/notes.md: -------------------------------------------------------------------------------- 1 | # :memo: **TO-DO** List: 2 | 3 | ### PREPROCESS 4 | - [x] Add directory check in inference in the beginning. 5 | - [x] Make preprocessing optimal. 6 | - [x] Clear ram after no_face_filter. 7 | - [x] Make face coordinates reusable: 8 | - [x] Saving facial coordinates as .npy file. 9 | - [x] Alter code to also include eye coordinates. 10 | 11 | ### IMPROVING GAN UPSCALING 12 | - [x] Merge Data Pipeline with preprocessor: 13 | - [x] Remove need to recrop, realign and rewarp the image. 14 | 15 | ### IMPROVING WAV2LIP 16 | - [x] Merge all data Pipeline: 17 | - [x] Remove the need to recrop, realign, renormalizing etc. 18 | - [x] Devise a way to keep frames without face in the video. 19 | - [x] Understand Mels and working of wav2lip model. 20 | 21 | ### OPTIONAL 22 | - [ ] Gradio UI 23 | - [ ] A tab for configuration variables. 24 | - [x] A tab for Video, Audio and Output. 25 | - [x] A tab for Image, Audio and output. 26 | 27 | ### URGENT REQUIREMENTS 28 | - [ ] setup.py 29 | - [ ] create venv 30 | - [ ] install requirements inside venv 31 | - [ ] codeformer arch initialization 32 | 33 | - [ ] Documentation 34 | 35 | ### FURTHER IMPROVEMENTS 36 | - [ ] Inference without restorer 37 | - [ ] Model Improvement 38 | - [ ] Implement no_face_filter too 39 | 40 | ### FUTURE PLANS 41 | - [ ] Face and Audio wise Lipsync using face recognition. 42 | - [ ] A separate tab for TTS. 43 | 44 | ### COLAB NOTEBOOK 45 | - [ ] Optimize Inference. 46 | - [ ] Implement Checks. -------------------------------------------------------------------------------- /Notebooks/lip-wise.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/Notebooks/lip-wise.ipynb -------------------------------------------------------------------------------- /assets/Lip-Wise-logos-1.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/Lip-Wise-logos-1.zip -------------------------------------------------------------------------------- /assets/Lip-Wise-logos-1/Lip-Wise-logos.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/Lip-Wise-logos-1/Lip-Wise-logos.jpeg -------------------------------------------------------------------------------- /assets/Lip-Wise-logos-1/logo_info.txt: -------------------------------------------------------------------------------- 1 | 2 | Fonts used: Raleway-Heavy 3 | 4 | Colors used: F67571,020202 5 | 6 | Icon url: https://thenounproject.com/term/lip/1557625 7 | -------------------------------------------------------------------------------- /assets/Lip-Wise-logos-2.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/Lip-Wise-logos-2.zip -------------------------------------------------------------------------------- /assets/Lip-Wise-logos-3.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/Lip-Wise-logos-3.zip -------------------------------------------------------------------------------- /assets/Lip-Wise-logos-4.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/Lip-Wise-logos-4.zip -------------------------------------------------------------------------------- /assets/Lip-Wise-logos-5.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/Lip-Wise-logos-5.zip -------------------------------------------------------------------------------- /assets/Lip-Wise-logos-6.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/Lip-Wise-logos-6.zip -------------------------------------------------------------------------------- /assets/Lip-Wise-logos.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/Lip-Wise-logos.zip -------------------------------------------------------------------------------- /assets/Screenshot_25-2-2024_12512_127.0.0.1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/Screenshot_25-2-2024_12512_127.0.0.1.jpeg -------------------------------------------------------------------------------- /assets/UI ScreenShots/process image.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/UI ScreenShots/process image.jpeg -------------------------------------------------------------------------------- /assets/UI ScreenShots/process video.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/assets/UI ScreenShots/process video.jpeg -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .test import * 10 | from .train import * 11 | from .utils import * 12 | from .version import __gitsha__, __version__ 13 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules 16 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | -------------------------------------------------------------------------------- /basicsr/archs/dfdnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | from torch.nn.utils.spectral_norm import spectral_norm 6 | 7 | 8 | class BlurFunctionBackward(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, grad_output, kernel, kernel_flip): 12 | ctx.save_for_backward(kernel, kernel_flip) 13 | grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]) 14 | return grad_input 15 | 16 | @staticmethod 17 | def backward(ctx, gradgrad_output): 18 | kernel, _ = ctx.saved_tensors 19 | grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]) 20 | return grad_input, None, None 21 | 22 | 23 | class BlurFunction(Function): 24 | 25 | @staticmethod 26 | def forward(ctx, x, kernel, kernel_flip): 27 | ctx.save_for_backward(kernel, kernel_flip) 28 | output = F.conv2d(x, kernel, padding=1, groups=x.shape[1]) 29 | return output 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output): 33 | kernel, kernel_flip = ctx.saved_tensors 34 | grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip) 35 | return grad_input, None, None 36 | 37 | 38 | blur = BlurFunction.apply 39 | 40 | 41 | class Blur(nn.Module): 42 | 43 | def __init__(self, channel): 44 | super().__init__() 45 | kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) 46 | kernel = kernel.view(1, 1, 3, 3) 47 | kernel = kernel / kernel.sum() 48 | kernel_flip = torch.flip(kernel, [2, 3]) 49 | 50 | self.kernel = kernel.repeat(channel, 1, 1, 1) 51 | self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1) 52 | 53 | def forward(self, x): 54 | return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x)) 55 | 56 | 57 | def calc_mean_std(feat, eps=1e-5): 58 | """Calculate mean and std for adaptive_instance_normalization. 59 | 60 | Args: 61 | feat (Tensor): 4D tensor. 62 | eps (float): A small value added to the variance to avoid 63 | divide-by-zero. Default: 1e-5. 64 | """ 65 | size = feat.size() 66 | assert len(size) == 4, 'The input feature should be 4D tensor.' 67 | n, c = size[:2] 68 | feat_var = feat.view(n, c, -1).var(dim=2) + eps 69 | feat_std = feat_var.sqrt().view(n, c, 1, 1) 70 | feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1) 71 | return feat_mean, feat_std 72 | 73 | 74 | def adaptive_instance_normalization(content_feat, style_feat): 75 | """Adaptive instance normalization. 76 | 77 | Adjust the reference features to have the similar color and illuminations 78 | as those in the degradate features. 79 | 80 | Args: 81 | content_feat (Tensor): The reference feature. 82 | style_feat (Tensor): The degradate features. 83 | """ 84 | size = content_feat.size() 85 | style_mean, style_std = calc_mean_std(style_feat) 86 | content_mean, content_std = calc_mean_std(content_feat) 87 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 88 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 89 | 90 | 91 | def AttentionBlock(in_channel): 92 | return nn.Sequential( 93 | spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), 94 | spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))) 95 | 96 | 97 | def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): 98 | """Conv block used in MSDilationBlock.""" 99 | 100 | return nn.Sequential( 101 | spectral_norm( 102 | nn.Conv2d( 103 | in_channels, 104 | out_channels, 105 | kernel_size=kernel_size, 106 | stride=stride, 107 | dilation=dilation, 108 | padding=((kernel_size - 1) // 2) * dilation, 109 | bias=bias)), 110 | nn.LeakyReLU(0.2), 111 | spectral_norm( 112 | nn.Conv2d( 113 | out_channels, 114 | out_channels, 115 | kernel_size=kernel_size, 116 | stride=stride, 117 | dilation=dilation, 118 | padding=((kernel_size - 1) // 2) * dilation, 119 | bias=bias)), 120 | ) 121 | 122 | 123 | class MSDilationBlock(nn.Module): 124 | """Multi-scale dilation block.""" 125 | 126 | def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True): 127 | super(MSDilationBlock, self).__init__() 128 | 129 | self.conv_blocks = nn.ModuleList() 130 | for i in range(4): 131 | self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias)) 132 | self.conv_fusion = spectral_norm( 133 | nn.Conv2d( 134 | in_channels * 4, 135 | in_channels, 136 | kernel_size=kernel_size, 137 | stride=1, 138 | padding=(kernel_size - 1) // 2, 139 | bias=bias)) 140 | 141 | def forward(self, x): 142 | out = [] 143 | for i in range(4): 144 | out.append(self.conv_blocks[i](x)) 145 | out = torch.cat(out, 1) 146 | out = self.conv_fusion(out) + x 147 | return out 148 | 149 | 150 | class UpResBlock(nn.Module): 151 | 152 | def __init__(self, in_channel): 153 | super(UpResBlock, self).__init__() 154 | self.body = nn.Sequential( 155 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 156 | nn.LeakyReLU(0.2, True), 157 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 158 | ) 159 | 160 | def forward(self, x): 161 | out = x + self.body(x) 162 | return out 163 | -------------------------------------------------------------------------------- /basicsr/archs/discriminator_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | from torch.nn.utils import spectral_norm 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class VGGStyleDiscriminator(nn.Module): 10 | """VGG style discriminator with input size 128 x 128 or 256 x 256. 11 | 12 | It is used to train SRGAN, ESRGAN, and VideoGAN. 13 | 14 | Args: 15 | num_in_ch (int): Channel number of inputs. Default: 3. 16 | num_feat (int): Channel number of base intermediate features.Default: 64. 17 | """ 18 | 19 | def __init__(self, num_in_ch, num_feat, input_size=128): 20 | super(VGGStyleDiscriminator, self).__init__() 21 | self.input_size = input_size 22 | assert self.input_size == 128 or self.input_size == 256, ( 23 | f'input size must be 128 or 256, but received {input_size}') 24 | 25 | self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True) 26 | self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False) 27 | self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True) 28 | 29 | self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False) 30 | self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True) 31 | self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False) 32 | self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True) 33 | 34 | self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False) 35 | self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True) 36 | self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False) 37 | self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True) 38 | 39 | self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False) 40 | self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True) 41 | self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) 42 | self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True) 43 | 44 | self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) 45 | self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True) 46 | self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) 47 | self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True) 48 | 49 | if self.input_size == 256: 50 | self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) 51 | self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True) 52 | self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) 53 | self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True) 54 | 55 | self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100) 56 | self.linear2 = nn.Linear(100, 1) 57 | 58 | # activation function 59 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 60 | 61 | def forward(self, x): 62 | assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.') 63 | 64 | feat = self.lrelu(self.conv0_0(x)) 65 | feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2 66 | 67 | feat = self.lrelu(self.bn1_0(self.conv1_0(feat))) 68 | feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4 69 | 70 | feat = self.lrelu(self.bn2_0(self.conv2_0(feat))) 71 | feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8 72 | 73 | feat = self.lrelu(self.bn3_0(self.conv3_0(feat))) 74 | feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16 75 | 76 | feat = self.lrelu(self.bn4_0(self.conv4_0(feat))) 77 | feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32 78 | 79 | if self.input_size == 256: 80 | feat = self.lrelu(self.bn5_0(self.conv5_0(feat))) 81 | feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64 82 | 83 | # spatial size: (4, 4) 84 | feat = feat.view(feat.size(0), -1) 85 | feat = self.lrelu(self.linear1(feat)) 86 | out = self.linear2(feat) 87 | return out 88 | 89 | 90 | @ARCH_REGISTRY.register(suffix='basicsr') 91 | class UNetDiscriminatorSN(nn.Module): 92 | """Defines a U-Net discriminator with spectral normalization (SN) 93 | 94 | It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 95 | 96 | Arg: 97 | num_in_ch (int): Channel number of inputs. Default: 3. 98 | num_feat (int): Channel number of base intermediate features. Default: 64. 99 | skip_connection (bool): Whether to use skip connections between U-Net. Default: True. 100 | """ 101 | 102 | def __init__(self, num_in_ch, num_feat=64, skip_connection=True): 103 | super(UNetDiscriminatorSN, self).__init__() 104 | self.skip_connection = skip_connection 105 | norm = spectral_norm 106 | # the first convolution 107 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) 108 | # downsample 109 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) 110 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) 111 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) 112 | # upsample 113 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) 114 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) 115 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) 116 | # extra convolutions 117 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 118 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 119 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) 120 | 121 | def forward(self, x): 122 | # downsample 123 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) 124 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) 125 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) 126 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) 127 | 128 | # upsample 129 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) 130 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) 131 | 132 | if self.skip_connection: 133 | x4 = x4 + x2 134 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) 135 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) 136 | 137 | if self.skip_connection: 138 | x5 = x5 + x1 139 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) 140 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) 141 | 142 | if self.skip_connection: 143 | x6 = x6 + x0 144 | 145 | # extra convolutions 146 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) 147 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) 148 | out = self.conv9(out) 149 | 150 | return out 151 | -------------------------------------------------------------------------------- /basicsr/archs/edsr_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class EDSR(nn.Module): 10 | """EDSR network structure. 11 | 12 | Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. 13 | Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch 14 | 15 | Args: 16 | num_in_ch (int): Channel number of inputs. 17 | num_out_ch (int): Channel number of outputs. 18 | num_feat (int): Channel number of intermediate features. 19 | Default: 64. 20 | num_block (int): Block number in the trunk network. Default: 16. 21 | upscale (int): Upsampling factor. Support 2^n and 3. 22 | Default: 4. 23 | res_scale (float): Used to scale the residual in residual block. 24 | Default: 1. 25 | img_range (float): Image range. Default: 255. 26 | rgb_mean (tuple[float]): Image mean in RGB orders. 27 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 28 | """ 29 | 30 | def __init__(self, 31 | num_in_ch, 32 | num_out_ch, 33 | num_feat=64, 34 | num_block=16, 35 | upscale=4, 36 | res_scale=1, 37 | img_range=255., 38 | rgb_mean=(0.4488, 0.4371, 0.4040)): 39 | super(EDSR, self).__init__() 40 | 41 | self.img_range = img_range 42 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 43 | 44 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 45 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True) 46 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 47 | self.upsample = Upsample(upscale, num_feat) 48 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 49 | 50 | def forward(self, x): 51 | self.mean = self.mean.type_as(x) 52 | 53 | x = (x - self.mean) * self.img_range 54 | x = self.conv_first(x) 55 | res = self.conv_after_body(self.body(x)) 56 | res += x 57 | 58 | x = self.conv_last(self.upsample(res)) 59 | x = x / self.img_range + self.mean 60 | 61 | return x 62 | -------------------------------------------------------------------------------- /basicsr/archs/rcan_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import Upsample, make_layer 6 | 7 | 8 | class ChannelAttention(nn.Module): 9 | """Channel attention used in RCAN. 10 | 11 | Args: 12 | num_feat (int): Channel number of intermediate features. 13 | squeeze_factor (int): Channel squeeze factor. Default: 16. 14 | """ 15 | 16 | def __init__(self, num_feat, squeeze_factor=16): 17 | super(ChannelAttention, self).__init__() 18 | self.attention = nn.Sequential( 19 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), 20 | nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid()) 21 | 22 | def forward(self, x): 23 | y = self.attention(x) 24 | return x * y 25 | 26 | 27 | class RCAB(nn.Module): 28 | """Residual Channel Attention Block (RCAB) used in RCAN. 29 | 30 | Args: 31 | num_feat (int): Channel number of intermediate features. 32 | squeeze_factor (int): Channel squeeze factor. Default: 16. 33 | res_scale (float): Scale the residual. Default: 1. 34 | """ 35 | 36 | def __init__(self, num_feat, squeeze_factor=16, res_scale=1): 37 | super(RCAB, self).__init__() 38 | self.res_scale = res_scale 39 | 40 | self.rcab = nn.Sequential( 41 | nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1), 42 | ChannelAttention(num_feat, squeeze_factor)) 43 | 44 | def forward(self, x): 45 | res = self.rcab(x) * self.res_scale 46 | return res + x 47 | 48 | 49 | class ResidualGroup(nn.Module): 50 | """Residual Group of RCAB. 51 | 52 | Args: 53 | num_feat (int): Channel number of intermediate features. 54 | num_block (int): Block number in the body network. 55 | squeeze_factor (int): Channel squeeze factor. Default: 16. 56 | res_scale (float): Scale the residual. Default: 1. 57 | """ 58 | 59 | def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1): 60 | super(ResidualGroup, self).__init__() 61 | 62 | self.residual_group = make_layer( 63 | RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale) 64 | self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 65 | 66 | def forward(self, x): 67 | res = self.conv(self.residual_group(x)) 68 | return res + x 69 | 70 | 71 | @ARCH_REGISTRY.register() 72 | class RCAN(nn.Module): 73 | """Residual Channel Attention Networks. 74 | 75 | Paper: Image Super-Resolution Using Very Deep Residual Channel Attention 76 | Networks 77 | Ref git repo: https://github.com/yulunzhang/RCAN. 78 | 79 | Args: 80 | num_in_ch (int): Channel number of inputs. 81 | num_out_ch (int): Channel number of outputs. 82 | num_feat (int): Channel number of intermediate features. 83 | Default: 64. 84 | num_group (int): Number of ResidualGroup. Default: 10. 85 | num_block (int): Number of RCAB in ResidualGroup. Default: 16. 86 | squeeze_factor (int): Channel squeeze factor. Default: 16. 87 | upscale (int): Upsampling factor. Support 2^n and 3. 88 | Default: 4. 89 | res_scale (float): Used to scale the residual in residual block. 90 | Default: 1. 91 | img_range (float): Image range. Default: 255. 92 | rgb_mean (tuple[float]): Image mean in RGB orders. 93 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 94 | """ 95 | 96 | def __init__(self, 97 | num_in_ch, 98 | num_out_ch, 99 | num_feat=64, 100 | num_group=10, 101 | num_block=16, 102 | squeeze_factor=16, 103 | upscale=4, 104 | res_scale=1, 105 | img_range=255., 106 | rgb_mean=(0.4488, 0.4371, 0.4040)): 107 | super(RCAN, self).__init__() 108 | 109 | self.img_range = img_range 110 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 111 | 112 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 113 | self.body = make_layer( 114 | ResidualGroup, 115 | num_group, 116 | num_feat=num_feat, 117 | num_block=num_block, 118 | squeeze_factor=squeeze_factor, 119 | res_scale=res_scale) 120 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 121 | self.upsample = Upsample(upscale, num_feat) 122 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 123 | 124 | def forward(self, x): 125 | self.mean = self.mean.type_as(x) 126 | 127 | x = (x - self.mean) * self.img_range 128 | x = self.conv_first(x) 129 | res = self.conv_after_body(self.body(x)) 130 | res += x 131 | 132 | x = self.conv_last(self.upsample(res)) 133 | x = x / self.img_range + self.mean 134 | 135 | return x 136 | -------------------------------------------------------------------------------- /basicsr/archs/ridnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import ResidualBlockNoBN, make_layer 6 | 7 | 8 | class MeanShift(nn.Conv2d): 9 | """ Data normalization with mean and std. 10 | 11 | Args: 12 | rgb_range (int): Maximum value of RGB. 13 | rgb_mean (list[float]): Mean for RGB channels. 14 | rgb_std (list[float]): Std for RGB channels. 15 | sign (int): For subtraction, sign is -1, for addition, sign is 1. 16 | Default: -1. 17 | requires_grad (bool): Whether to update the self.weight and self.bias. 18 | Default: True. 19 | """ 20 | 21 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True): 22 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 23 | std = torch.Tensor(rgb_std) 24 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 25 | self.weight.data.div_(std.view(3, 1, 1, 1)) 26 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 27 | self.bias.data.div_(std) 28 | self.requires_grad = requires_grad 29 | 30 | 31 | class EResidualBlockNoBN(nn.Module): 32 | """Enhanced Residual block without BN. 33 | 34 | There are three convolution layers in residual branch. 35 | 36 | It has a style of: 37 | ---Conv-ReLU-Conv-ReLU-Conv-+-ReLU- 38 | |__________________________| 39 | """ 40 | 41 | def __init__(self, in_channels, out_channels): 42 | super(EResidualBlockNoBN, self).__init__() 43 | 44 | self.body = nn.Sequential( 45 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 46 | nn.ReLU(inplace=True), 47 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 48 | nn.ReLU(inplace=True), 49 | nn.Conv2d(out_channels, out_channels, 1, 1, 0), 50 | ) 51 | self.relu = nn.ReLU(inplace=True) 52 | 53 | def forward(self, x): 54 | out = self.body(x) 55 | out = self.relu(out + x) 56 | return out 57 | 58 | 59 | class MergeRun(nn.Module): 60 | """ Merge-and-run unit. 61 | 62 | This unit contains two branches with different dilated convolutions, 63 | followed by a convolution to process the concatenated features. 64 | 65 | Paper: Real Image Denoising with Feature Attention 66 | Ref git repo: https://github.com/saeed-anwar/RIDNet 67 | """ 68 | 69 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): 70 | super(MergeRun, self).__init__() 71 | 72 | self.dilation1 = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True), 74 | nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True)) 75 | self.dilation2 = nn.Sequential( 76 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True), 77 | nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True)) 78 | 79 | self.aggregation = nn.Sequential( 80 | nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True)) 81 | 82 | def forward(self, x): 83 | dilation1 = self.dilation1(x) 84 | dilation2 = self.dilation2(x) 85 | out = torch.cat([dilation1, dilation2], dim=1) 86 | out = self.aggregation(out) 87 | out = out + x 88 | return out 89 | 90 | 91 | class ChannelAttention(nn.Module): 92 | """Channel attention. 93 | 94 | Args: 95 | num_feat (int): Channel number of intermediate features. 96 | squeeze_factor (int): Channel squeeze factor. Default: 97 | """ 98 | 99 | def __init__(self, mid_channels, squeeze_factor=16): 100 | super(ChannelAttention, self).__init__() 101 | self.attention = nn.Sequential( 102 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0), 103 | nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid()) 104 | 105 | def forward(self, x): 106 | y = self.attention(x) 107 | return x * y 108 | 109 | 110 | class EAM(nn.Module): 111 | """Enhancement attention modules (EAM) in RIDNet. 112 | 113 | This module contains a merge-and-run unit, a residual block, 114 | an enhanced residual block and a feature attention unit. 115 | 116 | Attributes: 117 | merge: The merge-and-run unit. 118 | block1: The residual block. 119 | block2: The enhanced residual block. 120 | ca: The feature/channel attention unit. 121 | """ 122 | 123 | def __init__(self, in_channels, mid_channels, out_channels): 124 | super(EAM, self).__init__() 125 | 126 | self.merge = MergeRun(in_channels, mid_channels) 127 | self.block1 = ResidualBlockNoBN(mid_channels) 128 | self.block2 = EResidualBlockNoBN(mid_channels, out_channels) 129 | self.ca = ChannelAttention(out_channels) 130 | # The residual block in the paper contains a relu after addition. 131 | self.relu = nn.ReLU(inplace=True) 132 | 133 | def forward(self, x): 134 | out = self.merge(x) 135 | out = self.relu(self.block1(out)) 136 | out = self.block2(out) 137 | out = self.ca(out) 138 | return out 139 | 140 | 141 | @ARCH_REGISTRY.register() 142 | class RIDNet(nn.Module): 143 | """RIDNet: Real Image Denoising with Feature Attention. 144 | 145 | Ref git repo: https://github.com/saeed-anwar/RIDNet 146 | 147 | Args: 148 | in_channels (int): Channel number of inputs. 149 | mid_channels (int): Channel number of EAM modules. 150 | Default: 64. 151 | out_channels (int): Channel number of outputs. 152 | num_block (int): Number of EAM. Default: 4. 153 | img_range (float): Image range. Default: 255. 154 | rgb_mean (tuple[float]): Image mean in RGB orders. 155 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 156 | """ 157 | 158 | def __init__(self, 159 | in_channels, 160 | mid_channels, 161 | out_channels, 162 | num_block=4, 163 | img_range=255., 164 | rgb_mean=(0.4488, 0.4371, 0.4040), 165 | rgb_std=(1.0, 1.0, 1.0)): 166 | super(RIDNet, self).__init__() 167 | 168 | self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std) 169 | self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1) 170 | 171 | self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) 172 | self.body = make_layer( 173 | EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels) 174 | self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1) 175 | 176 | self.relu = nn.ReLU(inplace=True) 177 | 178 | def forward(self, x): 179 | res = self.sub_mean(x) 180 | res = self.tail(self.body(self.relu(self.head(res)))) 181 | res = self.add_mean(res) 182 | 183 | out = x + res 184 | return out 185 | -------------------------------------------------------------------------------- /basicsr/archs/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | 12 | Used in RRDB block in ESRGAN. 13 | 14 | Args: 15 | num_feat (int): Channel number of intermediate features. 16 | num_grow_ch (int): Channels for each growth. 17 | """ 18 | 19 | def __init__(self, num_feat=64, num_grow_ch=32): 20 | super(ResidualDenseBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 26 | 27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | 29 | # initialization 30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 31 | 32 | def forward(self, x): 33 | x1 = self.lrelu(self.conv1(x)) 34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 38 | # Emperically, we use 0.2 to scale the residual for better performance 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | """Residual in Residual Dense Block. 44 | 45 | Used in RRDB-Net in ESRGAN. 46 | 47 | Args: 48 | num_feat (int): Channel number of intermediate features. 49 | num_grow_ch (int): Channels for each growth. 50 | """ 51 | 52 | def __init__(self, num_feat, num_grow_ch=32): 53 | super(RRDB, self).__init__() 54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 57 | 58 | def forward(self, x): 59 | out = self.rdb1(x) 60 | out = self.rdb2(out) 61 | out = self.rdb3(out) 62 | # Emperically, we use 0.2 to scale the residual for better performance 63 | return out * 0.2 + x 64 | 65 | 66 | @ARCH_REGISTRY.register() 67 | class RRDBNet(nn.Module): 68 | """Networks consisting of Residual in Residual Dense Block, which is used 69 | in ESRGAN. 70 | 71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 72 | 73 | We extend ESRGAN for scale x2 and scale x1. 74 | Note: This is one option for scale 1, scale 2 in RRDBNet. 75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 77 | 78 | Args: 79 | num_in_ch (int): Channel number of inputs. 80 | num_out_ch (int): Channel number of outputs. 81 | num_feat (int): Channel number of intermediate features. 82 | Default: 64 83 | num_block (int): Block number in the trunk network. Defaults: 23 84 | num_grow_ch (int): Channels for each growth. Default: 32. 85 | """ 86 | 87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 88 | super(RRDBNet, self).__init__() 89 | self.scale = scale 90 | if scale == 2: 91 | num_in_ch = num_in_ch * 4 92 | elif scale == 1: 93 | num_in_ch = num_in_ch * 16 94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | # upsample 98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 102 | 103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 104 | 105 | def forward(self, x): 106 | if self.scale == 2: 107 | feat = pixel_unshuffle(x, scale=2) 108 | elif self.scale == 1: 109 | feat = pixel_unshuffle(x, scale=4) 110 | else: 111 | feat = x 112 | feat = self.conv_first(feat) 113 | body_feat = self.conv_body(self.body(feat)) 114 | feat = feat + body_feat 115 | # upsample 116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 118 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 119 | return out -------------------------------------------------------------------------------- /basicsr/archs/spynet_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | from .arch_util import flow_warp 8 | 9 | 10 | class BasicModule(nn.Module): 11 | """Basic Module for SpyNet. 12 | """ 13 | 14 | def __init__(self): 15 | super(BasicModule, self).__init__() 16 | 17 | self.basic_module = nn.Sequential( 18 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 19 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 20 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 21 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 22 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) 23 | 24 | def forward(self, tensor_input): 25 | return self.basic_module(tensor_input) 26 | 27 | 28 | @ARCH_REGISTRY.register() 29 | class SpyNet(nn.Module): 30 | """SpyNet architecture. 31 | 32 | Args: 33 | load_path (str): path for pretrained SpyNet. Default: None. 34 | """ 35 | 36 | def __init__(self, load_path=None): 37 | super(SpyNet, self).__init__() 38 | self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) 39 | if load_path: 40 | self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) 41 | 42 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 43 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 44 | 45 | def preprocess(self, tensor_input): 46 | tensor_output = (tensor_input - self.mean) / self.std 47 | return tensor_output 48 | 49 | def process(self, ref, supp): 50 | flow = [] 51 | 52 | ref = [self.preprocess(ref)] 53 | supp = [self.preprocess(supp)] 54 | 55 | for level in range(5): 56 | ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) 57 | supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) 58 | 59 | flow = ref[0].new_zeros( 60 | [ref[0].size(0), 2, 61 | int(math.floor(ref[0].size(2) / 2.0)), 62 | int(math.floor(ref[0].size(3) / 2.0))]) 63 | 64 | for level in range(len(ref)): 65 | upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 66 | 67 | if upsampled_flow.size(2) != ref[level].size(2): 68 | upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') 69 | if upsampled_flow.size(3) != ref[level].size(3): 70 | upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') 71 | 72 | flow = self.basic_module[level](torch.cat([ 73 | ref[level], 74 | flow_warp( 75 | supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), 76 | upsampled_flow 77 | ], 1)) + upsampled_flow 78 | 79 | return flow 80 | 81 | def forward(self, ref, supp): 82 | assert ref.size() == supp.size() 83 | 84 | h, w = ref.size(2), ref.size(3) 85 | w_floor = math.floor(math.ceil(w / 32.0) * 32.0) 86 | h_floor = math.floor(math.ceil(h / 32.0) * 32.0) 87 | 88 | ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) 89 | supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) 90 | 91 | flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False) 92 | 93 | flow[:, 0, :, :] *= float(w) / float(w_floor) 94 | flow[:, 1, :, :] *= float(h) / float(h_floor) 95 | 96 | return flow 97 | -------------------------------------------------------------------------------- /basicsr/archs/srresnet_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class MSRResNet(nn.Module): 10 | """Modified SRResNet. 11 | 12 | A compacted version modified from SRResNet in 13 | "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" 14 | It uses residual blocks without BN, similar to EDSR. 15 | Currently, it supports x2, x3 and x4 upsampling scale factor. 16 | 17 | Args: 18 | num_in_ch (int): Channel number of inputs. Default: 3. 19 | num_out_ch (int): Channel number of outputs. Default: 3. 20 | num_feat (int): Channel number of intermediate features. Default: 64. 21 | num_block (int): Block number in the body network. Default: 16. 22 | upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4. 23 | """ 24 | 25 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4): 26 | super(MSRResNet, self).__init__() 27 | self.upscale = upscale 28 | 29 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 30 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat) 31 | 32 | # upsampling 33 | if self.upscale in [2, 3]: 34 | self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1) 35 | self.pixel_shuffle = nn.PixelShuffle(self.upscale) 36 | elif self.upscale == 4: 37 | self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 38 | self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 39 | self.pixel_shuffle = nn.PixelShuffle(2) 40 | 41 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 42 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 43 | 44 | # activation function 45 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 46 | 47 | # initialization 48 | default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1) 49 | if self.upscale == 4: 50 | default_init_weights(self.upconv2, 0.1) 51 | 52 | def forward(self, x): 53 | feat = self.lrelu(self.conv_first(x)) 54 | out = self.body(feat) 55 | 56 | if self.upscale == 4: 57 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 58 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 59 | elif self.upscale in [2, 3]: 60 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 61 | 62 | out = self.conv_last(self.lrelu(self.conv_hr(out))) 63 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 64 | out += base 65 | return out 66 | -------------------------------------------------------------------------------- /basicsr/archs/srvgg_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | 6 | 7 | @ARCH_REGISTRY.register(suffix='basicsr') 8 | class SRVGGNetCompact(nn.Module): 9 | """A compact VGG-style network structure for super-resolution. 10 | 11 | It is a compact network structure, which performs upsampling in the last layer and no convolution is 12 | conducted on the HR feature space. 13 | 14 | Args: 15 | num_in_ch (int): Channel number of inputs. Default: 3. 16 | num_out_ch (int): Channel number of outputs. Default: 3. 17 | num_feat (int): Channel number of intermediate features. Default: 64. 18 | num_conv (int): Number of convolution layers in the body network. Default: 16. 19 | upscale (int): Upsampling factor. Default: 4. 20 | act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. 21 | """ 22 | 23 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): 24 | super(SRVGGNetCompact, self).__init__() 25 | self.num_in_ch = num_in_ch 26 | self.num_out_ch = num_out_ch 27 | self.num_feat = num_feat 28 | self.num_conv = num_conv 29 | self.upscale = upscale 30 | self.act_type = act_type 31 | 32 | self.body = nn.ModuleList() 33 | # the first conv 34 | self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) 35 | # the first activation 36 | if act_type == 'relu': 37 | activation = nn.ReLU(inplace=True) 38 | elif act_type == 'prelu': 39 | activation = nn.PReLU(num_parameters=num_feat) 40 | elif act_type == 'leakyrelu': 41 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 42 | self.body.append(activation) 43 | 44 | # the body structure 45 | for _ in range(num_conv): 46 | self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) 47 | # activation 48 | if act_type == 'relu': 49 | activation = nn.ReLU(inplace=True) 50 | elif act_type == 'prelu': 51 | activation = nn.PReLU(num_parameters=num_feat) 52 | elif act_type == 'leakyrelu': 53 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 54 | self.body.append(activation) 55 | 56 | # the last conv 57 | self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) 58 | # upsample 59 | self.upsampler = nn.PixelShuffle(upscale) 60 | 61 | def forward(self, x): 62 | out = x 63 | for i in range(0, len(self.body)): 64 | out = self.body[i](out) 65 | 66 | out = self.upsampler(out) 67 | # add the nearest upsampled image, so that the network learns the residual 68 | base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') 69 | out += base 70 | return out 71 | -------------------------------------------------------------------------------- /basicsr/archs/tof_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from .arch_util import flow_warp 7 | 8 | 9 | class BasicModule(nn.Module): 10 | """Basic module of SPyNet. 11 | 12 | Note that unlike the architecture in spynet_arch.py, the basic module 13 | here contains batch normalization. 14 | """ 15 | 16 | def __init__(self): 17 | super(BasicModule, self).__init__() 18 | self.basic_module = nn.Sequential( 19 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False), 20 | nn.BatchNorm2d(32), nn.ReLU(inplace=True), 21 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False), 22 | nn.BatchNorm2d(64), nn.ReLU(inplace=True), 23 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False), 24 | nn.BatchNorm2d(32), nn.ReLU(inplace=True), 25 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False), 26 | nn.BatchNorm2d(16), nn.ReLU(inplace=True), 27 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) 28 | 29 | def forward(self, tensor_input): 30 | """ 31 | Args: 32 | tensor_input (Tensor): Input tensor with shape (b, 8, h, w). 33 | 8 channels contain: 34 | [reference image (3), neighbor image (3), initial flow (2)]. 35 | 36 | Returns: 37 | Tensor: Estimated flow with shape (b, 2, h, w) 38 | """ 39 | return self.basic_module(tensor_input) 40 | 41 | 42 | class SPyNetTOF(nn.Module): 43 | """SPyNet architecture for TOF. 44 | 45 | Note that this implementation is specifically for TOFlow. Please use 46 | spynet_arch.py for general use. They differ in the following aspects: 47 | 1. The basic modules here contain BatchNorm. 48 | 2. Normalization and denormalization are not done here, as 49 | they are done in TOFlow. 50 | Paper: 51 | Optical Flow Estimation using a Spatial Pyramid Network 52 | Code reference: 53 | https://github.com/Coldog2333/pytoflow 54 | 55 | Args: 56 | load_path (str): Path for pretrained SPyNet. Default: None. 57 | """ 58 | 59 | def __init__(self, load_path=None): 60 | super(SPyNetTOF, self).__init__() 61 | 62 | self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)]) 63 | if load_path: 64 | self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) 65 | 66 | def forward(self, ref, supp): 67 | """ 68 | Args: 69 | ref (Tensor): Reference image with shape of (b, 3, h, w). 70 | supp: The supporting image to be warped: (b, 3, h, w). 71 | 72 | Returns: 73 | Tensor: Estimated optical flow: (b, 2, h, w). 74 | """ 75 | num_batches, _, h, w = ref.size() 76 | ref = [ref] 77 | supp = [supp] 78 | 79 | # generate downsampled frames 80 | for _ in range(3): 81 | ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) 82 | supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) 83 | 84 | # flow computation 85 | flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16) 86 | for i in range(4): 87 | flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 88 | flow = flow_up + self.basic_module[i]( 89 | torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1)) 90 | return flow 91 | 92 | 93 | @ARCH_REGISTRY.register() 94 | class TOFlow(nn.Module): 95 | """PyTorch implementation of TOFlow. 96 | 97 | In TOFlow, the LR frames are pre-upsampled and have the same size with 98 | the GT frames. 99 | Paper: 100 | Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 101 | Code reference: 102 | 1. https://github.com/anchen1011/toflow 103 | 2. https://github.com/Coldog2333/pytoflow 104 | 105 | Args: 106 | adapt_official_weights (bool): Whether to adapt the weights translated 107 | from the official implementation. Set to false if you want to 108 | train from scratch. Default: False 109 | """ 110 | 111 | def __init__(self, adapt_official_weights=False): 112 | super(TOFlow, self).__init__() 113 | self.adapt_official_weights = adapt_official_weights 114 | self.ref_idx = 0 if adapt_official_weights else 3 115 | 116 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 117 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 118 | 119 | # flow estimation module 120 | self.spynet = SPyNetTOF() 121 | 122 | # reconstruction module 123 | self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4) 124 | self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4) 125 | self.conv_3 = nn.Conv2d(64, 64, 1) 126 | self.conv_4 = nn.Conv2d(64, 3, 1) 127 | 128 | # activation function 129 | self.relu = nn.ReLU(inplace=True) 130 | 131 | def normalize(self, img): 132 | return (img - self.mean) / self.std 133 | 134 | def denormalize(self, img): 135 | return img * self.std + self.mean 136 | 137 | def forward(self, lrs): 138 | """ 139 | Args: 140 | lrs: Input lr frames: (b, 7, 3, h, w). 141 | 142 | Returns: 143 | Tensor: SR frame: (b, 3, h, w). 144 | """ 145 | # In the official implementation, the 0-th frame is the reference frame 146 | if self.adapt_official_weights: 147 | lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] 148 | 149 | num_batches, num_lrs, _, h, w = lrs.size() 150 | 151 | lrs = self.normalize(lrs.view(-1, 3, h, w)) 152 | lrs = lrs.view(num_batches, num_lrs, 3, h, w) 153 | 154 | lr_ref = lrs[:, self.ref_idx, :, :, :] 155 | lr_aligned = [] 156 | for i in range(7): # 7 frames 157 | if i == self.ref_idx: 158 | lr_aligned.append(lr_ref) 159 | else: 160 | lr_supp = lrs[:, i, :, :, :] 161 | flow = self.spynet(lr_ref, lr_supp) 162 | lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1))) 163 | 164 | # reconstruction 165 | hr = torch.stack(lr_aligned, dim=1) 166 | hr = hr.view(num_batches, -1, h, w) 167 | hr = self.relu(self.conv_1(hr)) 168 | hr = self.relu(self.conv_2(hr)) 169 | hr = self.relu(self.conv_3(hr)) 170 | hr = self.conv_4(hr) + lr_ref 171 | 172 | return self.denormalize(hr) 173 | -------------------------------------------------------------------------------- /basicsr/archs/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch import nn as nn 5 | from torchvision.models import vgg as vgg 6 | 7 | from basicsr.utils.registry import ARCH_REGISTRY 8 | 9 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' 10 | NAMES = { 11 | 'vgg11': [ 12 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 13 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 14 | 'pool5' 15 | ], 16 | 'vgg13': [ 17 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 18 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 19 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' 20 | ], 21 | 'vgg16': [ 22 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 23 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 24 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 25 | 'pool5' 26 | ], 27 | 'vgg19': [ 28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 29 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 30 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 31 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' 32 | ] 33 | } 34 | 35 | 36 | def insert_bn(names): 37 | """Insert bn layer after each conv. 38 | 39 | Args: 40 | names (list): The list of layer names. 41 | 42 | Returns: 43 | list: The list of layer names with bn layers. 44 | """ 45 | names_bn = [] 46 | for name in names: 47 | names_bn.append(name) 48 | if 'conv' in name: 49 | position = name.replace('conv', '') 50 | names_bn.append('bn' + position) 51 | return names_bn 52 | 53 | 54 | @ARCH_REGISTRY.register() 55 | class VGGFeatureExtractor(nn.Module): 56 | """VGG network for feature extraction. 57 | 58 | In this implementation, we allow users to choose whether use normalization 59 | in the input feature and the type of vgg network. Note that the pretrained 60 | path must fit the vgg type. 61 | 62 | Args: 63 | layer_name_list (list[str]): Forward function returns the corresponding 64 | features according to the layer_name_list. 65 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 66 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 67 | use_input_norm (bool): If True, normalize the input image. Importantly, 68 | the input feature must in the range [0, 1]. Default: True. 69 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 70 | Default: False. 71 | requires_grad (bool): If true, the parameters of VGG network will be 72 | optimized. Default: False. 73 | remove_pooling (bool): If true, the max pooling operations in VGG net 74 | will be removed. Default: False. 75 | pooling_stride (int): The stride of max pooling operation. Default: 2. 76 | """ 77 | 78 | def __init__(self, 79 | layer_name_list, 80 | vgg_type='vgg19', 81 | use_input_norm=True, 82 | range_norm=False, 83 | requires_grad=False, 84 | remove_pooling=False, 85 | pooling_stride=2): 86 | super(VGGFeatureExtractor, self).__init__() 87 | 88 | self.layer_name_list = layer_name_list 89 | self.use_input_norm = use_input_norm 90 | self.range_norm = range_norm 91 | 92 | self.names = NAMES[vgg_type.replace('_bn', '')] 93 | if 'bn' in vgg_type: 94 | self.names = insert_bn(self.names) 95 | 96 | # only borrow layers that will be used to avoid unused params 97 | max_idx = 0 98 | for v in layer_name_list: 99 | idx = self.names.index(v) 100 | if idx > max_idx: 101 | max_idx = idx 102 | 103 | if os.path.exists(VGG_PRETRAIN_PATH): 104 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 105 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) 106 | vgg_net.load_state_dict(state_dict) 107 | else: 108 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 109 | 110 | features = vgg_net.features[:max_idx + 1] 111 | 112 | modified_net = OrderedDict() 113 | for k, v in zip(self.names, features): 114 | if 'pool' in k: 115 | # if remove_pooling is true, pooling operation will be removed 116 | if remove_pooling: 117 | continue 118 | else: 119 | # in some cases, we may want to change the default stride 120 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 121 | else: 122 | modified_net[k] = v 123 | 124 | self.vgg_net = nn.Sequential(modified_net) 125 | 126 | if not requires_grad: 127 | self.vgg_net.eval() 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | else: 131 | self.vgg_net.train() 132 | for param in self.parameters(): 133 | param.requires_grad = True 134 | 135 | if self.use_input_norm: 136 | # the mean is for image with range [0, 1] 137 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 138 | # the std is for image with range [0, 1] 139 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 140 | 141 | def forward(self, x): 142 | """Forward function. 143 | 144 | Args: 145 | x (Tensor): Input tensor with shape (n, c, h, w). 146 | 147 | Returns: 148 | Tensor: Forward results. 149 | """ 150 | if self.range_norm: 151 | x = (x + 1) / 2 152 | if self.use_input_norm: 153 | x = (x - self.mean) / self.std 154 | output = {} 155 | 156 | for key, layer in self.vgg_net._modules.items(): 157 | x = layer(x) 158 | if key in self.layer_name_list: 159 | output[key] = x.clone() 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from os import path as osp 4 | from torch.utils import data as data 5 | from torchvision.transforms.functional import normalize 6 | 7 | from basicsr.data.transforms import augment 8 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 9 | from basicsr.utils.registry import DATASET_REGISTRY 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class FFHQDataset(data.Dataset): 14 | """FFHQ dataset for StyleGAN. 15 | 16 | Args: 17 | opt (dict): Config for train datasets. It contains the following keys: 18 | dataroot_gt (str): Data root path for gt. 19 | io_backend (dict): IO backend type and other kwarg. 20 | mean (list | tuple): Image mean. 21 | std (list | tuple): Image std. 22 | use_hflip (bool): Whether to horizontally flip. 23 | 24 | """ 25 | 26 | def __init__(self, opt): 27 | super(FFHQDataset, self).__init__() 28 | self.opt = opt 29 | # file client (io backend) 30 | self.file_client = None 31 | self.io_backend_opt = opt['io_backend'] 32 | 33 | self.gt_folder = opt['dataroot_gt'] 34 | self.mean = opt['mean'] 35 | self.std = opt['std'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = self.gt_folder 39 | if not self.gt_folder.endswith('.lmdb'): 40 | raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 41 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 42 | self.paths = [line.split('.')[0] for line in fin] 43 | else: 44 | # FFHQ has 70000 images in total 45 | self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)] 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | # avoid errors caused by high latency in reading files 54 | retry = 3 55 | while retry > 0: 56 | try: 57 | img_bytes = self.file_client.get(gt_path) 58 | except Exception as e: 59 | logger = get_root_logger() 60 | logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') 61 | # change another file to read 62 | index = random.randint(0, self.__len__()) 63 | gt_path = self.paths[index] 64 | time.sleep(1) # sleep 1s for occasional server congestion 65 | else: 66 | break 67 | finally: 68 | retry -= 1 69 | img_gt = imfrombytes(img_bytes, float32=True) 70 | 71 | # random horizontal flip 72 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 73 | # BGR to RGB, HWC to CHW, numpy to tensor 74 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 75 | # normalize 76 | normalize(img_gt, self.mean, self.std, inplace=True) 77 | return {'gt': img_gt, 'gt_path': gt_path} 78 | 79 | def __len__(self): 80 | return len(self.paths) 81 | -------------------------------------------------------------------------------- /basicsr/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | 4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file 5 | from basicsr.data.transforms import augment, paired_random_crop 6 | from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class PairedImageDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 15 | 16 | There are three modes: 17 | 1. 'lmdb': Use lmdb files. 18 | If opt['io_backend'] == lmdb. 19 | 2. 'meta_info_file': Use meta information file to generate paths. 20 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 21 | 3. 'folder': Scan folders to generate paths. 22 | The rest. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_gt (str): Data root path for gt. 27 | dataroot_lq (str): Data root path for lq. 28 | meta_info_file (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 31 | Default: '{}'. 32 | gt_size (int): Cropped patched size for gt patches. 33 | use_hflip (bool): Use horizontal flips. 34 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 35 | 36 | scale (bool): Scale, which will be added automatically. 37 | phase (str): 'train' or 'val'. 38 | """ 39 | 40 | def __init__(self, opt): 41 | super(PairedImageDataset, self).__init__() 42 | self.opt = opt 43 | # file client (io backend) 44 | self.file_client = None 45 | self.io_backend_opt = opt['io_backend'] 46 | self.mean = opt['mean'] if 'mean' in opt else None 47 | self.std = opt['std'] if 'std' in opt else None 48 | 49 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 50 | if 'filename_tmpl' in opt: 51 | self.filename_tmpl = opt['filename_tmpl'] 52 | else: 53 | self.filename_tmpl = '{}' 54 | 55 | if self.io_backend_opt['type'] == 'lmdb': 56 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 57 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 58 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 59 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: 60 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], 61 | self.opt['meta_info_file'], self.filename_tmpl) 62 | else: 63 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 64 | 65 | def __getitem__(self, index): 66 | if self.file_client is None: 67 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 68 | 69 | scale = self.opt['scale'] 70 | 71 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 72 | # image range: [0, 1], float32. 73 | gt_path = self.paths[index]['gt_path'] 74 | img_bytes = self.file_client.get(gt_path, 'gt') 75 | img_gt = imfrombytes(img_bytes, float32=True) 76 | lq_path = self.paths[index]['lq_path'] 77 | img_bytes = self.file_client.get(lq_path, 'lq') 78 | img_lq = imfrombytes(img_bytes, float32=True) 79 | 80 | # augmentation for training 81 | if self.opt['phase'] == 'train': 82 | gt_size = self.opt['gt_size'] 83 | # random crop 84 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 85 | # flip, rotation 86 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 87 | 88 | # color space transform 89 | if 'color' in self.opt and self.opt['color'] == 'y': 90 | img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] 91 | img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] 92 | 93 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets 94 | # TODO: It is better to update the datasets, rather than force to crop 95 | if self.opt['phase'] != 'train': 96 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] 97 | 98 | # BGR to RGB, HWC to CHW, numpy to tensor 99 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 100 | # normalize 101 | if self.mean is not None or self.std is not None: 102 | normalize(img_lq, self.mean, self.std, inplace=True) 103 | normalize(img_gt, self.mean, self.std, inplace=True) 104 | 105 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 106 | 107 | def __len__(self): 108 | return len(self.paths) 109 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 116 | 117 | def next(self): 118 | torch.cuda.current_stream().wait_stream(self.stream) 119 | batch = self.batch 120 | self.preload() 121 | return batch 122 | 123 | def reset(self): 124 | self.loader = iter(self.ori_loader) 125 | self.preload() 126 | -------------------------------------------------------------------------------- /basicsr/data/realesrgan_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, imfrombytes, img2tensor 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | 10 | 11 | @DATASET_REGISTRY.register(suffix='basicsr') 12 | class RealESRGANPairedDataset(data.Dataset): 13 | """Paired image dataset for image restoration. 14 | 15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 16 | 17 | There are three modes: 18 | 1. 'lmdb': Use lmdb files. 19 | If opt['io_backend'] == lmdb. 20 | 2. 'meta_info': Use meta information file to generate paths. 21 | If opt['io_backend'] != lmdb and opt['meta_info'] is not None. 22 | 3. 'folder': Scan folders to generate paths. 23 | The rest. 24 | 25 | Args: 26 | opt (dict): Config for train datasets. It contains the following keys: 27 | dataroot_gt (str): Data root path for gt. 28 | dataroot_lq (str): Data root path for lq. 29 | meta_info (str): Path for meta information file. 30 | io_backend (dict): IO backend type and other kwarg. 31 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 32 | Default: '{}'. 33 | gt_size (int): Cropped patched size for gt patches. 34 | use_hflip (bool): Use horizontal flips. 35 | use_rot (bool): Use rotation (use vertical flip and transposing h 36 | and w for implementation). 37 | 38 | scale (bool): Scale, which will be added automatically. 39 | phase (str): 'train' or 'val'. 40 | """ 41 | 42 | def __init__(self, opt): 43 | super(RealESRGANPairedDataset, self).__init__() 44 | self.opt = opt 45 | self.file_client = None 46 | self.io_backend_opt = opt['io_backend'] 47 | # mean and std for normalizing the input images 48 | self.mean = opt['mean'] if 'mean' in opt else None 49 | self.std = opt['std'] if 'std' in opt else None 50 | 51 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 52 | self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' 53 | 54 | # file client (lmdb io backend) 55 | if self.io_backend_opt['type'] == 'lmdb': 56 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 57 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 58 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 59 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: 60 | # disk backend with meta_info 61 | # Each line in the meta_info describes the relative path to an image 62 | with open(self.opt['meta_info']) as fin: 63 | paths = [line.strip() for line in fin] 64 | self.paths = [] 65 | for path in paths: 66 | gt_path, lq_path = path.split(', ') 67 | gt_path = os.path.join(self.gt_folder, gt_path) 68 | lq_path = os.path.join(self.lq_folder, lq_path) 69 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) 70 | else: 71 | # disk backend 72 | # it will scan the whole folder to get meta info 73 | # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file 74 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 75 | 76 | def __getitem__(self, index): 77 | if self.file_client is None: 78 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 79 | 80 | scale = self.opt['scale'] 81 | 82 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 83 | # image range: [0, 1], float32. 84 | gt_path = self.paths[index]['gt_path'] 85 | img_bytes = self.file_client.get(gt_path, 'gt') 86 | img_gt = imfrombytes(img_bytes, float32=True) 87 | lq_path = self.paths[index]['lq_path'] 88 | img_bytes = self.file_client.get(lq_path, 'lq') 89 | img_lq = imfrombytes(img_bytes, float32=True) 90 | 91 | # augmentation for training 92 | if self.opt['phase'] == 'train': 93 | gt_size = self.opt['gt_size'] 94 | # random crop 95 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 96 | # flip, rotation 97 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 98 | 99 | # BGR to RGB, HWC to CHW, numpy to tensor 100 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 101 | # normalize 102 | if self.mean is not None or self.std is not None: 103 | normalize(img_lq, self.mean, self.std, inplace=True) 104 | normalize(img_gt, self.mean, self.std, inplace=True) 105 | 106 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 107 | 108 | def __len__(self): 109 | return len(self.paths) 110 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SingleImageDataset(data.Dataset): 12 | """Read only lq images in the test phase. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 15 | 16 | There are two modes: 17 | 1. 'meta_info_file': Use meta information file to generate paths. 18 | 2. 'folder': Scan folders to generate paths. 19 | 20 | Args: 21 | opt (dict): Config for train datasets. It contains the following keys: 22 | dataroot_lq (str): Data root path for lq. 23 | meta_info_file (str): Path for meta information file. 24 | io_backend (dict): IO backend type and other kwarg. 25 | """ 26 | 27 | def __init__(self, opt): 28 | super(SingleImageDataset, self).__init__() 29 | self.opt = opt 30 | # file client (io backend) 31 | self.file_client = None 32 | self.io_backend_opt = opt['io_backend'] 33 | self.mean = opt['mean'] if 'mean' in opt else None 34 | self.std = opt['std'] if 'std' in opt else None 35 | self.lq_folder = opt['dataroot_lq'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = [self.lq_folder] 39 | self.io_backend_opt['client_keys'] = ['lq'] 40 | self.paths = paths_from_lmdb(self.lq_folder) 41 | elif 'meta_info_file' in self.opt: 42 | with open(self.opt['meta_info_file'], 'r') as fin: 43 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] 44 | else: 45 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load lq image 52 | lq_path = self.paths[index] 53 | img_bytes = self.file_client.get(lq_path, 'lq') 54 | img_lq = imfrombytes(img_bytes, float32=True) 55 | 56 | # color space transform 57 | if 'color' in self.opt and self.opt['color'] == 'y': 58 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 59 | 60 | # BGR to RGB, HWC to CHW, numpy to tensor 61 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 62 | # normalize 63 | if self.mean is not None or self.std is not None: 64 | normalize(img_lq, self.mean, self.std, inplace=True) 65 | return {'lq': img_lq, 'lq_path': lq_path} 66 | 67 | def __len__(self): 68 | return len(self.paths) 69 | -------------------------------------------------------------------------------- /basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import torch 4 | 5 | 6 | def mod_crop(img, scale): 7 | """Mod crop images, used during testing. 8 | 9 | Args: 10 | img (ndarray): Input image. 11 | scale (int): Scale factor. 12 | 13 | Returns: 14 | ndarray: Result image. 15 | """ 16 | img = img.copy() 17 | if img.ndim in (2, 3): 18 | h, w = img.shape[0], img.shape[1] 19 | h_remainder, w_remainder = h % scale, w % scale 20 | img = img[:h - h_remainder, :w - w_remainder, ...] 21 | else: 22 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 23 | return img 24 | 25 | 26 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): 27 | """Paired random crop. Support Numpy array and Tensor inputs. 28 | 29 | It crops lists of lq and gt images with corresponding locations. 30 | 31 | Args: 32 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images 33 | should have the same shape. If the input is an ndarray, it will 34 | be transformed to a list containing itself. 35 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 36 | should have the same shape. If the input is an ndarray, it will 37 | be transformed to a list containing itself. 38 | gt_patch_size (int): GT patch size. 39 | scale (int): Scale factor. 40 | gt_path (str): Path to ground-truth. Default: None. 41 | 42 | Returns: 43 | list[ndarray] | ndarray: GT images and LQ images. If returned results 44 | only have one element, just return ndarray. 45 | """ 46 | 47 | if not isinstance(img_gts, list): 48 | img_gts = [img_gts] 49 | if not isinstance(img_lqs, list): 50 | img_lqs = [img_lqs] 51 | 52 | # determine input type: Numpy array or Tensor 53 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' 54 | 55 | if input_type == 'Tensor': 56 | h_lq, w_lq = img_lqs[0].size()[-2:] 57 | h_gt, w_gt = img_gts[0].size()[-2:] 58 | else: 59 | h_lq, w_lq = img_lqs[0].shape[0:2] 60 | h_gt, w_gt = img_gts[0].shape[0:2] 61 | lq_patch_size = gt_patch_size // scale 62 | 63 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 64 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 65 | f'multiplication of LQ ({h_lq}, {w_lq}).') 66 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 67 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 68 | f'({lq_patch_size}, {lq_patch_size}). ' 69 | f'Please remove {gt_path}.') 70 | 71 | # randomly choose top and left coordinates for lq patch 72 | top = random.randint(0, h_lq - lq_patch_size) 73 | left = random.randint(0, w_lq - lq_patch_size) 74 | 75 | # crop lq patch 76 | if input_type == 'Tensor': 77 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] 78 | else: 79 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 80 | 81 | # crop corresponding gt patch 82 | top_gt, left_gt = int(top * scale), int(left * scale) 83 | if input_type == 'Tensor': 84 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] 85 | else: 86 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 87 | if len(img_gts) == 1: 88 | img_gts = img_gts[0] 89 | if len(img_lqs) == 1: 90 | img_lqs = img_lqs[0] 91 | return img_gts, img_lqs 92 | 93 | 94 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 95 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 96 | 97 | We use vertical flip and transpose for rotation implementation. 98 | All the images in the list use the same augmentation. 99 | 100 | Args: 101 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 102 | is an ndarray, it will be transformed to a list. 103 | hflip (bool): Horizontal flip. Default: True. 104 | rotation (bool): Ratotation. Default: True. 105 | flows (list[ndarray]: Flows to be augmented. If the input is an 106 | ndarray, it will be transformed to a list. 107 | Dimension is (h, w, 2). Default: None. 108 | return_status (bool): Return the status of flip and rotation. 109 | Default: False. 110 | 111 | Returns: 112 | list[ndarray] | ndarray: Augmented images and flows. If returned 113 | results only have one element, just return ndarray. 114 | 115 | """ 116 | hflip = hflip and random.random() < 0.5 117 | vflip = rotation and random.random() < 0.5 118 | rot90 = rotation and random.random() < 0.5 119 | 120 | def _augment(img): 121 | if hflip: # horizontal 122 | cv2.flip(img, 1, img) 123 | if vflip: # vertical 124 | cv2.flip(img, 0, img) 125 | if rot90: 126 | img = img.transpose(1, 0, 2) 127 | return img 128 | 129 | def _augment_flow(flow): 130 | if hflip: # horizontal 131 | cv2.flip(flow, 1, flow) 132 | flow[:, :, 0] *= -1 133 | if vflip: # vertical 134 | cv2.flip(flow, 0, flow) 135 | flow[:, :, 1] *= -1 136 | if rot90: 137 | flow = flow.transpose(1, 0, 2) 138 | flow = flow[:, :, [1, 0]] 139 | return flow 140 | 141 | if not isinstance(imgs, list): 142 | imgs = [imgs] 143 | imgs = [_augment(img) for img in imgs] 144 | if len(imgs) == 1: 145 | imgs = imgs[0] 146 | 147 | if flows is not None: 148 | if not isinstance(flows, list): 149 | flows = [flows] 150 | flows = [_augment_flow(flow) for flow in flows] 151 | if len(flows) == 1: 152 | flows = flows[0] 153 | return imgs, flows 154 | else: 155 | if return_status: 156 | return imgs, (hflip, vflip, rot90) 157 | else: 158 | return imgs 159 | 160 | 161 | def img_rotate(img, angle, center=None, scale=1.0): 162 | """Rotate image. 163 | 164 | Args: 165 | img (ndarray): Image to be rotated. 166 | angle (float): Rotation angle in degrees. Positive values mean 167 | counter-clockwise rotation. 168 | center (tuple[int]): Rotation center. If the center is None, 169 | initialize it as the center of the image. Default: None. 170 | scale (float): Isotropic scale factor. Default: 1.0. 171 | """ 172 | (h, w) = img.shape[:2] 173 | 174 | if center is None: 175 | center = (w // 2, h // 2) 176 | 177 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 178 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 179 | return rotated_img 180 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import LOSS_REGISTRY 7 | from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty 8 | 9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] 10 | 11 | # automatically scan and import loss modules for registry 12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 13 | loss_folder = osp.dirname(osp.abspath(__file__)) 14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 15 | # import all the loss modules 16 | _model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] 17 | 18 | 19 | def build_loss(opt): 20 | """Build loss from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | loss_type = opt.pop('type') 28 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 29 | logger = get_root_logger() 30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 31 | return loss 32 | -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are 'none', 'mean' and 'sum'. 12 | 13 | Returns: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | reduction_enum = F._Reduction.get_enum(reduction) 17 | # none: 0, elementwise_mean:1, sum: 2 18 | if reduction_enum == 0: 19 | return loss 20 | elif reduction_enum == 1: 21 | return loss.mean() 22 | else: 23 | return loss.sum() 24 | 25 | 26 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. Default: None. 32 | reduction (str): Same as built-in losses of PyTorch. Options are 33 | 'none', 'mean' and 'sum'. Default: 'mean'. 34 | 35 | Returns: 36 | Tensor: Loss values. 37 | """ 38 | # if weight is specified, apply element-wise weight 39 | if weight is not None: 40 | assert weight.dim() == loss.dim() 41 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 42 | loss = loss * weight 43 | 44 | # if weight is not specified or reduction is sum, just reduce the loss 45 | if weight is None or reduction == 'sum': 46 | loss = reduce_loss(loss, reduction) 47 | # if reduction is mean, then compute mean over weight region 48 | elif reduction == 'mean': 49 | if weight.size(1) > 1: 50 | weight = weight.sum() 51 | else: 52 | weight = weight.sum() * loss.size(1) 53 | loss = loss.sum() / weight 54 | 55 | return loss 56 | 57 | 58 | def weighted_loss(loss_func): 59 | """Create a weighted version of a given loss function. 60 | 61 | To use this decorator, the loss function must have the signature like 62 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 63 | element-wise loss without any reduction. This decorator will add weight 64 | and reduction arguments to the function. The decorated function will have 65 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 66 | **kwargs)`. 67 | 68 | :Example: 69 | 70 | >>> import torch 71 | >>> @weighted_loss 72 | >>> def l1_loss(pred, target): 73 | >>> return (pred - target).abs() 74 | 75 | >>> pred = torch.Tensor([0, 2, 3]) 76 | >>> target = torch.Tensor([1, 1, 1]) 77 | >>> weight = torch.Tensor([1, 0, 1]) 78 | 79 | >>> l1_loss(pred, target) 80 | tensor(1.3333) 81 | >>> l1_loss(pred, target, weight) 82 | tensor(1.5000) 83 | >>> l1_loss(pred, target, reduction='none') 84 | tensor([1., 1., 2.]) 85 | >>> l1_loss(pred, target, weight, reduction='sum') 86 | tensor(3.) 87 | """ 88 | 89 | @functools.wraps(loss_func) 90 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 91 | # get element-wise loss 92 | loss = loss_func(pred, target, **kwargs) 93 | loss = weight_reduce_loss(loss, weight, reduction) 94 | return loss 95 | 96 | return wrapper 97 | 98 | 99 | def get_local_weights(residual, ksize): 100 | """Get local weights for generating the artifact map of LDL. 101 | 102 | It is only called by the `get_refined_artifact_map` function. 103 | 104 | Args: 105 | residual (Tensor): Residual between predicted and ground truth images. 106 | ksize (Int): size of the local window. 107 | 108 | Returns: 109 | Tensor: weight for each pixel to be discriminated as an artifact pixel 110 | """ 111 | 112 | pad = (ksize - 1) // 2 113 | residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') 114 | 115 | unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) 116 | pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) 117 | 118 | return pixel_level_weight 119 | 120 | 121 | def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): 122 | """Calculate the artifact map of LDL 123 | (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) 124 | 125 | Args: 126 | img_gt (Tensor): ground truth images. 127 | img_output (Tensor): output images given by the optimizing model. 128 | img_ema (Tensor): output images given by the ema model. 129 | ksize (Int): size of the local window. 130 | 131 | Returns: 132 | overall_weight: weight for each pixel to be discriminated as an artifact pixel 133 | (calculated based on both local and global observations). 134 | """ 135 | 136 | residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) 137 | residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) 138 | 139 | patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) 140 | pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) 141 | overall_weight = patch_level_weight * pixel_level_weight 142 | 143 | overall_weight[residual_sr < residual_ema] = 0 144 | 145 | return overall_weight 146 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .niqe import calculate_niqe 5 | from .psnr_ssim import calculate_psnr, calculate_ssim 6 | 7 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 8 | 9 | 10 | def calculate_metric(data, opt): 11 | """Calculate metric from data and options. 12 | 13 | Args: 14 | opt (dict): Configuration. It must contain: 15 | type (str): Model type. 16 | """ 17 | opt = deepcopy(opt) 18 | metric_type = opt.pop('type') 19 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 20 | return metric 21 | -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): 11 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 12 | # does resize the input. 13 | inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) 14 | inception = nn.DataParallel(inception).eval().to(device) 15 | return inception 16 | 17 | 18 | @torch.no_grad() 19 | def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): 20 | """Extract inception features. 21 | 22 | Args: 23 | data_generator (generator): A data generator. 24 | inception (nn.Module): Inception model. 25 | len_generator (int): Length of the data_generator to show the 26 | progressbar. Default: None. 27 | device (str): Device. Default: cuda. 28 | 29 | Returns: 30 | Tensor: Extracted features. 31 | """ 32 | if len_generator is not None: 33 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 34 | else: 35 | pbar = None 36 | features = [] 37 | 38 | for data in data_generator: 39 | if pbar: 40 | pbar.update(1) 41 | data = data.to(device) 42 | feature = inception(data)[0].view(data.shape[0], -1) 43 | features.append(feature.to('cpu')) 44 | if pbar: 45 | pbar.close() 46 | features = torch.cat(features, 0) 47 | return features 48 | 49 | 50 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 51 | """Numpy implementation of the Frechet Distance. 52 | 53 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 54 | and X_2 ~ N(mu_2, C_2) is 55 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 56 | Stable version by Dougal J. Sutherland. 57 | 58 | Args: 59 | mu1 (np.array): The sample mean over activations. 60 | sigma1 (np.array): The covariance matrix over activations for 61 | generated samples. 62 | mu2 (np.array): The sample mean over activations, precalculated on an 63 | representative data set. 64 | sigma2 (np.array): The covariance matrix over activations, 65 | precalculated on an representative data set. 66 | 67 | Returns: 68 | float: The Frechet Distance. 69 | """ 70 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 71 | assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') 72 | 73 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 74 | 75 | # Product might be almost singular 76 | if not np.isfinite(cov_sqrt).all(): 77 | print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') 78 | offset = np.eye(sigma1.shape[0]) * eps 79 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 80 | 81 | # Numerical error might give slight imaginary component 82 | if np.iscomplexobj(cov_sqrt): 83 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 84 | m = np.max(np.abs(cov_sqrt.imag)) 85 | raise ValueError(f'Imaginary component {m}') 86 | cov_sqrt = cov_sqrt.real 87 | 88 | mean_diff = mu1 - mu2 89 | mean_norm = mean_diff @ mean_diff 90 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 91 | fid = mean_norm + trace 92 | 93 | return fid 94 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with '_model.py' 12 | model_folder = osp.dirname(osp.abspath(__file__)) 13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 14 | # import all the model modules 15 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 16 | 17 | 18 | def build_model(opt): 19 | """Build model from options. 20 | 21 | Args: 22 | opt (dict): Configuration. It must contain: 23 | model_type (str): Model type. 24 | """ 25 | opt = deepcopy(opt) 26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 27 | logger = get_root_logger() 28 | logger.info(f'Model [{model.__class__.__name__}] is created.') 29 | return model 30 | -------------------------------------------------------------------------------- /basicsr/models/edvr_model.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils import get_root_logger 2 | from basicsr.utils.registry import MODEL_REGISTRY 3 | from .video_base_model import VideoBaseModel 4 | 5 | 6 | @MODEL_REGISTRY.register() 7 | class EDVRModel(VideoBaseModel): 8 | """EDVR Model. 9 | 10 | Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 11 | """ 12 | 13 | def __init__(self, opt): 14 | super(EDVRModel, self).__init__(opt) 15 | if self.is_train: 16 | self.train_tsa_iter = opt['train'].get('tsa_iter') 17 | 18 | def setup_optimizers(self): 19 | train_opt = self.opt['train'] 20 | dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) 21 | logger = get_root_logger() 22 | logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') 23 | if dcn_lr_mul == 1: 24 | optim_params = self.net_g.parameters() 25 | else: # separate dcn params and normal params for different lr 26 | normal_params = [] 27 | dcn_params = [] 28 | for name, param in self.net_g.named_parameters(): 29 | if 'dcn' in name: 30 | dcn_params.append(param) 31 | else: 32 | normal_params.append(param) 33 | optim_params = [ 34 | { # add normal params first 35 | 'params': normal_params, 36 | 'lr': train_opt['optim_g']['lr'] 37 | }, 38 | { 39 | 'params': dcn_params, 40 | 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul 41 | }, 42 | ] 43 | 44 | optim_type = train_opt['optim_g'].pop('type') 45 | self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) 46 | self.optimizers.append(self.optimizer_g) 47 | 48 | def optimize_parameters(self, current_iter): 49 | if self.train_tsa_iter: 50 | if current_iter == 1: 51 | logger = get_root_logger() 52 | logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') 53 | for name, param in self.net_g.named_parameters(): 54 | if 'fusion' not in name: 55 | param.requires_grad = False 56 | elif current_iter == self.train_tsa_iter: 57 | logger = get_root_logger() 58 | logger.warning('Train all the parameters.') 59 | for param in self.net_g.parameters(): 60 | param.requires_grad = True 61 | 62 | super(EDVRModel, self).optimize_parameters(current_iter) 63 | -------------------------------------------------------------------------------- /basicsr/models/esrgan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from .srgan_model import SRGANModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class ESRGANModel(SRGANModel): 10 | """ESRGAN model for single image super-resolution.""" 11 | 12 | def optimize_parameters(self, current_iter): 13 | # optimize net_g 14 | for p in self.net_d.parameters(): 15 | p.requires_grad = False 16 | 17 | self.optimizer_g.zero_grad() 18 | self.output = self.net_g(self.lq) 19 | 20 | l_g_total = 0 21 | loss_dict = OrderedDict() 22 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): 23 | # pixel loss 24 | if self.cri_pix: 25 | l_g_pix = self.cri_pix(self.output, self.gt) 26 | l_g_total += l_g_pix 27 | loss_dict['l_g_pix'] = l_g_pix 28 | # perceptual loss 29 | if self.cri_perceptual: 30 | l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) 31 | if l_g_percep is not None: 32 | l_g_total += l_g_percep 33 | loss_dict['l_g_percep'] = l_g_percep 34 | if l_g_style is not None: 35 | l_g_total += l_g_style 36 | loss_dict['l_g_style'] = l_g_style 37 | # gan loss (relativistic gan) 38 | real_d_pred = self.net_d(self.gt).detach() 39 | fake_g_pred = self.net_d(self.output) 40 | l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False) 41 | l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False) 42 | l_g_gan = (l_g_real + l_g_fake) / 2 43 | 44 | l_g_total += l_g_gan 45 | loss_dict['l_g_gan'] = l_g_gan 46 | 47 | l_g_total.backward() 48 | self.optimizer_g.step() 49 | 50 | # optimize net_d 51 | for p in self.net_d.parameters(): 52 | p.requires_grad = True 53 | 54 | self.optimizer_d.zero_grad() 55 | # gan loss (relativistic gan) 56 | 57 | # In order to avoid the error in distributed training: 58 | # "Error detected in CudnnBatchNormBackward: RuntimeError: one of 59 | # the variables needed for gradient computation has been modified by 60 | # an inplace operation", 61 | # we separate the backwards for real and fake, and also detach the 62 | # tensor for calculating mean. 63 | 64 | # real 65 | fake_d_pred = self.net_d(self.output).detach() 66 | real_d_pred = self.net_d(self.gt) 67 | l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5 68 | l_d_real.backward() 69 | # fake 70 | fake_d_pred = self.net_d(self.output.detach()) 71 | l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5 72 | l_d_fake.backward() 73 | self.optimizer_d.step() 74 | 75 | loss_dict['l_d_real'] = l_d_real 76 | loss_dict['l_d_fake'] = l_d_fake 77 | loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) 78 | loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) 79 | 80 | self.log_dict = self.reduce_loss_dict(loss_dict) 81 | 82 | if self.ema_decay > 0: 83 | self.model_ema(decay=self.ema_decay) 84 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /basicsr/models/srgan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | from basicsr.archs import build_network 5 | from basicsr.losses import build_loss 6 | from basicsr.utils import get_root_logger 7 | from basicsr.utils.registry import MODEL_REGISTRY 8 | from .sr_model import SRModel 9 | 10 | 11 | @MODEL_REGISTRY.register() 12 | class SRGANModel(SRModel): 13 | """SRGAN model for single image super-resolution.""" 14 | 15 | def init_training_settings(self): 16 | train_opt = self.opt['train'] 17 | 18 | self.ema_decay = train_opt.get('ema_decay', 0) 19 | if self.ema_decay > 0: 20 | logger = get_root_logger() 21 | logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') 22 | # define network net_g with Exponential Moving Average (EMA) 23 | # net_g_ema is used only for testing on one GPU and saving 24 | # There is no need to wrap with DistributedDataParallel 25 | self.net_g_ema = build_network(self.opt['network_g']).to(self.device) 26 | # load pretrained model 27 | load_path = self.opt['path'].get('pretrain_network_g', None) 28 | if load_path is not None: 29 | self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') 30 | else: 31 | self.model_ema(0) # copy net_g weight 32 | self.net_g_ema.eval() 33 | 34 | # define network net_d 35 | self.net_d = build_network(self.opt['network_d']) 36 | self.net_d = self.model_to_device(self.net_d) 37 | self.print_network(self.net_d) 38 | 39 | # load pretrained models 40 | load_path = self.opt['path'].get('pretrain_network_d', None) 41 | if load_path is not None: 42 | param_key = self.opt['path'].get('param_key_d', 'params') 43 | self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) 44 | 45 | self.net_g.train() 46 | self.net_d.train() 47 | 48 | # define losses 49 | if train_opt.get('pixel_opt'): 50 | self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) 51 | else: 52 | self.cri_pix = None 53 | 54 | if train_opt.get('ldl_opt'): 55 | self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device) 56 | else: 57 | self.cri_ldl = None 58 | 59 | if train_opt.get('perceptual_opt'): 60 | self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) 61 | else: 62 | self.cri_perceptual = None 63 | 64 | if train_opt.get('gan_opt'): 65 | self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) 66 | 67 | self.net_d_iters = train_opt.get('net_d_iters', 1) 68 | self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) 69 | 70 | # set up optimizers and schedulers 71 | self.setup_optimizers() 72 | self.setup_schedulers() 73 | 74 | def setup_optimizers(self): 75 | train_opt = self.opt['train'] 76 | # optimizer g 77 | optim_type = train_opt['optim_g'].pop('type') 78 | self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g']) 79 | self.optimizers.append(self.optimizer_g) 80 | # optimizer d 81 | optim_type = train_opt['optim_d'].pop('type') 82 | self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) 83 | self.optimizers.append(self.optimizer_d) 84 | 85 | def optimize_parameters(self, current_iter): 86 | # optimize net_g 87 | for p in self.net_d.parameters(): 88 | p.requires_grad = False 89 | 90 | self.optimizer_g.zero_grad() 91 | self.output = self.net_g(self.lq) 92 | 93 | l_g_total = 0 94 | loss_dict = OrderedDict() 95 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): 96 | # pixel loss 97 | if self.cri_pix: 98 | l_g_pix = self.cri_pix(self.output, self.gt) 99 | l_g_total += l_g_pix 100 | loss_dict['l_g_pix'] = l_g_pix 101 | # perceptual loss 102 | if self.cri_perceptual: 103 | l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) 104 | if l_g_percep is not None: 105 | l_g_total += l_g_percep 106 | loss_dict['l_g_percep'] = l_g_percep 107 | if l_g_style is not None: 108 | l_g_total += l_g_style 109 | loss_dict['l_g_style'] = l_g_style 110 | # gan loss 111 | fake_g_pred = self.net_d(self.output) 112 | l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) 113 | l_g_total += l_g_gan 114 | loss_dict['l_g_gan'] = l_g_gan 115 | 116 | l_g_total.backward() 117 | self.optimizer_g.step() 118 | 119 | # optimize net_d 120 | for p in self.net_d.parameters(): 121 | p.requires_grad = True 122 | 123 | self.optimizer_d.zero_grad() 124 | # real 125 | real_d_pred = self.net_d(self.gt) 126 | l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) 127 | loss_dict['l_d_real'] = l_d_real 128 | loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) 129 | l_d_real.backward() 130 | # fake 131 | fake_d_pred = self.net_d(self.output.detach()) 132 | l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) 133 | loss_dict['l_d_fake'] = l_d_fake 134 | loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) 135 | l_d_fake.backward() 136 | self.optimizer_d.step() 137 | 138 | self.log_dict = self.reduce_loss_dict(loss_dict) 139 | 140 | if self.ema_decay > 0: 141 | self.model_ema(decay=self.ema_decay) 142 | 143 | def save(self, epoch, current_iter): 144 | if hasattr(self, 'net_g_ema'): 145 | self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) 146 | else: 147 | self.save_network(self.net_g, 'net_g', current_iter) 148 | self.save_network(self.net_d, 'net_d', current_iter) 149 | self.save_training_state(epoch, current_iter) 150 | -------------------------------------------------------------------------------- /basicsr/models/swinir_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from .sr_model import SRModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class SwinIRModel(SRModel): 10 | 11 | def test(self): 12 | # pad to multiplication of window_size 13 | window_size = self.opt['network_g']['window_size'] 14 | scale = self.opt.get('scale', 1) 15 | mod_pad_h, mod_pad_w = 0, 0 16 | _, _, h, w = self.lq.size() 17 | if h % window_size != 0: 18 | mod_pad_h = window_size - h % window_size 19 | if w % window_size != 0: 20 | mod_pad_w = window_size - w % window_size 21 | img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 22 | if hasattr(self, 'net_g_ema'): 23 | self.net_g_ema.eval() 24 | with torch.no_grad(): 25 | self.output = self.net_g_ema(img) 26 | else: 27 | self.net_g.eval() 28 | with torch.no_grad(): 29 | self.output = self.net_g(img) 30 | self.net_g.train() 31 | 32 | _, _, h, w = self.output.size() 33 | self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] 34 | -------------------------------------------------------------------------------- /basicsr/models/video_gan_model.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import MODEL_REGISTRY 2 | from .srgan_model import SRGANModel 3 | from .video_base_model import VideoBaseModel 4 | 5 | 6 | @MODEL_REGISTRY.register() 7 | class VideoGANModel(SRGANModel, VideoBaseModel): 8 | """Video GAN model. 9 | 10 | Use multiple inheritance. 11 | It will first use the functions of SRGANModel: 12 | init_training_settings 13 | setup_optimizers 14 | optimize_parameters 15 | save 16 | Then find functions in VideoBaseModel. 17 | """ 18 | -------------------------------------------------------------------------------- /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | fused_act_ext = load( 13 | 'fused', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 16 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import fused_act_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class FusedLeakyReLUFunctionBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, out, negative_slope, scale): 34 | ctx.save_for_backward(out) 35 | ctx.negative_slope = negative_slope 36 | ctx.scale = scale 37 | 38 | empty = grad_output.new_empty(0) 39 | 40 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 41 | 42 | dim = [0] 43 | 44 | if grad_input.ndim > 2: 45 | dim += list(range(2, grad_input.ndim)) 46 | 47 | grad_bias = grad_input.sum(dim).detach() 48 | 49 | return grad_input, grad_bias 50 | 51 | @staticmethod 52 | def backward(ctx, gradgrad_input, gradgrad_bias): 53 | out, = ctx.saved_tensors 54 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 55 | ctx.scale) 56 | 57 | return gradgrad_out, None, None, None 58 | 59 | 60 | class FusedLeakyReLUFunction(Function): 61 | 62 | @staticmethod 63 | def forward(ctx, input, bias, negative_slope, scale): 64 | empty = input.new_empty(0) 65 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 66 | ctx.save_for_backward(out) 67 | ctx.negative_slope = negative_slope 68 | ctx.scale = scale 69 | 70 | return out 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | out, = ctx.saved_tensors 75 | 76 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 77 | 78 | return grad_input, grad_bias, None, None 79 | 80 | 81 | class FusedLeakyReLU(nn.Module): 82 | 83 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 95 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 96 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch.autograd import Function 6 | from torch.nn import functional as F 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | upfirdn2d_ext = load( 13 | 'upfirdn2d', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'upfirdn2d.cpp'), 16 | os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import upfirdn2d_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class UpFirDn2dBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 34 | 35 | up_x, up_y = up 36 | down_x, down_y = down 37 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 38 | 39 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 40 | 41 | grad_input = upfirdn2d_ext.upfirdn2d( 42 | grad_output, 43 | grad_kernel, 44 | down_x, 45 | down_y, 46 | up_x, 47 | up_y, 48 | g_pad_x0, 49 | g_pad_x1, 50 | g_pad_y0, 51 | g_pad_y1, 52 | ) 53 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 54 | 55 | ctx.save_for_backward(kernel) 56 | 57 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 58 | 59 | ctx.up_x = up_x 60 | ctx.up_y = up_y 61 | ctx.down_x = down_x 62 | ctx.down_y = down_y 63 | ctx.pad_x0 = pad_x0 64 | ctx.pad_x1 = pad_x1 65 | ctx.pad_y0 = pad_y0 66 | ctx.pad_y1 = pad_y1 67 | ctx.in_size = in_size 68 | ctx.out_size = out_size 69 | 70 | return grad_input 71 | 72 | @staticmethod 73 | def backward(ctx, gradgrad_input): 74 | kernel, = ctx.saved_tensors 75 | 76 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 77 | 78 | gradgrad_out = upfirdn2d_ext.upfirdn2d( 79 | gradgrad_input, 80 | kernel, 81 | ctx.up_x, 82 | ctx.up_y, 83 | ctx.down_x, 84 | ctx.down_y, 85 | ctx.pad_x0, 86 | ctx.pad_x1, 87 | ctx.pad_y0, 88 | ctx.pad_y1, 89 | ) 90 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], 91 | # ctx.out_size[1], ctx.in_size[3]) 92 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 93 | 94 | return gradgrad_out, None, None, None, None, None, None, None, None 95 | 96 | 97 | class UpFirDn2d(Function): 98 | 99 | @staticmethod 100 | def forward(ctx, input, kernel, up, down, pad): 101 | up_x, up_y = up 102 | down_x, down_y = down 103 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 104 | 105 | kernel_h, kernel_w = kernel.shape 106 | _, channel, in_h, in_w = input.shape 107 | ctx.in_size = input.shape 108 | 109 | input = input.reshape(-1, in_h, in_w, 1) 110 | 111 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 112 | 113 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 114 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 115 | ctx.out_size = (out_h, out_w) 116 | 117 | ctx.up = (up_x, up_y) 118 | ctx.down = (down_x, down_y) 119 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 120 | 121 | g_pad_x0 = kernel_w - pad_x0 - 1 122 | g_pad_y0 = kernel_h - pad_y0 - 1 123 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 124 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 125 | 126 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 127 | 128 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 129 | # out = out.view(major, out_h, out_w, minor) 130 | out = out.view(-1, channel, out_h, out_w) 131 | 132 | return out 133 | 134 | @staticmethod 135 | def backward(ctx, grad_output): 136 | kernel, grad_kernel = ctx.saved_tensors 137 | 138 | grad_input = UpFirDn2dBackward.apply( 139 | grad_output, 140 | kernel, 141 | grad_kernel, 142 | ctx.up, 143 | ctx.down, 144 | ctx.pad, 145 | ctx.g_pad, 146 | ctx.in_size, 147 | ctx.out_size, 148 | ) 149 | 150 | return grad_input, None, None, None, None 151 | 152 | 153 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 154 | if input.device.type == 'cpu': 155 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 156 | else: 157 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 158 | 159 | return out 160 | 161 | 162 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 163 | _, channel, in_h, in_w = input.shape 164 | input = input.reshape(-1, in_h, in_w, 1) 165 | 166 | _, in_h, in_w, minor = input.shape 167 | kernel_h, kernel_w = kernel.shape 168 | 169 | out = input.view(-1, in_h, 1, in_w, 1, minor) 170 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 171 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 172 | 173 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 174 | out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] 175 | 176 | out = out.permute(0, 3, 1, 2) 177 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 178 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 179 | out = F.conv2d(out, w) 180 | out = out.reshape( 181 | -1, 182 | minor, 183 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 184 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 185 | ) 186 | out = out.permute(0, 2, 3, 1) 187 | out = out[:, ::down_y, ::down_x, :] 188 | 189 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 190 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 191 | 192 | return out.view(-1, channel, out_h, out_w) 193 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb 2 | from .diffjpeg import DiffJPEG 3 | from .file_client import FileClient 4 | from .img_process_util import USMSharp, usm_sharp 5 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 6 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 7 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 8 | 9 | __all__ = [ 10 | # color_util.py 11 | 'bgr2ycbcr', 12 | 'rgb2ycbcr', 13 | 'rgb2ycbcr_pt', 14 | 'ycbcr2bgr', 15 | 'ycbcr2rgb', 16 | # file_client.py 17 | 'FileClient', 18 | # img_util.py 19 | 'img2tensor', 20 | 'tensor2img', 21 | 'imfrombytes', 22 | 'imwrite', 23 | 'crop_border', 24 | # logger.py 25 | 'MessageLogger', 26 | 'AvgTimer', 27 | 'init_tb_logger', 28 | 'init_wandb_logger', 29 | 'get_root_logger', 30 | 'get_env_info', 31 | # misc.py 32 | 'set_random_seed', 33 | 'get_time_str', 34 | 'mkdir_and_rename', 35 | 'make_exp_dirs', 36 | 'scandir', 37 | 'check_resume', 38 | 'sizeof_fmt', 39 | # diffjpeg 40 | 'DiffJPEG', 41 | # img_process_util 42 | 'USMSharp', 43 | 'usm_sharp' 44 | ] 45 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | 14 | Ref: 15 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 16 | 17 | Args: 18 | file_id (str): File id. 19 | save_path (str): Save path. 20 | """ 21 | 22 | session = requests.Session() 23 | URL = 'https://docs.google.com/uc?export=download' 24 | params = {'id': file_id} 25 | 26 | response = session.get(URL, params=params, stream=True) 27 | token = get_confirm_token(response) 28 | if token: 29 | params['confirm'] = token 30 | response = session.get(URL, params=params, stream=True) 31 | 32 | # get file size 33 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 34 | if 'Content-Range' in response_file_size.headers: 35 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 36 | else: 37 | file_size = None 38 | 39 | save_response_content(response, save_path, file_size) 40 | 41 | 42 | def get_confirm_token(response): 43 | for key, value in response.cookies.items(): 44 | if key.startswith('download_warning'): 45 | return value 46 | return None 47 | 48 | 49 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 50 | if file_size is not None: 51 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 52 | 53 | readable_file_size = sizeof_fmt(file_size) 54 | else: 55 | pbar = None 56 | 57 | with open(destination, 'wb') as f: 58 | downloaded_size = 0 59 | for chunk in response.iter_content(chunk_size): 60 | downloaded_size += chunk_size 61 | if pbar is not None: 62 | pbar.update(1) 63 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 64 | if chunk: # filter out keep-alive new chunks 65 | f.write(chunk) 66 | if pbar is not None: 67 | pbar.close() 68 | 69 | 70 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 71 | """Load file form http url, will download models if necessary. 72 | 73 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 74 | 75 | Args: 76 | url (str): URL to be downloaded. 77 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 78 | Default: None. 79 | progress (bool): Whether to show the download progress. Default: True. 80 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 81 | 82 | Returns: 83 | str: The path to the downloaded file. 84 | """ 85 | if model_dir is None: # use the pytorch hub_dir 86 | hub_dir = get_dir() 87 | model_dir = os.path.join(hub_dir, 'checkpoints') 88 | 89 | os.makedirs(model_dir, exist_ok=True) 90 | 91 | parts = urlparse(url) 92 | filename = os.path.basename(parts.path) 93 | if file_name is not None: 94 | filename = file_name 95 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 96 | if not os.path.exists(cached_file): 97 | print(f'Downloading: "{url}" to {cached_file}\n') 98 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 99 | return cached_file 100 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing different lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') 25 | assert cat_flow.shape[concat_axis] % 2 == 0 26 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 27 | flow = dequantize_flow(dx, dy, *args, **kwargs) 28 | else: 29 | with open(flow_path, 'rb') as f: 30 | try: 31 | header = f.read(4).decode('utf-8') 32 | except Exception: 33 | raise IOError(f'Invalid flow file: {flow_path}') 34 | else: 35 | if header != 'PIEH': 36 | raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') 37 | 38 | w = np.fromfile(f, np.int32, 1).squeeze() 39 | h = np.fromfile(f, np.int32, 1).squeeze() 40 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 41 | 42 | return flow.astype(np.float32) 43 | 44 | 45 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 46 | """Write optical flow to file. 47 | 48 | If the flow is not quantized, it will be saved as a .flo file losslessly, 49 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 50 | will be concatenated horizontally into a single image if quantize is True.) 51 | 52 | Args: 53 | flow (ndarray): (h, w, 2) array of optical flow. 54 | filename (str): Output filepath. 55 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 56 | images. If set to True, remaining args will be passed to 57 | :func:`quantize_flow`. 58 | concat_axis (int): The axis that dx and dy are concatenated, 59 | can be either 0 or 1. Ignored if quantize is False. 60 | """ 61 | if not quantize: 62 | with open(filename, 'wb') as f: 63 | f.write('PIEH'.encode('utf-8')) 64 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 65 | flow = flow.astype(np.float32) 66 | flow.tofile(f) 67 | f.flush() 68 | else: 69 | assert concat_axis in [0, 1] 70 | dx, dy = quantize_flow(flow, *args, **kwargs) 71 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 72 | os.makedirs(os.path.dirname(filename), exist_ok=True) 73 | cv2.imwrite(filename, dxdy) 74 | 75 | 76 | def quantize_flow(flow, max_val=0.02, norm=True): 77 | """Quantize flow to [0, 255]. 78 | 79 | After this step, the size of flow will be much smaller, and can be 80 | dumped as jpeg images. 81 | 82 | Args: 83 | flow (ndarray): (h, w, 2) array of optical flow. 84 | max_val (float): Maximum value of flow, values beyond 85 | [-max_val, max_val] will be truncated. 86 | norm (bool): Whether to divide flow values by image width/height. 87 | 88 | Returns: 89 | tuple[ndarray]: Quantized dx and dy. 90 | """ 91 | h, w, _ = flow.shape 92 | dx = flow[..., 0] 93 | dy = flow[..., 1] 94 | if norm: 95 | dx = dx / w # avoid inplace operations 96 | dy = dy / h 97 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 98 | flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] 99 | return tuple(flow_comps) 100 | 101 | 102 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 103 | """Recover from quantized flow. 104 | 105 | Args: 106 | dx (ndarray): Quantized dx. 107 | dy (ndarray): Quantized dy. 108 | max_val (float): Maximum value used when quantizing. 109 | denorm (bool): Whether to multiply flow values with width/height. 110 | 111 | Returns: 112 | ndarray: Dequantized flow. 113 | """ 114 | assert dx.shape == dy.shape 115 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 116 | 117 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 118 | 119 | if denorm: 120 | dx *= dx.shape[1] 121 | dy *= dx.shape[0] 122 | flow = np.dstack((dx, dy)) 123 | return flow 124 | 125 | 126 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 127 | """Quantize an array of (-inf, inf) to [0, levels-1]. 128 | 129 | Args: 130 | arr (ndarray): Input array. 131 | min_val (scalar): Minimum value to be clipped. 132 | max_val (scalar): Maximum value to be clipped. 133 | levels (int): Quantization levels. 134 | dtype (np.type): The type of the quantized array. 135 | 136 | Returns: 137 | tuple: Quantized array. 138 | """ 139 | if not (isinstance(levels, int) and levels > 1): 140 | raise ValueError(f'levels must be a positive integer, but got {levels}') 141 | if min_val >= max_val: 142 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 143 | 144 | arr = np.clip(arr, min_val, max_val) - min_val 145 | quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 146 | 147 | return quantized_arr 148 | 149 | 150 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 151 | """Dequantize an array. 152 | 153 | Args: 154 | arr (ndarray): Input array. 155 | min_val (scalar): Minimum value to be clipped. 156 | max_val (scalar): Maximum value to be clipped. 157 | levels (int): Quantization levels. 158 | dtype (np.type): The type of the dequantized array. 159 | 160 | Returns: 161 | tuple: Dequantized array. 162 | """ 163 | if not (isinstance(levels, int) and levels > 1): 164 | raise ValueError(f'levels must be a positive integer, but got {levels}') 165 | if min_val >= max_val: 166 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 167 | 168 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val 169 | 170 | return dequantized_arr 171 | -------------------------------------------------------------------------------- /basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | ok = cv2.imwrite(file_path, img, params) 152 | if not ok: 153 | raise IOError('Failed in writing images.') 154 | 155 | 156 | def crop_border(imgs, crop_border): 157 | """Crop borders of images. 158 | 159 | Args: 160 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 161 | crop_border (int): Crop border for each end of height and weight. 162 | 163 | Returns: 164 | list[ndarray]: Cropped images. 165 | """ 166 | if crop_border == 0: 167 | return imgs 168 | else: 169 | if isinstance(imgs, list): 170 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 171 | else: 172 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 173 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file size. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | import yaml 5 | from collections import OrderedDict 6 | from os import path as osp 7 | 8 | from basicsr.utils import set_random_seed 9 | from basicsr.utils.dist_util import get_dist_info, init_dist, master_only 10 | 11 | 12 | def ordered_yaml(): 13 | """Support OrderedDict for yaml. 14 | 15 | Returns: 16 | yaml Loader and Dumper. 17 | """ 18 | try: 19 | from yaml import CDumper as Dumper 20 | from yaml import CLoader as Loader 21 | except ImportError: 22 | from yaml import Dumper, Loader 23 | 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | def dict2str(opt, indent_level=1): 38 | """dict to string for printing options. 39 | 40 | Args: 41 | opt (dict): Option dict. 42 | indent_level (int): Indent level. Default: 1. 43 | 44 | Return: 45 | (str): Option string for printing. 46 | """ 47 | msg = '\n' 48 | for k, v in opt.items(): 49 | if isinstance(v, dict): 50 | msg += ' ' * (indent_level * 2) + k + ':[' 51 | msg += dict2str(v, indent_level + 1) 52 | msg += ' ' * (indent_level * 2) + ']\n' 53 | else: 54 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 55 | return msg 56 | 57 | 58 | def _postprocess_yml_value(value): 59 | # None 60 | if value == '~' or value.lower() == 'none': 61 | return None 62 | # bool 63 | if value.lower() == 'true': 64 | return True 65 | elif value.lower() == 'false': 66 | return False 67 | # !!float number 68 | if value.startswith('!!float'): 69 | return float(value.replace('!!float', '')) 70 | # number 71 | if value.isdigit(): 72 | return int(value) 73 | elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: 74 | return float(value) 75 | # list 76 | if value.startswith('['): 77 | return eval(value) 78 | # str 79 | return value 80 | 81 | 82 | def parse_options(root_path, is_train=True): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') 85 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') 86 | parser.add_argument('--auto_resume', action='store_true') 87 | parser.add_argument('--debug', action='store_true') 88 | parser.add_argument('--local_rank', type=int, default=0) 89 | parser.add_argument( 90 | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') 91 | args = parser.parse_args() 92 | 93 | # parse yml to dict 94 | with open(args.opt, mode='r') as f: 95 | opt = yaml.load(f, Loader=ordered_yaml()[0]) 96 | 97 | # distributed settings 98 | if args.launcher == 'none': 99 | opt['dist'] = False 100 | print('Disable distributed.', flush=True) 101 | else: 102 | opt['dist'] = True 103 | if args.launcher == 'slurm' and 'dist_params' in opt: 104 | init_dist(args.launcher, **opt['dist_params']) 105 | else: 106 | init_dist(args.launcher) 107 | opt['rank'], opt['world_size'] = get_dist_info() 108 | 109 | # random seed 110 | seed = opt.get('manual_seed') 111 | if seed is None: 112 | seed = random.randint(1, 10000) 113 | opt['manual_seed'] = seed 114 | set_random_seed(seed + opt['rank']) 115 | 116 | # force to update yml options 117 | if args.force_yml is not None: 118 | for entry in args.force_yml: 119 | # now do not support creating new keys 120 | keys, value = entry.split('=') 121 | keys, value = keys.strip(), value.strip() 122 | value = _postprocess_yml_value(value) 123 | eval_str = 'opt' 124 | for key in keys.split(':'): 125 | eval_str += f'["{key}"]' 126 | eval_str += '=value' 127 | # using exec function 128 | exec(eval_str) 129 | 130 | opt['auto_resume'] = args.auto_resume 131 | opt['is_train'] = is_train 132 | 133 | # debug setting 134 | if args.debug and not opt['name'].startswith('debug'): 135 | opt['name'] = 'debug_' + opt['name'] 136 | 137 | if opt['num_gpu'] == 'auto': 138 | opt['num_gpu'] = torch.cuda.device_count() 139 | 140 | # datasets 141 | for phase, dataset in opt['datasets'].items(): 142 | # for multiple datasets, e.g., val_1, val_2; test_1, test_2 143 | phase = phase.split('_')[0] 144 | dataset['phase'] = phase 145 | if 'scale' in opt: 146 | dataset['scale'] = opt['scale'] 147 | if dataset.get('dataroot_gt') is not None: 148 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 149 | if dataset.get('dataroot_lq') is not None: 150 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 151 | 152 | # paths 153 | for key, val in opt['path'].items(): 154 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 155 | opt['path'][key] = osp.expanduser(val) 156 | 157 | if is_train: 158 | experiments_root = osp.join(root_path, 'experiments', opt['name']) 159 | opt['path']['experiments_root'] = experiments_root 160 | opt['path']['models'] = osp.join(experiments_root, 'models') 161 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 162 | opt['path']['log'] = experiments_root 163 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 164 | 165 | # change some options for debug mode 166 | if 'debug' in opt['name']: 167 | if 'val' in opt: 168 | opt['val']['val_freq'] = 8 169 | opt['logger']['print_freq'] = 1 170 | opt['logger']['save_checkpoint_freq'] = 8 171 | else: # test 172 | results_root = osp.join(root_path, 'results', opt['name']) 173 | opt['path']['results_root'] = results_root 174 | opt['path']['log'] = results_root 175 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 176 | 177 | return opt, args 178 | 179 | 180 | @master_only 181 | def copy_opt_file(opt_file, experiments_root): 182 | # copy the yml file to the experiment root 183 | import sys 184 | import time 185 | from shutil import copyfile 186 | cmd = ' '.join(sys.argv) 187 | filename = osp.join(experiments_root, osp.basename(opt_file)) 188 | copyfile(opt_file, filename) 189 | 190 | with open(filename, 'r+') as f: 191 | lines = f.readlines() 192 | lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') 193 | f.seek(0) 194 | f.writelines(lines) 195 | -------------------------------------------------------------------------------- /basicsr/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def read_data_from_tensorboard(log_path, tag): 5 | """Get raw data (steps and values) from tensorboard events. 6 | 7 | Args: 8 | log_path (str): Path to the tensorboard log. 9 | tag (str): tag to be read. 10 | """ 11 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 12 | 13 | # tensorboard event 14 | event_acc = EventAccumulator(log_path) 15 | event_acc.Reload() 16 | scalar_list = event_acc.Tags()['scalars'] 17 | print('tag list: ', scalar_list) 18 | steps = [int(s.step) for s in event_acc.Scalars(tag)] 19 | values = [s.value for s in event_acc.Scalars(tag)] 20 | return steps, values 21 | 22 | 23 | def read_data_from_txt_2v(path, pattern, step_one=False): 24 | """Read data from txt with 2 returned values (usually [step, value]). 25 | 26 | Args: 27 | path (str): path to the txt file. 28 | pattern (str): re (regular expression) pattern. 29 | step_one (bool): add 1 to steps. Default: False. 30 | """ 31 | with open(path) as f: 32 | lines = f.readlines() 33 | lines = [line.strip() for line in lines] 34 | steps = [] 35 | values = [] 36 | 37 | pattern = re.compile(pattern) 38 | for line in lines: 39 | match = pattern.match(line) 40 | if match: 41 | steps.append(int(match.group(1))) 42 | values.append(float(match.group(2))) 43 | if step_one: 44 | steps = [v + 1 for v in steps] 45 | return steps, values 46 | 47 | 48 | def read_data_from_txt_1v(path, pattern): 49 | """Read data from txt with 1 returned values. 50 | 51 | Args: 52 | path (str): path to the txt file. 53 | pattern (str): re (regular expression) pattern. 54 | """ 55 | with open(path) as f: 56 | lines = f.readlines() 57 | lines = [line.strip() for line in lines] 58 | data = [] 59 | 60 | pattern = re.compile(pattern) 61 | for line in lines: 62 | match = pattern.match(line) 63 | if match: 64 | data.append(float(match.group(1))) 65 | return data 66 | 67 | 68 | def smooth_data(values, smooth_weight): 69 | """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). 70 | 71 | Ref: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/\ 72 | tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 73 | 74 | Args: 75 | values (list): A list of values to be smoothed. 76 | smooth_weight (float): Smooth weight. 77 | """ 78 | values_sm = [] 79 | last_sm_value = values[0] 80 | for value in values: 81 | value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value 82 | values_sm.append(value_sm) 83 | last_sm_value = value_sm 84 | return values_sm 85 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj, suffix=None): 39 | if isinstance(suffix, str): 40 | name = name + '_' + suffix 41 | 42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 43 | f"in '{self._name}' registry!") 44 | self._obj_map[name] = obj 45 | 46 | def register(self, obj=None, suffix=None): 47 | """ 48 | Register the given object under the the name `obj.__name__`. 49 | Can be used as either a decorator or not. 50 | See docstring of this class for usage. 51 | """ 52 | if obj is None: 53 | # used as a decorator 54 | def deco(func_or_class): 55 | name = func_or_class.__name__ 56 | self._do_register(name, func_or_class, suffix) 57 | return func_or_class 58 | 59 | return deco 60 | 61 | # used as a function call 62 | name = obj.__name__ 63 | self._do_register(name, obj, suffix) 64 | 65 | def get(self, name, suffix='basicsr'): 66 | ret = self._obj_map.get(name) 67 | if ret is None: 68 | ret = self._obj_map.get(name + '_' + suffix) 69 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 70 | if ret is None: 71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 72 | return ret 73 | 74 | def __contains__(self, name): 75 | return name in self._obj_map 76 | 77 | def __iter__(self): 78 | return iter(self._obj_map.items()) 79 | 80 | def keys(self): 81 | return self._obj_map.keys() 82 | 83 | 84 | DATASET_REGISTRY = Registry('dataset') 85 | ARCH_REGISTRY = Registry('arch') 86 | MODEL_REGISTRY = Registry('model') 87 | LOSS_REGISTRY = Registry('loss') 88 | METRIC_REGISTRY = Registry('metric') 89 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Sat Jun 29 22:33:34 2024 3 | __version__ = '1.4.2' 4 | __gitsha__ = 'unknown' 5 | version_info = (1, 4, 2) 6 | -------------------------------------------------------------------------------- /helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/helpers/__init__.py -------------------------------------------------------------------------------- /helpers/audio.py: -------------------------------------------------------------------------------- 1 | # This Module is used to process Audio. And is a part of https://github.com/Rudrabha/Wav2Lip repository. 2 | 3 | import librosa 4 | import librosa.filters 5 | import numpy as np 6 | # import tensorflow as tf 7 | from scipy import signal 8 | from scipy.io import wavfile 9 | from helpers.hparams import hparams as hp 10 | 11 | def load_wav(path, sr): 12 | return librosa.core.load(path, sr=sr)[0] 13 | 14 | def save_wav(wav, path, sr): 15 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 16 | #proposed by @dsmiller 17 | wavfile.write(path, sr, wav.astype(np.int16)) 18 | 19 | def save_wavenet_wav(wav, path, sr): 20 | librosa.output.write_wav(path, wav, sr=sr) 21 | 22 | def preemphasis(wav, k, preemphasize=True): 23 | if preemphasize: 24 | return signal.lfilter([1, -k], [1], wav) 25 | return wav 26 | 27 | def inv_preemphasis(wav, k, inv_preemphasize=True): 28 | if inv_preemphasize: 29 | return signal.lfilter([1], [1, -k], wav) 30 | return wav 31 | 32 | def get_hop_size(): 33 | hop_size = hp.hop_size 34 | if hop_size is None: 35 | assert hp.frame_shift_ms is not None 36 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 37 | return hop_size 38 | 39 | def linearspectrogram(wav): 40 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 41 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 42 | 43 | if hp.signal_normalization: 44 | return _normalize(S) 45 | return S 46 | 47 | def melspectrogram(wav): 48 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 49 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 50 | 51 | if hp.signal_normalization: 52 | return _normalize(S) 53 | return S 54 | 55 | def _lws_processor(): 56 | import lws 57 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") 58 | 59 | def _stft(y): 60 | if hp.use_lws: 61 | return _lws_processor(hp).stft(y).T 62 | else: 63 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) 64 | 65 | ########################################################## 66 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 67 | def num_frames(length, fsize, fshift): 68 | """Compute number of time frames of spectrogram 69 | """ 70 | pad = (fsize - fshift) 71 | if length % fshift == 0: 72 | M = (length + pad * 2 - fsize) // fshift + 1 73 | else: 74 | M = (length + pad * 2 - fsize) // fshift + 2 75 | return M 76 | 77 | 78 | def pad_lr(x, fsize, fshift): 79 | """Compute left and right padding 80 | """ 81 | M = num_frames(len(x), fsize, fshift) 82 | pad = (fsize - fshift) 83 | T = len(x) + 2 * pad 84 | r = (M - 1) * fshift + fsize - T 85 | return pad, pad + r 86 | ########################################################## 87 | #Librosa correct padding 88 | def librosa_pad_lr(x, fsize, fshift): 89 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 90 | 91 | # Conversions 92 | _mel_basis = None 93 | 94 | def _linear_to_mel(spectogram): 95 | global _mel_basis 96 | if _mel_basis is None: 97 | _mel_basis = _build_mel_basis() 98 | return np.dot(_mel_basis, spectogram) 99 | 100 | def _build_mel_basis(): 101 | assert hp.fmax <= hp.sample_rate // 2 102 | return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, 103 | fmin=hp.fmin, fmax=hp.fmax) 104 | 105 | def _amp_to_db(x): 106 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 107 | return 20 * np.log10(np.maximum(min_level, x)) 108 | 109 | def _db_to_amp(x): 110 | return np.power(10.0, (x) * 0.05) 111 | 112 | def _normalize(S): 113 | if hp.allow_clipping_in_normalization: 114 | if hp.symmetric_mels: 115 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, 116 | -hp.max_abs_value, hp.max_abs_value) 117 | else: 118 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) 119 | 120 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 121 | if hp.symmetric_mels: 122 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value 123 | else: 124 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 125 | 126 | def _denormalize(D): 127 | if hp.allow_clipping_in_normalization: 128 | if hp.symmetric_mels: 129 | return (((np.clip(D, -hp.max_abs_value, 130 | hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) 131 | + hp.min_level_db) 132 | else: 133 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 134 | 135 | if hp.symmetric_mels: 136 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) 137 | else: 138 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 139 | -------------------------------------------------------------------------------- /helpers/batch_processors.py: -------------------------------------------------------------------------------- 1 | # This file is a part of https://github.com/pawansharmaaaa/Lip_Wise/ repository. 2 | 3 | import cv2 4 | import os 5 | 6 | import mediapipe as mp 7 | import numpy as np 8 | 9 | from concurrent.futures import ThreadPoolExecutor 10 | from functools import partial 11 | 12 | from helpers import preprocess_mp 13 | from helpers import file_check 14 | 15 | class BatchProcessors: 16 | def __init__(self, image_mode=False): 17 | 18 | self.npy_directory = file_check.NPY_FILES_DIR 19 | self.weights_directory = file_check.WEIGHTS_DIR 20 | 21 | self.helper = preprocess_mp.FaceHelpers(image_mode=image_mode) 22 | 23 | def extract_face_batch(self, frame_batch, frame_numbers): 24 | with ThreadPoolExecutor() as executor: 25 | extracted_faces, masks, inv_masks, centers, bboxes = zip(*list(executor.map(self.helper.extract_face, frame_batch, frame_numbers))) 26 | return extracted_faces, masks, inv_masks, centers, bboxes 27 | 28 | def align_crop_batch(self, extracted_faces, frame_numbers): 29 | with ThreadPoolExecutor() as executor: 30 | cropped_faces, aligned_bboxes, rotation_matrices = zip(*list(executor.map(self.helper.align_crop_face, extracted_faces, frame_numbers))) 31 | return cropped_faces, aligned_bboxes, rotation_matrices 32 | 33 | def gen_data_video_mode(self, cropped_faces_batch, mel_batch): 34 | """ 35 | Generates data for inference in video mode. 36 | Batches the data to be fed into the model. 37 | Batch of image includes several images of shape (96, 96, 6) stacked together. 38 | These images contain the half face and the full face. 39 | 40 | Args: 41 | cropped_faces: a batch of size batch_size of The cropped faces obtained from the crop_extracted_face function. 42 | mel_batch: a batch of size batch_size consisting of The mel chunks obtained from the audio. 43 | 44 | Returns: 45 | A batch of images of shape (96, 96, 6) and mel chunks. 46 | """ 47 | resized_cropped_faces_batch = [] 48 | # Resize face for wav2lip 49 | for cropped_face in cropped_faces_batch: 50 | cropped_face = cv2.resize(cropped_face, (96, 96), interpolation=cv2.INTER_AREA) 51 | resized_cropped_faces_batch.append(cropped_face) 52 | 53 | frame_batch = np.asarray(resized_cropped_faces_batch) 54 | 55 | img_masked = frame_batch.copy() 56 | img_masked[:, 96//2:] = 0 57 | 58 | frame_batch = np.concatenate((img_masked, frame_batch), axis=3) / 255. 59 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 60 | 61 | return frame_batch, mel_batch 62 | 63 | def face_resize_batch(self, restored_faces, cropped_faces_batch): 64 | size_batch = [] 65 | for cropped_face in cropped_faces_batch: 66 | height, width = cropped_face.shape[:2] 67 | size_batch.append((width, height)) 68 | 69 | resizer_partial = partial(cv2.resize, interpolation=cv2.INTER_LANCZOS4) 70 | with ThreadPoolExecutor() as executor: 71 | resized_restored_faces = list(executor.map(resizer_partial, restored_faces, size_batch)) 72 | return resized_restored_faces 73 | 74 | def paste_back_black_bg_batch(self, processed_face_batch, aligned_bboxes_batch, frame_batch, ml): 75 | paste_partial = partial(self.helper.paste_back_black_bg, ml=ml) 76 | with ThreadPoolExecutor() as executor: 77 | pasted_ready_faces = list(executor.map(paste_partial, processed_face_batch, aligned_bboxes_batch, frame_batch)) 78 | return pasted_ready_faces 79 | 80 | def unwarp_align_batch(self, pasted_ready_faces, rotation_matrices): 81 | with ThreadPoolExecutor() as executor: 82 | ready_to_paste = list(executor.map(self.helper.unwarp_align, pasted_ready_faces, rotation_matrices)) 83 | return ready_to_paste 84 | 85 | def paste_back_batch(self, ready_to_paste, frame_batch, face_masks, inv_masks, centers): 86 | with ThreadPoolExecutor() as executor: 87 | pasted_faces = list(executor.map(self.helper.paste_back, ready_to_paste, frame_batch, face_masks, inv_masks, centers)) 88 | return pasted_faces -------------------------------------------------------------------------------- /helpers/hparams.py: -------------------------------------------------------------------------------- 1 | # This Module is used to set the hyperparameters for the model. And is a part of https://github.com/Rudrabha/Wav2Lip repository. 2 | 3 | from glob import glob 4 | import os 5 | 6 | def get_image_list(data_root, split): 7 | filelist = [] 8 | 9 | with open('filelists/{}.txt'.format(split)) as f: 10 | for line in f: 11 | line = line.strip() 12 | if ' ' in line: line = line.split()[0] 13 | filelist.append(os.path.join(data_root, line)) 14 | 15 | return filelist 16 | 17 | class HParams: 18 | def __init__(self, **kwargs): 19 | self.data = {} 20 | 21 | for key, value in kwargs.items(): 22 | self.data[key] = value 23 | 24 | def __getattr__(self, key): 25 | if key not in self.data: 26 | raise AttributeError("'HParams' object has no attribute %s" % key) 27 | return self.data[key] 28 | 29 | def set_hparam(self, key, value): 30 | self.data[key] = value 31 | 32 | 33 | # Default hyperparameters 34 | hparams = HParams( 35 | num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality 36 | # network 37 | rescale=True, # Whether to rescale audio prior to preprocessing 38 | rescaling_max=0.9, # Rescaling value 39 | 40 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 41 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 42 | # Does not work if n_ffit is not multiple of hop_size!! 43 | use_lws=False, 44 | 45 | n_fft=800, # Extra window size is filled with 0 paddings to match this parameter 46 | hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 47 | win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 48 | sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) 49 | 50 | frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) 51 | 52 | # Mel and Linear spectrograms normalization/scaling and clipping 53 | signal_normalization=True, 54 | # Whether to normalize mel spectrograms to some predefined range (following below parameters) 55 | allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True 56 | symmetric_mels=True, 57 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 58 | # faster and cleaner convergence) 59 | max_abs_value=4., 60 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 61 | # be too big to avoid gradient explosion, 62 | # not too small for fast convergence) 63 | # Contribution by @begeekmyfriend 64 | # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 65 | # levels. Also allows for better G&L phase reconstruction) 66 | preemphasize=True, # whether to apply filter 67 | preemphasis=0.97, # filter coefficient. 68 | 69 | # Limits 70 | min_level_db=-100, 71 | ref_level_db=20, 72 | fmin=55, 73 | # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 74 | # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 75 | fmax=7600, # To be increased/reduced depending on data. 76 | 77 | ###################### Our training parameters ################################# 78 | img_size=96, 79 | fps=25, 80 | 81 | batch_size=16, 82 | initial_learning_rate=1e-4, 83 | nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs 84 | num_workers=16, 85 | checkpoint_interval=3000, 86 | eval_interval=3000, 87 | save_optimizer_state=True, 88 | 89 | syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. 90 | syncnet_batch_size=64, 91 | syncnet_lr=1e-4, 92 | syncnet_eval_interval=10000, 93 | syncnet_checkpoint_interval=10000, 94 | 95 | disc_wt=0.07, 96 | disc_initial_learning_rate=1e-4, 97 | ) 98 | 99 | 100 | def hparams_debug_string(): 101 | values = hparams.values() 102 | hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] 103 | return "Hyperparameters:\n" + "\n".join(hp) 104 | -------------------------------------------------------------------------------- /helpers/vars.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pawansharmaaaa/Lip_Wise/e414debb6f6b645908e71cc6737437cff95794d4/helpers/vars.py -------------------------------------------------------------------------------- /launch.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | call .lip-wise\Scripts\activate 3 | python launch.py 4 | pause -------------------------------------------------------------------------------- /launch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for arg in "$@" 3 | do 4 | if [ "$arg" == "--colab" ] 5 | then 6 | python launch.py --colab 7 | exit 0 8 | else 9 | source .lip-wise/bin/activate 10 | python launch.py 11 | fi 12 | done -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .wav2lip import Wav2Lip, Wav2Lip_disc_qual 2 | from .syncnet import SyncNet_color -------------------------------------------------------------------------------- /models/conv.py: -------------------------------------------------------------------------------- 1 | # This is a part of https://github.com/Rudrabha/Wav2Lip repository. 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | class Conv2d(nn.Module): 8 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.conv_block = nn.Sequential( 11 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 12 | nn.BatchNorm2d(cout) 13 | ) 14 | self.act = nn.ReLU() 15 | self.residual = residual 16 | 17 | def forward(self, x): 18 | out = self.conv_block(x) 19 | if self.residual: 20 | out += x 21 | return self.act(out) 22 | 23 | class nonorm_Conv2d(nn.Module): 24 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | self.conv_block = nn.Sequential( 27 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 28 | ) 29 | self.act = nn.LeakyReLU(0.01, inplace=True) 30 | 31 | def forward(self, x): 32 | out = self.conv_block(x) 33 | return self.act(out) 34 | 35 | class Conv2dTranspose(nn.Module): 36 | def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): 37 | super().__init__(*args, **kwargs) 38 | self.conv_block = nn.Sequential( 39 | nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), 40 | nn.BatchNorm2d(cout) 41 | ) 42 | self.act = nn.ReLU() 43 | 44 | def forward(self, x): 45 | out = self.conv_block(x) 46 | return self.act(out) 47 | -------------------------------------------------------------------------------- /models/syncnet.py: -------------------------------------------------------------------------------- 1 | # This is a part of https://github.com/Rudrabha/Wav2Lip repository. 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .conv import Conv2d 8 | 9 | class SyncNet_color(nn.Module): 10 | def __init__(self): 11 | super(SyncNet_color, self).__init__() 12 | 13 | self.face_encoder = nn.Sequential( 14 | Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3), 15 | 16 | Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1), 17 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 18 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 19 | 20 | Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 21 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 22 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 23 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 24 | 25 | Conv2d(128, 256, kernel_size=3, stride=2, padding=1), 26 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 27 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 28 | 29 | Conv2d(256, 512, kernel_size=3, stride=2, padding=1), 30 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 31 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 32 | 33 | Conv2d(512, 512, kernel_size=3, stride=2, padding=1), 34 | Conv2d(512, 512, kernel_size=3, stride=1, padding=0), 35 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 36 | 37 | self.audio_encoder = nn.Sequential( 38 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 39 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 40 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 41 | 42 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 43 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 44 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 45 | 46 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 47 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 48 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 49 | 50 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 51 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 52 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 53 | 54 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 55 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 56 | 57 | def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T) 58 | face_embedding = self.face_encoder(face_sequences) 59 | audio_embedding = self.audio_encoder(audio_sequences) 60 | 61 | audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) 62 | face_embedding = face_embedding.view(face_embedding.size(0), -1) 63 | 64 | audio_embedding = F.normalize(audio_embedding, p=2, dim=1) 65 | face_embedding = F.normalize(face_embedding, p=2, dim=1) 66 | 67 | 68 | return audio_embedding, face_embedding 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gfpgan>=1.3.8 2 | gdown>=4.7.3 3 | gradio>=4.12.0 4 | ipywidgets>=8.1.2 5 | librosa>=0.9.1 6 | mediapipe>=0.10.9 7 | numba>=0.58.1 8 | numpy>=1.26.2 9 | opencv-contrib-python>=4.8.1.78 10 | opencv-python>=4.8.1.78 11 | realesrgan>=0.3.0 12 | torch>=2.1.2 13 | torchvision==0.16.2 14 | tqdm>=4.66.1 15 | -------------------------------------------------------------------------------- /setup-colab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Check if ffmpeg is installed 3 | if ! command -v ffmpeg &> /dev/null 4 | then 5 | sudo apt-get install ffmpeg 6 | fi 7 | 8 | # Install requirements 9 | pip install -r requirements.txt 10 | pip install --upgrade --no-cache-dir gdown 11 | 12 | # Find python version 13 | python_version=$(python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))") 14 | 15 | # Copy archs 16 | cp archs/* /usr/local/lib/python${python_version}/dist-packages/basicsr/archs 17 | 18 | # Run file_check.py 19 | python ./helpers/file_check.py -------------------------------------------------------------------------------- /setup.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | REM Check if ffmpeg is installed 4 | where ffmpeg >nul 2>nul 5 | if %ERRORLEVEL% neq 0 ( 6 | echo ffmpeg is not installed. Checking for winget... 7 | 8 | REM Check if winget is installed 9 | where winget >nul 2>nul 10 | if %ERRORLEVEL% neq 0 ( 11 | echo winget is not installed. Please install winget or manually install ffmpeg. 12 | exit /b 13 | ) 14 | 15 | REM Install FFMPEG 16 | winget install ffmpeg 17 | ) 18 | 19 | REM Check if CUDA is installed 20 | if not defined CUDA_PATH ( 21 | echo CUDA is not installed. Please install CUDA. 22 | echo CPU will be used for inference. 23 | ) 24 | 25 | REM Create a virtual environment 26 | python -m venv .lip-wise 27 | 28 | REM Activate the virtual environment 29 | call .lip-wise\Scripts\activate 30 | 31 | REM Install requirements 32 | pip install -r requirements.txt 33 | 34 | REM Copy basicsr archs 35 | copy archs\* .lip-wise\Lib\site-packages\basicsr\archs 36 | 37 | REM Run file_check.py 38 | python .\helpers\file_check.py -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install FFMPEG 4 | os=$(uname -a) 5 | 6 | # Check if ffmpeg is installed 7 | if ! command -v ffmpeg &> /dev/null 8 | then 9 | if [[ $os == *"arch"* ]]; then 10 | # If the OS is Arch Linux, use pacman 11 | sudo pacman -S ffmpeg 12 | elif [[ $os == *"Ubuntu"* ]] || [[ $os == *"Debian"* ]]; then 13 | # If the OS is Ubuntu or Debian, use apt-get. Also install python3-venv because it is not installed by default in debian. 14 | sudo apt-get install ffmpeg 15 | sudo apt install python3-venv 16 | fi 17 | fi 18 | 19 | # Check if CUDA is installed 20 | if [ -z "$CUDA_PATH" ] 21 | then 22 | echo "CUDA is not installed. Please install CUDA." 23 | echo "CPU will be used for inference." 24 | fi 25 | 26 | # Create a virtual environment 27 | python3 -m venv .lip-wise 28 | 29 | wait 30 | 31 | # Activate the virtual environment 32 | source .lip-wise/bin/activate 33 | 34 | wait 35 | 36 | # Get Python version 37 | python_version=$(python -c 'import sys; print(".".join(map(str, sys.version_info[:2])))') 38 | 39 | # Install requirements 40 | pip install -r requirements.txt 41 | 42 | # Copy archs 43 | cp archs/* .lip-wise/lib/python${python_version}/site-packages/basicsr/archs/ 44 | 45 | # Run file_check.py 46 | python ./helpers/file_check.py -------------------------------------------------------------------------------- /todo.md: -------------------------------------------------------------------------------- 1 | ## :memo: **TO-DO** List: 2 | 3 | #### URGENT REQUIREMENTS 4 | - [x] Change mask in seamless clone and give it a try 5 | - [x] setup.bat / setup.sh 6 | - [x] create venv 7 | - [x] install requirements inside venv 8 | - [x] CodeFormer arch initialization 9 | - [x] Documentation 10 | 11 | #### PREPROCESS 12 | - [x] Add directory check in inference in the beginning. 13 | - [x] Make preprocessing optimal. 14 | - [x] Clear ram after no_face_filter. 15 | - [x] Make face coordinates reusable: 16 | - [x] Saving facial coordinates as .npy file. 17 | - [x] Alter code to also include eye coordinates. 18 | 19 | #### IMPROVING GAN UPSCALING 20 | - [x] Merge Data Pipeline with preprocessor: 21 | - [x] Remove need to recrop, realign and rewarp the image. 22 | 23 | #### IMPROVING WAV2LIP 24 | - [x] Merge all data Pipeline: 25 | - [x] Remove the need to recrop, realign, renormalizing etc. 26 | - [x] Devise a way to keep frames without face in the video. 27 | - [x] Understand Mels and working of wav2lip model. 28 | 29 | #### OPTIONAL 30 | - [x] Gradio UI 31 | - [x] A tab for Video, Audio and Output. 32 | - [x] A tab for Image, Audio and output. 33 | 34 | #### FURTHER IMPROVEMENTS 35 | - [x] Inference without restorer 36 | - [ ] Model Improvement 37 | - [ ] Implement no_face_filter too 38 | 39 | #### COLAB NOTEBOOK 40 | - [x] Make it intuitive with proper instructions. 41 | - [x] Optimize Inference. 42 | - [x] Implement Checks. 43 | 44 | #### FUTURE PLANS 45 | - [ ] Face and Audio wise Lipsync using face recognition. 46 | - [ ] A separate tab for TTS. 47 | 48 | --- --------------------------------------------------------------------------------