├── README.md ├── SCNet_arch.py ├── basicsr ├── __init__.py ├── archs │ ├── SCNet_arch.py │ ├── __init__.py │ ├── arch_util.py │ └── vgg_arch.py ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── degradations.py │ ├── ffhq_dataset.py │ ├── meta_info │ │ ├── meta_info_DIV2K800sub_GT.txt │ │ ├── meta_info_REDS4_test_GT.txt │ │ ├── meta_info_REDS_GT.txt │ │ ├── meta_info_REDSofficial4_test_GT.txt │ │ ├── meta_info_REDSval_official_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_fast_GT.txt │ │ ├── meta_info_Vimeo90K_test_medium_GT.txt │ │ ├── meta_info_Vimeo90K_test_slow_GT.txt │ │ └── meta_info_Vimeo90K_train_GT.txt │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── realesrgan_dataset.py │ ├── realesrgan_paired_dataset.py │ ├── reds_dataset.py │ ├── single_image_dataset.py │ ├── transforms.py │ ├── video_test_dataset.py │ └── vimeo90k_dataset.py ├── losses │ ├── __init__.py │ ├── basic_loss.py │ ├── gan_loss.py │ └── loss_util.py ├── metrics │ ├── README.md │ ├── README_CN.md │ ├── __init__.py │ ├── fid.py │ ├── metric_util.py │ ├── niqe.py │ ├── niqe_pris_params.npz │ ├── psnr_ssim.py │ └── test_metrics │ │ └── test_psnr_ssim.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── lr_scheduler.py │ └── sr_model.py ├── ops │ ├── __init__.py │ ├── dcn │ │ ├── __init__.py │ │ ├── deform_conv.py │ │ └── src │ │ │ ├── deform_conv_cuda.cpp │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ └── deform_conv_ext.cpp │ ├── fused_act │ │ ├── __init__.py │ │ ├── fused_act.py │ │ └── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ └── upfirdn2d │ │ ├── __init__.py │ │ ├── src │ │ ├── upfirdn2d.cpp │ │ └── upfirdn2d_kernel.cu │ │ └── upfirdn2d.py ├── test.py ├── train.py └── utils │ ├── __init__.py │ ├── color_util.py │ ├── diffjpeg.py │ ├── dist_util.py │ ├── download_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_process_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ ├── plot_util.py │ └── registry.py └── options ├── test └── SCNet │ ├── SCNet-T-x4-PS.yml │ └── SCNet-T-x4.yml └── train └── SCNet └── SCNet-T-x4.yml /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Fully 1x1 Convolutional Network for Lightweight Image Super-Resolution 4 | 5 | [Gang Wu](https://scholar.google.com/citations?user=JSqb7QIAAAAJ), [Junjun Jiang](http://homepage.hit.edu.cn/jiangjunjun), [Kui Jiang](https://github.com/kuijiang94), and [Xianming Liu](http://homepage.hit.edu.cn/xmliu) 6 | 7 | [AIIA Lab](https://aiialabhit.github.io/team/), Harbin Institute of Technology. 8 | 9 | --- 10 | 11 | 12 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-red.svg)](https://arxiv.org/abs/2307.16140) 13 | [![pretrained weights](https://img.shields.io/badge/Models-GoogleDrive-yellow.svg)](https://drive.google.com/drive/folders/1eUqL_8a9DQXZ2uCVyKeWB-6fO1ZdJciG?usp=sharing) 14 | [![pretrained weights](https://img.shields.io/badge/Models-BaiduNetdisk-blue.svg)](https://pan.baidu.com/s/13_syaIXmG3lVnoMgzOS2Ag?pwd=SCSR) 15 | [![visitors](https://hits.sh/github.com/Aitical/SCNet.svg)](https://hits.sh/github.com/Aitical/SCNet/) 16 |
17 | 18 | This repository is the official PyTorch implementation of "Fully 1×1 Convolutional Network for Lightweight Image Super-Resolution". If our work helps your research or work, please cite it. 19 | ``` 20 | @article{wu2023fully, 21 | title={Fully $1\times1$ Convolutional Network for Lightweight Image Super-Resolution}, 22 | author={Gang Wu and Junjun Jiang and Kui Jiang and Xianming Liu}, 23 | year={2023}, 24 | journal={Machine Intelligence Research}, 25 | doi={10.1007/s11633-024-1401-z}, 26 | } 27 | ``` 28 | >Wu, Gang, Junjun Jiang, Kui Jiang and Xianming Liu. “Fully 1×1 Convolutional Network for Lightweight Image Super-Resolution.” Machine Intelligence Research. 29 | 30 | ## News 31 | 32 | - [x] Update implementation codes. 33 | 34 | - [x] Upload pre-trained weights utilized in manuscript. You can download from [Google Drive](https://drive.google.com/drive/folders/1eUqL_8a9DQXZ2uCVyKeWB-6fO1ZdJciG?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/13_syaIXmG3lVnoMgzOS2Ag?pwd=SCSR) with password `SCSR`. 35 | 36 | ## Overview 37 | >Deep models have achieved significant process on single image super-resolution (SISR) tasks, in particular large models with large kernel (3×3 or more). However, the heavy computational footprint of such models prevents their deployment in real-time, resource-constrained environments. Conversely, 1×1 convolutions bring substantial computational efficiency, but struggle with aggregating local spatial representations, an essential capability to SISR models. In response to this dichotomy, we propose to harmonize the merits of both 3×3 and 1×1 kernels, and exploit a great potential for lightweight SISR tasks. Specifically, we propose a simple yet effective fully 1×1 convolutional network, named Shift-Conv-based Network (SCNet). By incorporating a parameter-free spatial-shift operation, it equips the fully 1×1 convolutional network with powerful representation capability while impressive computational efficiency. Extensive experiments demonstrate that SCNets, despite its fully 1×1 convolutional structure, consistently matches or even surpasses the performance of existing lightweight SR models that employ regular convolutions. 38 | 39 |
40 | overview_SCNet.png 41 |
42 | 43 | 44 | ## Train 45 | 46 | All experiments are evaluated based on [BasicSR](https://github.com/XPixelGroup/BasicSR), and we provide a minimal implementation in `SCNet_arch.py`. 47 | 48 | For training, you may refer to the following script: 49 | ``` 50 | python basicsr/train.py -opt options/train/SCNet/SCNet-T-x4.yml 51 | ``` 52 | And for testing: 53 | ``` 54 | python basicsr/test.py -opt options/test/SCNet/SCNet-T-x4.yml 55 | ``` 56 | ## License 57 | This code is licensed under the [Creative Commons Attribution-NonCommercial 4.0 International](https://creativecommons.org/licenses/by-nc/4.0/) for non-commercial use only. Please note that any commercial use of this code requires formal permission prior to use. 58 | 59 | ## Acknowledgement 60 | The codes are based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Thanks for their nice sharing. 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /SCNet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from basicsr.archs.arch_util import default_init_weights, make_layer 7 | 8 | 9 | class Shift8(nn.Module): 10 | def __init__(self, groups=4, stride=1, mode='constant') -> None: 11 | super().__init__() 12 | self.g = groups 13 | self.mode = mode 14 | self.stride = stride 15 | 16 | def forward(self, x): 17 | b, c, h, w = x.shape 18 | out = torch.zeros_like(x) 19 | 20 | pad_x = F.pad(x, pad=[self.stride for _ in range(4)], mode=self.mode) 21 | assert c == self.g * 8 22 | 23 | cx, cy = self.stride, self.stride 24 | stride = self.stride 25 | out[:,0*self.g:1*self.g, :, :] = pad_x[:, 0*self.g:1*self.g, cx-stride:cx-stride+h, cy:cy+w] 26 | out[:,1*self.g:2*self.g, :, :] = pad_x[:, 1*self.g:2*self.g, cx+stride:cx+stride+h, cy:cy+w] 27 | out[:,2*self.g:3*self.g, :, :] = pad_x[:, 2*self.g:3*self.g, cx:cx+h, cy-stride:cy-stride+w] 28 | out[:,3*self.g:4*self.g, :, :] = pad_x[:, 3*self.g:4*self.g, cx:cx+h, cy+stride:cy+stride+w] 29 | 30 | out[:,4*self.g:5*self.g, :, :] = pad_x[:, 4*self.g:5*self.g, cx+stride:cx+stride+h, cy+stride:cy+stride+w] 31 | out[:,5*self.g:6*self.g, :, :] = pad_x[:, 5*self.g:6*self.g, cx+stride:cx+stride+h, cy-stride:cy-stride+w] 32 | out[:,6*self.g:7*self.g, :, :] = pad_x[:, 6*self.g:7*self.g, cx-stride:cx-stride+h, cy+stride:cy+stride+w] 33 | out[:,7*self.g:8*self.g, :, :] = pad_x[:, 7*self.g:8*self.g, cx-stride:cx-stride+h, cy-stride:cy-stride+w] 34 | 35 | #out[:, 8*self.g:, :, :] = pad_x[:, 8*self.g:, cx:cx+h, cy:cy+w] 36 | return out 37 | 38 | 39 | class ResidualBlockShift(nn.Module): 40 | """Residual block without BN. 41 | 42 | It has a style of: 43 | ---Conv-Shift-ReLU-Conv-+- 44 | |________________| 45 | 46 | Args: 47 | num_feat (int): Channel number of intermediate features. 48 | Default: 64. 49 | res_scale (float): Residual scale. Default: 1. 50 | pytorch_init (bool): If set to True, use pytorch default init, 51 | otherwise, use default_init_weights. Default: False. 52 | """ 53 | 54 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): 55 | super(ResidualBlockShift, self).__init__() 56 | self.res_scale = res_scale 57 | self.conv1 = nn.Conv2d(num_feat, num_feat, kernel_size=1) 58 | self.conv2 = nn.Conv2d(num_feat, num_feat, kernel_size=1) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.shift = Shift8(groups=num_feat//8, stride=1) 61 | 62 | if not pytorch_init: 63 | default_init_weights([self.conv1, self.conv2], 0.1) 64 | 65 | def forward(self, x): 66 | identity = x 67 | out = self.conv2(self.relu(self.shift(self.conv1(x)))) 68 | return identity + out * self.res_scale 69 | 70 | 71 | class UpShiftPixelShuffle(nn.Module): 72 | def __init__(self, dim, scale=2) -> None: 73 | super().__init__() 74 | 75 | self.up_layer = nn.Sequential( 76 | nn.Conv2d(dim, dim, kernel_size=1), 77 | nn.LeakyReLU(0.02), 78 | Shift8(groups=dim//8), 79 | nn.Conv2d(dim, dim*scale*scale, kernel_size=1), 80 | nn.PixelShuffle(upscale_factor=scale) 81 | ) 82 | def forward(self, x): 83 | out = self.up_layer(x) 84 | return out 85 | 86 | class UpShiftMLP(nn.Module): 87 | def __init__(self, dim, mode='bilinear', scale=2) -> None: 88 | super().__init__() 89 | 90 | self.up_layer = nn.Sequential( 91 | nn.Upsample(scale_factor=scale, mode=mode, align_corners=False), 92 | nn.Conv2d(dim, dim, kernel_size=1), 93 | nn.LeakyReLU(0.02), 94 | Shift8(groups=dim//8), 95 | nn.Conv2d(dim, dim, kernel_size=1) 96 | ) 97 | def forward(self, x): 98 | out = self.up_layer(x) 99 | return out 100 | 101 | @ARCH_REGISTRY.register() 102 | class SCNet(nn.Module): 103 | """ SCNet (https://arxiv.org/abs/2307.16140) based on the Modified SRResNet. 104 | Args: 105 | num_in_ch (int): Channel number of inputs. Default: 3. 106 | num_out_ch (int): Channel number of outputs. Default: 3. 107 | num_feat (int): Channel number of intermediate features. Default: 64. 108 | num_block (int): Block number in the body network. Default: 16. 109 | upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4. 110 | """ 111 | 112 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4): 113 | super(SCNet, self).__init__() 114 | self.upscale = upscale 115 | 116 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 1) 117 | self.body = make_layer(ResidualBlockShift, num_block, num_feat=num_feat) 118 | 119 | # upsampling 120 | if self.upscale in [2, 3]: 121 | self.upconv1 = UpShiftMLP(num_feat, scale=self.upscale) 122 | 123 | elif self.upscale == 4: 124 | self.upconv1 = UpShiftMLP(num_feat) 125 | self.upconv2 = UpShiftMLP(num_feat) 126 | elif self.upscale == 8: 127 | self.upconv1 = UpShiftMLP(num_feat) 128 | self.upconv2 = UpShiftMLP(num_feat) 129 | self.upconv3 = UpShiftMLP(num_feat) 130 | # freeze infrence 131 | self.pixel_shuffle = nn.Identity() 132 | 133 | self.conv_hr = nn.Conv2d(num_feat, num_feat, kernel_size=1) 134 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, kernel_size=1) 135 | 136 | # activation function 137 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 138 | 139 | # initialization 140 | default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1) 141 | if self.upscale == 4: 142 | default_init_weights(self.upconv2, 0.1) 143 | 144 | def forward(self, x): 145 | feat = self.lrelu(self.conv_first(x)) 146 | out = self.body(feat) 147 | 148 | if self.upscale == 4: 149 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 150 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 151 | elif self.upscale in [2, 3]: 152 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 153 | elif self.upscale == 8: 154 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 155 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 156 | out = self.lrelu(self.pixel_shuffle(self.upconv3(out))) 157 | 158 | out = self.conv_last(self.lrelu(self.conv_hr(out))) 159 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 160 | out += base 161 | return out 162 | 163 | if __name__ == '__main__': 164 | model = SCNet(upscale=4) 165 | load_dict = torch.load('SCNet-T-x4.pth') 166 | model.load_state_dict(load_dict['params']) 167 | -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .test import * 10 | from .train import * 11 | from .utils import * 12 | -------------------------------------------------------------------------------- /basicsr/archs/SCNet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from basicsr.archs.arch_util import default_init_weights, make_layer 7 | 8 | 9 | class Shift8(nn.Module): 10 | def __init__(self, groups=4, stride=1, mode="constant") -> None: 11 | super().__init__() 12 | self.g = groups 13 | self.mode = mode 14 | self.stride = stride 15 | 16 | def forward(self, x): 17 | b, c, h, w = x.shape 18 | out = torch.zeros_like(x) 19 | 20 | pad_x = F.pad(x, pad=[self.stride for _ in range(4)], mode=self.mode) 21 | assert c == self.g * 8 22 | 23 | cx, cy = self.stride, self.stride 24 | stride = self.stride 25 | out[:, 0 * self.g : 1 * self.g, :, :] = pad_x[ 26 | :, 0 * self.g : 1 * self.g, cx - stride : cx - stride + h, cy : cy + w 27 | ] 28 | out[:, 1 * self.g : 2 * self.g, :, :] = pad_x[ 29 | :, 1 * self.g : 2 * self.g, cx + stride : cx + stride + h, cy : cy + w 30 | ] 31 | out[:, 2 * self.g : 3 * self.g, :, :] = pad_x[ 32 | :, 2 * self.g : 3 * self.g, cx : cx + h, cy - stride : cy - stride + w 33 | ] 34 | out[:, 3 * self.g : 4 * self.g, :, :] = pad_x[ 35 | :, 3 * self.g : 4 * self.g, cx : cx + h, cy + stride : cy + stride + w 36 | ] 37 | 38 | out[:, 4 * self.g : 5 * self.g, :, :] = pad_x[ 39 | :, 40 | 4 * self.g : 5 * self.g, 41 | cx + stride : cx + stride + h, 42 | cy + stride : cy + stride + w, 43 | ] 44 | out[:, 5 * self.g : 6 * self.g, :, :] = pad_x[ 45 | :, 46 | 5 * self.g : 6 * self.g, 47 | cx + stride : cx + stride + h, 48 | cy - stride : cy - stride + w, 49 | ] 50 | out[:, 6 * self.g : 7 * self.g, :, :] = pad_x[ 51 | :, 52 | 6 * self.g : 7 * self.g, 53 | cx - stride : cx - stride + h, 54 | cy + stride : cy + stride + w, 55 | ] 56 | out[:, 7 * self.g : 8 * self.g, :, :] = pad_x[ 57 | :, 58 | 7 * self.g : 8 * self.g, 59 | cx - stride : cx - stride + h, 60 | cy - stride : cy - stride + w, 61 | ] 62 | 63 | # out[:, 8*self.g:, :, :] = pad_x[:, 8*self.g:, cx:cx+h, cy:cy+w] 64 | return out 65 | 66 | 67 | class ResidualBlockShift(nn.Module): 68 | """Residual block without BN. 69 | 70 | It has a style of: 71 | ---Conv-Shift-ReLU-Conv-+- 72 | |________________| 73 | 74 | Args: 75 | num_feat (int): Channel number of intermediate features. 76 | Default: 64. 77 | res_scale (float): Residual scale. Default: 1. 78 | pytorch_init (bool): If set to True, use pytorch default init, 79 | otherwise, use default_init_weights. Default: False. 80 | """ 81 | 82 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): 83 | super(ResidualBlockShift, self).__init__() 84 | self.res_scale = res_scale 85 | self.conv1 = nn.Conv2d(num_feat, num_feat, kernel_size=1) 86 | self.conv2 = nn.Conv2d(num_feat, num_feat, kernel_size=1) 87 | self.relu = nn.ReLU(inplace=True) 88 | self.shift = Shift8(groups=num_feat // 8, stride=1) 89 | 90 | if not pytorch_init: 91 | default_init_weights([self.conv1, self.conv2], 0.1) 92 | 93 | def forward(self, x): 94 | identity = x 95 | out = self.conv2(self.relu(self.shift(self.conv1(x)))) 96 | return identity + out * self.res_scale 97 | 98 | 99 | class UpShiftPixelShuffle(nn.Module): 100 | def __init__(self, dim, scale=2) -> None: 101 | super().__init__() 102 | 103 | self.up_layer = nn.Sequential( 104 | nn.Conv2d(dim, dim, kernel_size=1), 105 | nn.LeakyReLU(0.02), 106 | Shift8(groups=dim // 8), 107 | nn.Conv2d(dim, dim * scale * scale, kernel_size=1), 108 | nn.PixelShuffle(upscale_factor=scale), 109 | ) 110 | 111 | def forward(self, x): 112 | out = self.up_layer(x) 113 | return out 114 | 115 | 116 | class UpShiftMLP(nn.Module): 117 | def __init__(self, dim, mode="bilinear", scale=2) -> None: 118 | super().__init__() 119 | 120 | self.up_layer = nn.Sequential( 121 | nn.Upsample(scale_factor=scale, mode=mode, align_corners=False), 122 | nn.Conv2d(dim, dim, kernel_size=1), 123 | nn.LeakyReLU(0.02), 124 | Shift8(groups=dim // 8), 125 | nn.Conv2d(dim, dim, kernel_size=1), 126 | ) 127 | 128 | def forward(self, x): 129 | out = self.up_layer(x) 130 | return out 131 | 132 | 133 | @ARCH_REGISTRY.register() 134 | class SCNet(nn.Module): 135 | """SCNet (https://arxiv.org/abs/2307.16140) based on the Modified SRResNet. 136 | Args: 137 | num_in_ch (int): Channel number of inputs. Default: 3. 138 | num_out_ch (int): Channel number of outputs. Default: 3. 139 | num_feat (int): Channel number of intermediate features. Default: 64. 140 | num_block (int): Block number in the body network. Default: 16. 141 | upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4. 142 | use_pixelshuffle (bool): Upsampling with PixelShuffle operation. 143 | """ 144 | 145 | def __init__( 146 | self, 147 | num_in_ch=3, 148 | num_out_ch=3, 149 | num_feat=64, 150 | num_block=16, 151 | upscale=4, 152 | use_pixelshuffle=False, 153 | ): 154 | super(SCNet, self).__init__() 155 | self.upscale = upscale 156 | 157 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 1) 158 | self.body = make_layer(ResidualBlockShift, num_block, num_feat=num_feat) 159 | 160 | UpLayer = UpShiftPixelShuffle if use_pixelshuffle else UpShiftMLP 161 | # upsampling 162 | if self.upscale in [2, 3]: 163 | self.upconv1 = UpLayer(num_feat, scale=self.upscale) 164 | 165 | elif self.upscale == 4: 166 | self.upconv1 = UpLayer(num_feat) 167 | self.upconv2 = UpLayer(num_feat) 168 | elif self.upscale == 8: 169 | self.upconv1 = UpLayer(num_feat) 170 | self.upconv2 = UpLayer(num_feat) 171 | self.upconv3 = UpLayer(num_feat) 172 | # freeze infrence 173 | self.pixel_shuffle = nn.Identity() 174 | 175 | self.conv_hr = nn.Conv2d(num_feat, num_feat, kernel_size=1) 176 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, kernel_size=1) 177 | 178 | # activation function 179 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 180 | 181 | # initialization 182 | default_init_weights( 183 | [self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1 184 | ) 185 | if self.upscale == 4: 186 | default_init_weights(self.upconv2, 0.1) 187 | 188 | def forward(self, x): 189 | feat = self.lrelu(self.conv_first(x)) 190 | out = self.body(feat) 191 | 192 | if self.upscale == 4: 193 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 194 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 195 | elif self.upscale in [2, 3]: 196 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 197 | elif self.upscale == 8: 198 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 199 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 200 | out = self.lrelu(self.pixel_shuffle(self.upconv3(out))) 201 | 202 | out = self.conv_last(self.lrelu(self.conv_hr(out))) 203 | base = F.interpolate( 204 | x, scale_factor=self.upscale, mode="bilinear", align_corners=False 205 | ) 206 | out += base 207 | return out 208 | 209 | 210 | if __name__ == "__main__": 211 | model = SCNet(upscale=4) 212 | load_dict = torch.load("SCNet-T-x4.pth") 213 | model.load_state_dict(load_dict["params"]) 214 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with '_arch.py' 12 | arch_folder = osp.dirname(osp.abspath(__file__)) 13 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 14 | # import all the arch modules 15 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 16 | 17 | 18 | def build_network(opt): 19 | opt = deepcopy(opt) 20 | network_type = opt.pop('type') 21 | net = ARCH_REGISTRY.get(network_type)(**opt) 22 | logger = get_root_logger() 23 | logger.info(f'Network [{net.__class__.__name__}] is created.') 24 | return net 25 | -------------------------------------------------------------------------------- /basicsr/archs/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch import nn as nn 5 | from torchvision.models import vgg as vgg 6 | 7 | from basicsr.utils.registry import ARCH_REGISTRY 8 | 9 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' 10 | NAMES = { 11 | 'vgg11': [ 12 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 13 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 14 | 'pool5' 15 | ], 16 | 'vgg13': [ 17 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 18 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 19 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' 20 | ], 21 | 'vgg16': [ 22 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 23 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 24 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 25 | 'pool5' 26 | ], 27 | 'vgg19': [ 28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 29 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 30 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 31 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' 32 | ] 33 | } 34 | 35 | 36 | def insert_bn(names): 37 | """Insert bn layer after each conv. 38 | 39 | Args: 40 | names (list): The list of layer names. 41 | 42 | Returns: 43 | list: The list of layer names with bn layers. 44 | """ 45 | names_bn = [] 46 | for name in names: 47 | names_bn.append(name) 48 | if 'conv' in name: 49 | position = name.replace('conv', '') 50 | names_bn.append('bn' + position) 51 | return names_bn 52 | 53 | 54 | @ARCH_REGISTRY.register() 55 | class VGGFeatureExtractor(nn.Module): 56 | """VGG network for feature extraction. 57 | 58 | In this implementation, we allow users to choose whether use normalization 59 | in the input feature and the type of vgg network. Note that the pretrained 60 | path must fit the vgg type. 61 | 62 | Args: 63 | layer_name_list (list[str]): Forward function returns the corresponding 64 | features according to the layer_name_list. 65 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 66 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 67 | use_input_norm (bool): If True, normalize the input image. Importantly, 68 | the input feature must in the range [0, 1]. Default: True. 69 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 70 | Default: False. 71 | requires_grad (bool): If true, the parameters of VGG network will be 72 | optimized. Default: False. 73 | remove_pooling (bool): If true, the max pooling operations in VGG net 74 | will be removed. Default: False. 75 | pooling_stride (int): The stride of max pooling operation. Default: 2. 76 | """ 77 | 78 | def __init__(self, 79 | layer_name_list, 80 | vgg_type='vgg19', 81 | use_input_norm=True, 82 | range_norm=False, 83 | requires_grad=False, 84 | remove_pooling=False, 85 | pooling_stride=2): 86 | super(VGGFeatureExtractor, self).__init__() 87 | 88 | self.layer_name_list = layer_name_list 89 | self.use_input_norm = use_input_norm 90 | self.range_norm = range_norm 91 | 92 | self.names = NAMES[vgg_type.replace('_bn', '')] 93 | if 'bn' in vgg_type: 94 | self.names = insert_bn(self.names) 95 | 96 | # only borrow layers that will be used to avoid unused params 97 | max_idx = 0 98 | for v in layer_name_list: 99 | idx = self.names.index(v) 100 | if idx > max_idx: 101 | max_idx = idx 102 | 103 | if os.path.exists(VGG_PRETRAIN_PATH): 104 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 105 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) 106 | vgg_net.load_state_dict(state_dict) 107 | else: 108 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 109 | 110 | features = vgg_net.features[:max_idx + 1] 111 | 112 | modified_net = OrderedDict() 113 | for k, v in zip(self.names, features): 114 | if 'pool' in k: 115 | # if remove_pooling is true, pooling operation will be removed 116 | if remove_pooling: 117 | continue 118 | else: 119 | # in some cases, we may want to change the default stride 120 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 121 | else: 122 | modified_net[k] = v 123 | 124 | self.vgg_net = nn.Sequential(modified_net) 125 | 126 | if not requires_grad: 127 | self.vgg_net.eval() 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | else: 131 | self.vgg_net.train() 132 | for param in self.parameters(): 133 | param.requires_grad = True 134 | 135 | if self.use_input_norm: 136 | # the mean is for image with range [0, 1] 137 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 138 | # the std is for image with range [0, 1] 139 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 140 | 141 | def forward(self, x): 142 | """Forward function. 143 | 144 | Args: 145 | x (Tensor): Input tensor with shape (n, c, h, w). 146 | 147 | Returns: 148 | Tensor: Forward results. 149 | """ 150 | if self.range_norm: 151 | x = (x + 1) / 2 152 | if self.use_input_norm: 153 | x = (x - self.mean) / self.std 154 | 155 | output = {} 156 | for key, layer in self.vgg_net._modules.items(): 157 | x = layer(x) 158 | if key in self.layer_name_list: 159 | output[key] = x.clone() 160 | 161 | return output -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from os import path as osp 4 | from torch.utils import data as data 5 | from torchvision.transforms.functional import normalize 6 | 7 | from basicsr.data.transforms import augment 8 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 9 | from basicsr.utils.registry import DATASET_REGISTRY 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class FFHQDataset(data.Dataset): 14 | """FFHQ dataset for StyleGAN. 15 | 16 | Args: 17 | opt (dict): Config for train datasets. It contains the following keys: 18 | dataroot_gt (str): Data root path for gt. 19 | io_backend (dict): IO backend type and other kwarg. 20 | mean (list | tuple): Image mean. 21 | std (list | tuple): Image std. 22 | use_hflip (bool): Whether to horizontally flip. 23 | 24 | """ 25 | 26 | def __init__(self, opt): 27 | super(FFHQDataset, self).__init__() 28 | self.opt = opt 29 | # file client (io backend) 30 | self.file_client = None 31 | self.io_backend_opt = opt['io_backend'] 32 | 33 | self.gt_folder = opt['dataroot_gt'] 34 | self.mean = opt['mean'] 35 | self.std = opt['std'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = self.gt_folder 39 | if not self.gt_folder.endswith('.lmdb'): 40 | raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 41 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 42 | self.paths = [line.split('.')[0] for line in fin] 43 | else: 44 | # FFHQ has 70000 images in total 45 | self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)] 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | # avoid errors caused by high latency in reading files 54 | retry = 3 55 | while retry > 0: 56 | try: 57 | img_bytes = self.file_client.get(gt_path) 58 | except Exception as e: 59 | logger = get_root_logger() 60 | logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') 61 | # change another file to read 62 | index = random.randint(0, self.__len__()) 63 | gt_path = self.paths[index] 64 | time.sleep(1) # sleep 1s for occasional server congestion 65 | else: 66 | break 67 | finally: 68 | retry -= 1 69 | img_gt = imfrombytes(img_bytes, float32=True) 70 | 71 | # random horizontal flip 72 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 73 | # BGR to RGB, HWC to CHW, numpy to tensor 74 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 75 | # normalize 76 | normalize(img_gt, self.mean, self.std, inplace=True) 77 | return {'gt': img_gt, 'gt_path': gt_path} 78 | 79 | def __len__(self): 80 | return len(self.paths) 81 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 011 100 (720,1280,3) 3 | 015 100 (720,1280,3) 4 | 020 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 001 100 (720,1280,3) 3 | 002 100 (720,1280,3) 4 | 003 100 (720,1280,3) 5 | 004 100 (720,1280,3) 6 | 005 100 (720,1280,3) 7 | 006 100 (720,1280,3) 8 | 007 100 (720,1280,3) 9 | 008 100 (720,1280,3) 10 | 009 100 (720,1280,3) 11 | 010 100 (720,1280,3) 12 | 011 100 (720,1280,3) 13 | 012 100 (720,1280,3) 14 | 013 100 (720,1280,3) 15 | 014 100 (720,1280,3) 16 | 015 100 (720,1280,3) 17 | 016 100 (720,1280,3) 18 | 017 100 (720,1280,3) 19 | 018 100 (720,1280,3) 20 | 019 100 (720,1280,3) 21 | 020 100 (720,1280,3) 22 | 021 100 (720,1280,3) 23 | 022 100 (720,1280,3) 24 | 023 100 (720,1280,3) 25 | 024 100 (720,1280,3) 26 | 025 100 (720,1280,3) 27 | 026 100 (720,1280,3) 28 | 027 100 (720,1280,3) 29 | 028 100 (720,1280,3) 30 | 029 100 (720,1280,3) 31 | 030 100 (720,1280,3) 32 | 031 100 (720,1280,3) 33 | 032 100 (720,1280,3) 34 | 033 100 (720,1280,3) 35 | 034 100 (720,1280,3) 36 | 035 100 (720,1280,3) 37 | 036 100 (720,1280,3) 38 | 037 100 (720,1280,3) 39 | 038 100 (720,1280,3) 40 | 039 100 (720,1280,3) 41 | 040 100 (720,1280,3) 42 | 041 100 (720,1280,3) 43 | 042 100 (720,1280,3) 44 | 043 100 (720,1280,3) 45 | 044 100 (720,1280,3) 46 | 045 100 (720,1280,3) 47 | 046 100 (720,1280,3) 48 | 047 100 (720,1280,3) 49 | 048 100 (720,1280,3) 50 | 049 100 (720,1280,3) 51 | 050 100 (720,1280,3) 52 | 051 100 (720,1280,3) 53 | 052 100 (720,1280,3) 54 | 053 100 (720,1280,3) 55 | 054 100 (720,1280,3) 56 | 055 100 (720,1280,3) 57 | 056 100 (720,1280,3) 58 | 057 100 (720,1280,3) 59 | 058 100 (720,1280,3) 60 | 059 100 (720,1280,3) 61 | 060 100 (720,1280,3) 62 | 061 100 (720,1280,3) 63 | 062 100 (720,1280,3) 64 | 063 100 (720,1280,3) 65 | 064 100 (720,1280,3) 66 | 065 100 (720,1280,3) 67 | 066 100 (720,1280,3) 68 | 067 100 (720,1280,3) 69 | 068 100 (720,1280,3) 70 | 069 100 (720,1280,3) 71 | 070 100 (720,1280,3) 72 | 071 100 (720,1280,3) 73 | 072 100 (720,1280,3) 74 | 073 100 (720,1280,3) 75 | 074 100 (720,1280,3) 76 | 075 100 (720,1280,3) 77 | 076 100 (720,1280,3) 78 | 077 100 (720,1280,3) 79 | 078 100 (720,1280,3) 80 | 079 100 (720,1280,3) 81 | 080 100 (720,1280,3) 82 | 081 100 (720,1280,3) 83 | 082 100 (720,1280,3) 84 | 083 100 (720,1280,3) 85 | 084 100 (720,1280,3) 86 | 085 100 (720,1280,3) 87 | 086 100 (720,1280,3) 88 | 087 100 (720,1280,3) 89 | 088 100 (720,1280,3) 90 | 089 100 (720,1280,3) 91 | 090 100 (720,1280,3) 92 | 091 100 (720,1280,3) 93 | 092 100 (720,1280,3) 94 | 093 100 (720,1280,3) 95 | 094 100 (720,1280,3) 96 | 095 100 (720,1280,3) 97 | 096 100 (720,1280,3) 98 | 097 100 (720,1280,3) 99 | 098 100 (720,1280,3) 100 | 099 100 (720,1280,3) 101 | 100 100 (720,1280,3) 102 | 101 100 (720,1280,3) 103 | 102 100 (720,1280,3) 104 | 103 100 (720,1280,3) 105 | 104 100 (720,1280,3) 106 | 105 100 (720,1280,3) 107 | 106 100 (720,1280,3) 108 | 107 100 (720,1280,3) 109 | 108 100 (720,1280,3) 110 | 109 100 (720,1280,3) 111 | 110 100 (720,1280,3) 112 | 111 100 (720,1280,3) 113 | 112 100 (720,1280,3) 114 | 113 100 (720,1280,3) 115 | 114 100 (720,1280,3) 116 | 115 100 (720,1280,3) 117 | 116 100 (720,1280,3) 118 | 117 100 (720,1280,3) 119 | 118 100 (720,1280,3) 120 | 119 100 (720,1280,3) 121 | 120 100 (720,1280,3) 122 | 121 100 (720,1280,3) 123 | 122 100 (720,1280,3) 124 | 123 100 (720,1280,3) 125 | 124 100 (720,1280,3) 126 | 125 100 (720,1280,3) 127 | 126 100 (720,1280,3) 128 | 127 100 (720,1280,3) 129 | 128 100 (720,1280,3) 130 | 129 100 (720,1280,3) 131 | 130 100 (720,1280,3) 132 | 131 100 (720,1280,3) 133 | 132 100 (720,1280,3) 134 | 133 100 (720,1280,3) 135 | 134 100 (720,1280,3) 136 | 135 100 (720,1280,3) 137 | 136 100 (720,1280,3) 138 | 137 100 (720,1280,3) 139 | 138 100 (720,1280,3) 140 | 139 100 (720,1280,3) 141 | 140 100 (720,1280,3) 142 | 141 100 (720,1280,3) 143 | 142 100 (720,1280,3) 144 | 143 100 (720,1280,3) 145 | 144 100 (720,1280,3) 146 | 145 100 (720,1280,3) 147 | 146 100 (720,1280,3) 148 | 147 100 (720,1280,3) 149 | 148 100 (720,1280,3) 150 | 149 100 (720,1280,3) 151 | 150 100 (720,1280,3) 152 | 151 100 (720,1280,3) 153 | 152 100 (720,1280,3) 154 | 153 100 (720,1280,3) 155 | 154 100 (720,1280,3) 156 | 155 100 (720,1280,3) 157 | 156 100 (720,1280,3) 158 | 157 100 (720,1280,3) 159 | 158 100 (720,1280,3) 160 | 159 100 (720,1280,3) 161 | 160 100 (720,1280,3) 162 | 161 100 (720,1280,3) 163 | 162 100 (720,1280,3) 164 | 163 100 (720,1280,3) 165 | 164 100 (720,1280,3) 166 | 165 100 (720,1280,3) 167 | 166 100 (720,1280,3) 168 | 167 100 (720,1280,3) 169 | 168 100 (720,1280,3) 170 | 169 100 (720,1280,3) 171 | 170 100 (720,1280,3) 172 | 171 100 (720,1280,3) 173 | 172 100 (720,1280,3) 174 | 173 100 (720,1280,3) 175 | 174 100 (720,1280,3) 176 | 175 100 (720,1280,3) 177 | 176 100 (720,1280,3) 178 | 177 100 (720,1280,3) 179 | 178 100 (720,1280,3) 180 | 179 100 (720,1280,3) 181 | 180 100 (720,1280,3) 182 | 181 100 (720,1280,3) 183 | 182 100 (720,1280,3) 184 | 183 100 (720,1280,3) 185 | 184 100 (720,1280,3) 186 | 185 100 (720,1280,3) 187 | 186 100 (720,1280,3) 188 | 187 100 (720,1280,3) 189 | 188 100 (720,1280,3) 190 | 189 100 (720,1280,3) 191 | 190 100 (720,1280,3) 192 | 191 100 (720,1280,3) 193 | 192 100 (720,1280,3) 194 | 193 100 (720,1280,3) 195 | 194 100 (720,1280,3) 196 | 195 100 (720,1280,3) 197 | 196 100 (720,1280,3) 198 | 197 100 (720,1280,3) 199 | 198 100 (720,1280,3) 200 | 199 100 (720,1280,3) 201 | 200 100 (720,1280,3) 202 | 201 100 (720,1280,3) 203 | 202 100 (720,1280,3) 204 | 203 100 (720,1280,3) 205 | 204 100 (720,1280,3) 206 | 205 100 (720,1280,3) 207 | 206 100 (720,1280,3) 208 | 207 100 (720,1280,3) 209 | 208 100 (720,1280,3) 210 | 209 100 (720,1280,3) 211 | 210 100 (720,1280,3) 212 | 211 100 (720,1280,3) 213 | 212 100 (720,1280,3) 214 | 213 100 (720,1280,3) 215 | 214 100 (720,1280,3) 216 | 215 100 (720,1280,3) 217 | 216 100 (720,1280,3) 218 | 217 100 (720,1280,3) 219 | 218 100 (720,1280,3) 220 | 219 100 (720,1280,3) 221 | 220 100 (720,1280,3) 222 | 221 100 (720,1280,3) 223 | 222 100 (720,1280,3) 224 | 223 100 (720,1280,3) 225 | 224 100 (720,1280,3) 226 | 225 100 (720,1280,3) 227 | 226 100 (720,1280,3) 228 | 227 100 (720,1280,3) 229 | 228 100 (720,1280,3) 230 | 229 100 (720,1280,3) 231 | 230 100 (720,1280,3) 232 | 231 100 (720,1280,3) 233 | 232 100 (720,1280,3) 234 | 233 100 (720,1280,3) 235 | 234 100 (720,1280,3) 236 | 235 100 (720,1280,3) 237 | 236 100 (720,1280,3) 238 | 237 100 (720,1280,3) 239 | 238 100 (720,1280,3) 240 | 239 100 (720,1280,3) 241 | 240 100 (720,1280,3) 242 | 241 100 (720,1280,3) 243 | 242 100 (720,1280,3) 244 | 243 100 (720,1280,3) 245 | 244 100 (720,1280,3) 246 | 245 100 (720,1280,3) 247 | 246 100 (720,1280,3) 248 | 247 100 (720,1280,3) 249 | 248 100 (720,1280,3) 250 | 249 100 (720,1280,3) 251 | 250 100 (720,1280,3) 252 | 251 100 (720,1280,3) 253 | 252 100 (720,1280,3) 254 | 253 100 (720,1280,3) 255 | 254 100 (720,1280,3) 256 | 255 100 (720,1280,3) 257 | 256 100 (720,1280,3) 258 | 257 100 (720,1280,3) 259 | 258 100 (720,1280,3) 260 | 259 100 (720,1280,3) 261 | 260 100 (720,1280,3) 262 | 261 100 (720,1280,3) 263 | 262 100 (720,1280,3) 264 | 263 100 (720,1280,3) 265 | 264 100 (720,1280,3) 266 | 265 100 (720,1280,3) 267 | 266 100 (720,1280,3) 268 | 267 100 (720,1280,3) 269 | 268 100 (720,1280,3) 270 | 269 100 (720,1280,3) 271 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 246 100 (720,1280,3) 4 | 257 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 242 100 (720,1280,3) 4 | 243 100 (720,1280,3) 5 | 244 100 (720,1280,3) 6 | 245 100 (720,1280,3) 7 | 246 100 (720,1280,3) 8 | 247 100 (720,1280,3) 9 | 248 100 (720,1280,3) 10 | 249 100 (720,1280,3) 11 | 250 100 (720,1280,3) 12 | 251 100 (720,1280,3) 13 | 252 100 (720,1280,3) 14 | 253 100 (720,1280,3) 15 | 254 100 (720,1280,3) 16 | 255 100 (720,1280,3) 17 | 256 100 (720,1280,3) 18 | 257 100 (720,1280,3) 19 | 258 100 (720,1280,3) 20 | 259 100 (720,1280,3) 21 | 260 100 (720,1280,3) 22 | 261 100 (720,1280,3) 23 | 262 100 (720,1280,3) 24 | 263 100 (720,1280,3) 25 | 264 100 (720,1280,3) 26 | 265 100 (720,1280,3) 27 | 266 100 (720,1280,3) 28 | 267 100 (720,1280,3) 29 | 268 100 (720,1280,3) 30 | 269 100 (720,1280,3) 31 | -------------------------------------------------------------------------------- /basicsr/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | 4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file 5 | from basicsr.data.transforms import augment, paired_random_crop 6 | from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class PairedImageDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 15 | 16 | There are three modes: 17 | 18 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. 19 | 2. **meta_info_file**: Use meta information file to generate paths. \ 20 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 21 | 3. **folder**: Scan folders to generate paths. The rest. 22 | 23 | Args: 24 | opt (dict): Config for train datasets. It contains the following keys: 25 | dataroot_gt (str): Data root path for gt. 26 | dataroot_lq (str): Data root path for lq. 27 | meta_info_file (str): Path for meta information file. 28 | io_backend (dict): IO backend type and other kwarg. 29 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 30 | Default: '{}'. 31 | gt_size (int): Cropped patched size for gt patches. 32 | use_hflip (bool): Use horizontal flips. 33 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 34 | scale (bool): Scale, which will be added automatically. 35 | phase (str): 'train' or 'val'. 36 | """ 37 | 38 | def __init__(self, opt): 39 | super(PairedImageDataset, self).__init__() 40 | self.opt = opt 41 | # file client (io backend) 42 | self.file_client = None 43 | self.io_backend_opt = opt['io_backend'] 44 | self.mean = opt['mean'] if 'mean' in opt else None 45 | self.std = opt['std'] if 'std' in opt else None 46 | 47 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 48 | if 'filename_tmpl' in opt: 49 | self.filename_tmpl = opt['filename_tmpl'] 50 | else: 51 | self.filename_tmpl = '{}' 52 | 53 | if self.io_backend_opt['type'] == 'lmdb': 54 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 55 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 56 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 57 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: 58 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], 59 | self.opt['meta_info_file'], self.filename_tmpl) 60 | else: 61 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 62 | 63 | def __getitem__(self, index): 64 | if self.file_client is None: 65 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 66 | 67 | scale = self.opt['scale'] 68 | 69 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 70 | # image range: [0, 1], float32. 71 | gt_path = self.paths[index]['gt_path'] 72 | img_bytes = self.file_client.get(gt_path, 'gt') 73 | img_gt = imfrombytes(img_bytes, float32=True) 74 | lq_path = self.paths[index]['lq_path'] 75 | img_bytes = self.file_client.get(lq_path, 'lq') 76 | img_lq = imfrombytes(img_bytes, float32=True) 77 | 78 | # augmentation for training 79 | if self.opt['phase'] == 'train': 80 | gt_size = self.opt['gt_size'] 81 | # random crop 82 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 83 | # flip, rotation 84 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 85 | 86 | # color space transform 87 | if 'color' in self.opt and self.opt['color'] == 'y': 88 | img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] 89 | img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] 90 | 91 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets 92 | # TODO: It is better to update the datasets, rather than force to crop 93 | if self.opt['phase'] != 'train': 94 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] 95 | 96 | # BGR to RGB, HWC to CHW, numpy to tensor 97 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 98 | # normalize 99 | if self.mean is not None or self.std is not None: 100 | normalize(img_lq, self.mean, self.std, inplace=True) 101 | normalize(img_gt, self.mean, self.std, inplace=True) 102 | 103 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 104 | 105 | def __len__(self): 106 | return len(self.paths) 107 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 11 | 12 | Args: 13 | generator: Python generator. 14 | num_prefetch_queue (int): Number of prefetch queue. 15 | """ 16 | 17 | def __init__(self, generator, num_prefetch_queue): 18 | threading.Thread.__init__(self) 19 | self.queue = Queue.Queue(num_prefetch_queue) 20 | self.generator = generator 21 | self.daemon = True 22 | self.start() 23 | 24 | def run(self): 25 | for item in self.generator: 26 | self.queue.put(item) 27 | self.queue.put(None) 28 | 29 | def __next__(self): 30 | next_item = self.queue.get() 31 | if next_item is None: 32 | raise StopIteration 33 | return next_item 34 | 35 | def __iter__(self): 36 | return self 37 | 38 | 39 | class PrefetchDataLoader(DataLoader): 40 | """Prefetch version of dataloader. 41 | 42 | Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 43 | 44 | TODO: 45 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 46 | ddp. 47 | 48 | Args: 49 | num_prefetch_queue (int): Number of prefetch queue. 50 | kwargs (dict): Other arguments for dataloader. 51 | """ 52 | 53 | def __init__(self, num_prefetch_queue, **kwargs): 54 | self.num_prefetch_queue = num_prefetch_queue 55 | super(PrefetchDataLoader, self).__init__(**kwargs) 56 | 57 | def __iter__(self): 58 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 59 | 60 | 61 | class CPUPrefetcher(): 62 | """CPU prefetcher. 63 | 64 | Args: 65 | loader: Dataloader. 66 | """ 67 | 68 | def __init__(self, loader): 69 | self.ori_loader = loader 70 | self.loader = iter(loader) 71 | 72 | def next(self): 73 | try: 74 | return next(self.loader) 75 | except StopIteration: 76 | return None 77 | 78 | def reset(self): 79 | self.loader = iter(self.ori_loader) 80 | 81 | 82 | class CUDAPrefetcher(): 83 | """CUDA prefetcher. 84 | 85 | Reference: https://github.com/NVIDIA/apex/issues/304# 86 | 87 | It may consume more GPU memory. 88 | 89 | Args: 90 | loader: Dataloader. 91 | opt (dict): Options. 92 | """ 93 | 94 | def __init__(self, loader, opt): 95 | self.ori_loader = loader 96 | self.loader = iter(loader) 97 | self.opt = opt 98 | self.stream = torch.cuda.Stream() 99 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 100 | self.preload() 101 | 102 | def preload(self): 103 | try: 104 | self.batch = next(self.loader) # self.batch is a dict 105 | except StopIteration: 106 | self.batch = None 107 | return None 108 | # put tensors to gpu 109 | with torch.cuda.stream(self.stream): 110 | for k, v in self.batch.items(): 111 | if torch.is_tensor(v): 112 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 113 | 114 | def next(self): 115 | torch.cuda.current_stream().wait_stream(self.stream) 116 | batch = self.batch 117 | self.preload() 118 | return batch 119 | 120 | def reset(self): 121 | self.loader = iter(self.ori_loader) 122 | self.preload() 123 | -------------------------------------------------------------------------------- /basicsr/data/realesrgan_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, imfrombytes, img2tensor 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | 10 | 11 | @DATASET_REGISTRY.register(suffix='basicsr') 12 | class RealESRGANPairedDataset(data.Dataset): 13 | """Paired image dataset for image restoration. 14 | 15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 16 | 17 | There are three modes: 18 | 19 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. 20 | 2. **meta_info_file**: Use meta information file to generate paths. \ 21 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 22 | 3. **folder**: Scan folders to generate paths. The rest. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_gt (str): Data root path for gt. 27 | dataroot_lq (str): Data root path for lq. 28 | meta_info (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 31 | Default: '{}'. 32 | gt_size (int): Cropped patched size for gt patches. 33 | use_hflip (bool): Use horizontal flips. 34 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 35 | scale (bool): Scale, which will be added automatically. 36 | phase (str): 'train' or 'val'. 37 | """ 38 | 39 | def __init__(self, opt): 40 | super(RealESRGANPairedDataset, self).__init__() 41 | self.opt = opt 42 | self.file_client = None 43 | self.io_backend_opt = opt['io_backend'] 44 | # mean and std for normalizing the input images 45 | self.mean = opt['mean'] if 'mean' in opt else None 46 | self.std = opt['std'] if 'std' in opt else None 47 | 48 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 49 | self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' 50 | 51 | # file client (lmdb io backend) 52 | if self.io_backend_opt['type'] == 'lmdb': 53 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 54 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 55 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 56 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: 57 | # disk backend with meta_info 58 | # Each line in the meta_info describes the relative path to an image 59 | with open(self.opt['meta_info']) as fin: 60 | paths = [line.strip() for line in fin] 61 | self.paths = [] 62 | for path in paths: 63 | gt_path, lq_path = path.split(', ') 64 | gt_path = os.path.join(self.gt_folder, gt_path) 65 | lq_path = os.path.join(self.lq_folder, lq_path) 66 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) 67 | else: 68 | # disk backend 69 | # it will scan the whole folder to get meta info 70 | # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file 71 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 72 | 73 | def __getitem__(self, index): 74 | if self.file_client is None: 75 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 76 | 77 | scale = self.opt['scale'] 78 | 79 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 80 | # image range: [0, 1], float32. 81 | gt_path = self.paths[index]['gt_path'] 82 | img_bytes = self.file_client.get(gt_path, 'gt') 83 | img_gt = imfrombytes(img_bytes, float32=True) 84 | lq_path = self.paths[index]['lq_path'] 85 | img_bytes = self.file_client.get(lq_path, 'lq') 86 | img_lq = imfrombytes(img_bytes, float32=True) 87 | 88 | # augmentation for training 89 | if self.opt['phase'] == 'train': 90 | gt_size = self.opt['gt_size'] 91 | # random crop 92 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 93 | # flip, rotation 94 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 95 | 96 | # BGR to RGB, HWC to CHW, numpy to tensor 97 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 98 | # normalize 99 | if self.mean is not None or self.std is not None: 100 | normalize(img_lq, self.mean, self.std, inplace=True) 101 | normalize(img_gt, self.mean, self.std, inplace=True) 102 | 103 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 104 | 105 | def __len__(self): 106 | return len(self.paths) 107 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SingleImageDataset(data.Dataset): 12 | """Read only lq images in the test phase. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 15 | 16 | There are two modes: 17 | 1. 'meta_info_file': Use meta information file to generate paths. 18 | 2. 'folder': Scan folders to generate paths. 19 | 20 | Args: 21 | opt (dict): Config for train datasets. It contains the following keys: 22 | dataroot_lq (str): Data root path for lq. 23 | meta_info_file (str): Path for meta information file. 24 | io_backend (dict): IO backend type and other kwarg. 25 | """ 26 | 27 | def __init__(self, opt): 28 | super(SingleImageDataset, self).__init__() 29 | self.opt = opt 30 | # file client (io backend) 31 | self.file_client = None 32 | self.io_backend_opt = opt['io_backend'] 33 | self.mean = opt['mean'] if 'mean' in opt else None 34 | self.std = opt['std'] if 'std' in opt else None 35 | self.lq_folder = opt['dataroot_lq'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = [self.lq_folder] 39 | self.io_backend_opt['client_keys'] = ['lq'] 40 | self.paths = paths_from_lmdb(self.lq_folder) 41 | elif 'meta_info_file' in self.opt: 42 | with open(self.opt['meta_info_file'], 'r') as fin: 43 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] 44 | else: 45 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load lq image 52 | lq_path = self.paths[index] 53 | img_bytes = self.file_client.get(lq_path, 'lq') 54 | img_lq = imfrombytes(img_bytes, float32=True) 55 | 56 | # color space transform 57 | if 'color' in self.opt and self.opt['color'] == 'y': 58 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 59 | 60 | # BGR to RGB, HWC to CHW, numpy to tensor 61 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 62 | # normalize 63 | if self.mean is not None or self.std is not None: 64 | normalize(img_lq, self.mean, self.std, inplace=True) 65 | return {'lq': img_lq, 'lq_path': lq_path} 66 | 67 | def __len__(self): 68 | return len(self.paths) 69 | -------------------------------------------------------------------------------- /basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import torch 4 | 5 | 6 | def mod_crop(img, scale): 7 | """Mod crop images, used during testing. 8 | 9 | Args: 10 | img (ndarray): Input image. 11 | scale (int): Scale factor. 12 | 13 | Returns: 14 | ndarray: Result image. 15 | """ 16 | img = img.copy() 17 | if img.ndim in (2, 3): 18 | h, w = img.shape[0], img.shape[1] 19 | h_remainder, w_remainder = h % scale, w % scale 20 | img = img[:h - h_remainder, :w - w_remainder, ...] 21 | else: 22 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 23 | return img 24 | 25 | 26 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): 27 | """Paired random crop. Support Numpy array and Tensor inputs. 28 | 29 | It crops lists of lq and gt images with corresponding locations. 30 | 31 | Args: 32 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images 33 | should have the same shape. If the input is an ndarray, it will 34 | be transformed to a list containing itself. 35 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 36 | should have the same shape. If the input is an ndarray, it will 37 | be transformed to a list containing itself. 38 | gt_patch_size (int): GT patch size. 39 | scale (int): Scale factor. 40 | gt_path (str): Path to ground-truth. Default: None. 41 | 42 | Returns: 43 | list[ndarray] | ndarray: GT images and LQ images. If returned results 44 | only have one element, just return ndarray. 45 | """ 46 | 47 | if not isinstance(img_gts, list): 48 | img_gts = [img_gts] 49 | if not isinstance(img_lqs, list): 50 | img_lqs = [img_lqs] 51 | 52 | # determine input type: Numpy array or Tensor 53 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' 54 | 55 | if input_type == 'Tensor': 56 | h_lq, w_lq = img_lqs[0].size()[-2:] 57 | h_gt, w_gt = img_gts[0].size()[-2:] 58 | else: 59 | h_lq, w_lq = img_lqs[0].shape[0:2] 60 | h_gt, w_gt = img_gts[0].shape[0:2] 61 | lq_patch_size = gt_patch_size // scale 62 | 63 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 64 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 65 | f'multiplication of LQ ({h_lq}, {w_lq}).') 66 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 67 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 68 | f'({lq_patch_size}, {lq_patch_size}). ' 69 | f'Please remove {gt_path}.') 70 | 71 | # randomly choose top and left coordinates for lq patch 72 | top = random.randint(0, h_lq - lq_patch_size) 73 | left = random.randint(0, w_lq - lq_patch_size) 74 | 75 | # crop lq patch 76 | if input_type == 'Tensor': 77 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] 78 | else: 79 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 80 | 81 | # crop corresponding gt patch 82 | top_gt, left_gt = int(top * scale), int(left * scale) 83 | if input_type == 'Tensor': 84 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] 85 | else: 86 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 87 | if len(img_gts) == 1: 88 | img_gts = img_gts[0] 89 | if len(img_lqs) == 1: 90 | img_lqs = img_lqs[0] 91 | return img_gts, img_lqs 92 | 93 | 94 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 95 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 96 | 97 | We use vertical flip and transpose for rotation implementation. 98 | All the images in the list use the same augmentation. 99 | 100 | Args: 101 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 102 | is an ndarray, it will be transformed to a list. 103 | hflip (bool): Horizontal flip. Default: True. 104 | rotation (bool): Ratotation. Default: True. 105 | flows (list[ndarray]: Flows to be augmented. If the input is an 106 | ndarray, it will be transformed to a list. 107 | Dimension is (h, w, 2). Default: None. 108 | return_status (bool): Return the status of flip and rotation. 109 | Default: False. 110 | 111 | Returns: 112 | list[ndarray] | ndarray: Augmented images and flows. If returned 113 | results only have one element, just return ndarray. 114 | 115 | """ 116 | hflip = hflip and random.random() < 0.5 117 | vflip = rotation and random.random() < 0.5 118 | rot90 = rotation and random.random() < 0.5 119 | 120 | def _augment(img): 121 | if hflip: # horizontal 122 | cv2.flip(img, 1, img) 123 | if vflip: # vertical 124 | cv2.flip(img, 0, img) 125 | if rot90: 126 | img = img.transpose(1, 0, 2) 127 | return img 128 | 129 | def _augment_flow(flow): 130 | if hflip: # horizontal 131 | cv2.flip(flow, 1, flow) 132 | flow[:, :, 0] *= -1 133 | if vflip: # vertical 134 | cv2.flip(flow, 0, flow) 135 | flow[:, :, 1] *= -1 136 | if rot90: 137 | flow = flow.transpose(1, 0, 2) 138 | flow = flow[:, :, [1, 0]] 139 | return flow 140 | 141 | if not isinstance(imgs, list): 142 | imgs = [imgs] 143 | imgs = [_augment(img) for img in imgs] 144 | if len(imgs) == 1: 145 | imgs = imgs[0] 146 | 147 | if flows is not None: 148 | if not isinstance(flows, list): 149 | flows = [flows] 150 | flows = [_augment_flow(flow) for flow in flows] 151 | if len(flows) == 1: 152 | flows = flows[0] 153 | return imgs, flows 154 | else: 155 | if return_status: 156 | return imgs, (hflip, vflip, rot90) 157 | else: 158 | return imgs 159 | 160 | 161 | def img_rotate(img, angle, center=None, scale=1.0): 162 | """Rotate image. 163 | 164 | Args: 165 | img (ndarray): Image to be rotated. 166 | angle (float): Rotation angle in degrees. Positive values mean 167 | counter-clockwise rotation. 168 | center (tuple[int]): Rotation center. If the center is None, 169 | initialize it as the center of the image. Default: None. 170 | scale (float): Isotropic scale factor. Default: 1.0. 171 | """ 172 | (h, w) = img.shape[:2] 173 | 174 | if center is None: 175 | center = (w // 2, h // 2) 176 | 177 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 178 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 179 | return rotated_img 180 | -------------------------------------------------------------------------------- /basicsr/data/vimeo90k_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from pathlib import Path 4 | from torch.utils import data as data 5 | 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class Vimeo90KDataset(data.Dataset): 13 | """Vimeo90K dataset for training. 14 | 15 | The keys are generated from a meta info txt file. 16 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt 17 | 18 | Each line contains the following items, separated by a white space. 19 | 20 | 1. clip name; 21 | 2. frame number; 22 | 3. image shape 23 | 24 | Examples: 25 | 26 | :: 27 | 28 | 00001/0001 7 (256,448,3) 29 | 00001/0002 7 (256,448,3) 30 | 31 | - Key examples: "00001/0001" 32 | - GT (gt): Ground-Truth; 33 | - LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. 34 | 35 | The neighboring frame list for different num_frame: 36 | 37 | :: 38 | 39 | num_frame | frame list 40 | 1 | 4 41 | 3 | 3,4,5 42 | 5 | 2,3,4,5,6 43 | 7 | 1,2,3,4,5,6,7 44 | 45 | Args: 46 | opt (dict): Config for train dataset. It contains the following keys: 47 | dataroot_gt (str): Data root path for gt. 48 | dataroot_lq (str): Data root path for lq. 49 | meta_info_file (str): Path for meta information file. 50 | io_backend (dict): IO backend type and other kwarg. 51 | num_frame (int): Window size for input frames. 52 | gt_size (int): Cropped patched size for gt patches. 53 | random_reverse (bool): Random reverse input frames. 54 | use_hflip (bool): Use horizontal flips. 55 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 56 | scale (bool): Scale, which will be added automatically. 57 | """ 58 | 59 | def __init__(self, opt): 60 | super(Vimeo90KDataset, self).__init__() 61 | self.opt = opt 62 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq']) 63 | 64 | with open(opt['meta_info_file'], 'r') as fin: 65 | self.keys = [line.split(' ')[0] for line in fin] 66 | 67 | # file client (io backend) 68 | self.file_client = None 69 | self.io_backend_opt = opt['io_backend'] 70 | self.is_lmdb = False 71 | if self.io_backend_opt['type'] == 'lmdb': 72 | self.is_lmdb = True 73 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] 74 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 75 | 76 | # indices of input images 77 | self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])] 78 | 79 | # temporal augmentation configs 80 | self.random_reverse = opt['random_reverse'] 81 | logger = get_root_logger() 82 | logger.info(f'Random reverse is {self.random_reverse}.') 83 | 84 | def __getitem__(self, index): 85 | if self.file_client is None: 86 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 87 | 88 | # random reverse 89 | if self.random_reverse and random.random() < 0.5: 90 | self.neighbor_list.reverse() 91 | 92 | scale = self.opt['scale'] 93 | gt_size = self.opt['gt_size'] 94 | key = self.keys[index] 95 | clip, seq = key.split('/') # key example: 00001/0001 96 | 97 | # get the GT frame (im4.png) 98 | if self.is_lmdb: 99 | img_gt_path = f'{key}/im4' 100 | else: 101 | img_gt_path = self.gt_root / clip / seq / 'im4.png' 102 | img_bytes = self.file_client.get(img_gt_path, 'gt') 103 | img_gt = imfrombytes(img_bytes, float32=True) 104 | 105 | # get the neighboring LQ frames 106 | img_lqs = [] 107 | for neighbor in self.neighbor_list: 108 | if self.is_lmdb: 109 | img_lq_path = f'{clip}/{seq}/im{neighbor}' 110 | else: 111 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' 112 | img_bytes = self.file_client.get(img_lq_path, 'lq') 113 | img_lq = imfrombytes(img_bytes, float32=True) 114 | img_lqs.append(img_lq) 115 | 116 | # randomly crop 117 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path) 118 | 119 | # augmentation - flip, rotate 120 | img_lqs.append(img_gt) 121 | img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) 122 | 123 | img_results = img2tensor(img_results) 124 | img_lqs = torch.stack(img_results[0:-1], dim=0) 125 | img_gt = img_results[-1] 126 | 127 | # img_lqs: (t, c, h, w) 128 | # img_gt: (c, h, w) 129 | # key: str 130 | return {'lq': img_lqs, 'gt': img_gt, 'key': key} 131 | 132 | def __len__(self): 133 | return len(self.keys) 134 | 135 | 136 | @DATASET_REGISTRY.register() 137 | class Vimeo90KRecurrentDataset(Vimeo90KDataset): 138 | 139 | def __init__(self, opt): 140 | super(Vimeo90KRecurrentDataset, self).__init__(opt) 141 | 142 | self.flip_sequence = opt['flip_sequence'] 143 | self.neighbor_list = [1, 2, 3, 4, 5, 6, 7] 144 | 145 | def __getitem__(self, index): 146 | if self.file_client is None: 147 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 148 | 149 | # random reverse 150 | if self.random_reverse and random.random() < 0.5: 151 | self.neighbor_list.reverse() 152 | 153 | scale = self.opt['scale'] 154 | gt_size = self.opt['gt_size'] 155 | key = self.keys[index] 156 | clip, seq = key.split('/') # key example: 00001/0001 157 | 158 | # get the neighboring LQ and GT frames 159 | img_lqs = [] 160 | img_gts = [] 161 | for neighbor in self.neighbor_list: 162 | if self.is_lmdb: 163 | img_lq_path = f'{clip}/{seq}/im{neighbor}' 164 | img_gt_path = f'{clip}/{seq}/im{neighbor}' 165 | else: 166 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' 167 | img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png' 168 | # LQ 169 | img_bytes = self.file_client.get(img_lq_path, 'lq') 170 | img_lq = imfrombytes(img_bytes, float32=True) 171 | # GT 172 | img_bytes = self.file_client.get(img_gt_path, 'gt') 173 | img_gt = imfrombytes(img_bytes, float32=True) 174 | 175 | img_lqs.append(img_lq) 176 | img_gts.append(img_gt) 177 | 178 | # randomly crop 179 | img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path) 180 | 181 | # augmentation - flip, rotate 182 | img_lqs.extend(img_gts) 183 | img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) 184 | 185 | img_results = img2tensor(img_results) 186 | img_lqs = torch.stack(img_results[:7], dim=0) 187 | img_gts = torch.stack(img_results[7:], dim=0) 188 | 189 | if self.flip_sequence: # flip the sequence: 7 frames to 14 frames 190 | img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0) 191 | img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0) 192 | 193 | # img_lqs: (t, c, h, w) 194 | # img_gt: (c, h, w) 195 | # key: str 196 | return {'lq': img_lqs, 'gt': img_gts, 'key': key} 197 | 198 | def __len__(self): 199 | return len(self.keys) 200 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import LOSS_REGISTRY 7 | from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty 8 | 9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] 10 | 11 | # automatically scan and import loss modules for registry 12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 13 | loss_folder = osp.dirname(osp.abspath(__file__)) 14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 15 | # import all the loss modules 16 | _model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] 17 | 18 | 19 | def build_loss(opt): 20 | """Build loss from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | loss_type = opt.pop('type') 28 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 29 | logger = get_root_logger() 30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 31 | return loss 32 | -------------------------------------------------------------------------------- /basicsr/losses/gan_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import autograd as autograd 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | from basicsr.utils.registry import LOSS_REGISTRY 8 | 9 | 10 | @LOSS_REGISTRY.register() 11 | class GANLoss(nn.Module): 12 | """Define GAN loss. 13 | 14 | Args: 15 | gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. 16 | real_label_val (float): The value for real label. Default: 1.0. 17 | fake_label_val (float): The value for fake label. Default: 0.0. 18 | loss_weight (float): Loss weight. Default: 1.0. 19 | Note that loss_weight is only for generators; and it is always 1.0 20 | for discriminators. 21 | """ 22 | 23 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): 24 | super(GANLoss, self).__init__() 25 | self.gan_type = gan_type 26 | self.loss_weight = loss_weight 27 | self.real_label_val = real_label_val 28 | self.fake_label_val = fake_label_val 29 | 30 | if self.gan_type == 'vanilla': 31 | self.loss = nn.BCEWithLogitsLoss() 32 | elif self.gan_type == 'lsgan': 33 | self.loss = nn.MSELoss() 34 | elif self.gan_type == 'wgan': 35 | self.loss = self._wgan_loss 36 | elif self.gan_type == 'wgan_softplus': 37 | self.loss = self._wgan_softplus_loss 38 | elif self.gan_type == 'hinge': 39 | self.loss = nn.ReLU() 40 | else: 41 | raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') 42 | 43 | def _wgan_loss(self, input, target): 44 | """wgan loss. 45 | 46 | Args: 47 | input (Tensor): Input tensor. 48 | target (bool): Target label. 49 | 50 | Returns: 51 | Tensor: wgan loss. 52 | """ 53 | return -input.mean() if target else input.mean() 54 | 55 | def _wgan_softplus_loss(self, input, target): 56 | """wgan loss with soft plus. softplus is a smooth approximation to the 57 | ReLU function. 58 | 59 | In StyleGAN2, it is called: 60 | Logistic loss for discriminator; 61 | Non-saturating loss for generator. 62 | 63 | Args: 64 | input (Tensor): Input tensor. 65 | target (bool): Target label. 66 | 67 | Returns: 68 | Tensor: wgan loss. 69 | """ 70 | return F.softplus(-input).mean() if target else F.softplus(input).mean() 71 | 72 | def get_target_label(self, input, target_is_real): 73 | """Get target label. 74 | 75 | Args: 76 | input (Tensor): Input tensor. 77 | target_is_real (bool): Whether the target is real or fake. 78 | 79 | Returns: 80 | (bool | Tensor): Target tensor. Return bool for wgan, otherwise, 81 | return Tensor. 82 | """ 83 | 84 | if self.gan_type in ['wgan', 'wgan_softplus']: 85 | return target_is_real 86 | target_val = (self.real_label_val if target_is_real else self.fake_label_val) 87 | return input.new_ones(input.size()) * target_val 88 | 89 | def forward(self, input, target_is_real, is_disc=False): 90 | """ 91 | Args: 92 | input (Tensor): The input for the loss module, i.e., the network 93 | prediction. 94 | target_is_real (bool): Whether the targe is real or fake. 95 | is_disc (bool): Whether the loss for discriminators or not. 96 | Default: False. 97 | 98 | Returns: 99 | Tensor: GAN loss value. 100 | """ 101 | target_label = self.get_target_label(input, target_is_real) 102 | if self.gan_type == 'hinge': 103 | if is_disc: # for discriminators in hinge-gan 104 | input = -input if target_is_real else input 105 | loss = self.loss(1 + input).mean() 106 | else: # for generators in hinge-gan 107 | loss = -input.mean() 108 | else: # other gan types 109 | loss = self.loss(input, target_label) 110 | 111 | # loss_weight is always 1.0 for discriminators 112 | return loss if is_disc else loss * self.loss_weight 113 | 114 | 115 | @LOSS_REGISTRY.register() 116 | class MultiScaleGANLoss(GANLoss): 117 | """ 118 | MultiScaleGANLoss accepts a list of predictions 119 | """ 120 | 121 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): 122 | super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) 123 | 124 | def forward(self, input, target_is_real, is_disc=False): 125 | """ 126 | The input is a list of tensors, or a list of (a list of tensors) 127 | """ 128 | if isinstance(input, list): 129 | loss = 0 130 | for pred_i in input: 131 | if isinstance(pred_i, list): 132 | # Only compute GAN loss for the last layer 133 | # in case of multiscale feature matching 134 | pred_i = pred_i[-1] 135 | # Safe operation: 0-dim tensor calling self.mean() does nothing 136 | loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() 137 | loss += loss_tensor 138 | return loss / len(input) 139 | else: 140 | return super().forward(input, target_is_real, is_disc) 141 | 142 | 143 | def r1_penalty(real_pred, real_img): 144 | """R1 regularization for discriminator. The core idea is to 145 | penalize the gradient on real data alone: when the 146 | generator distribution produces the true data distribution 147 | and the discriminator is equal to 0 on the data manifold, the 148 | gradient penalty ensures that the discriminator cannot create 149 | a non-zero gradient orthogonal to the data manifold without 150 | suffering a loss in the GAN game. 151 | 152 | Reference: Eq. 9 in Which training methods for GANs do actually converge. 153 | """ 154 | grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] 155 | grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() 156 | return grad_penalty 157 | 158 | 159 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 160 | noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) 161 | grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] 162 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 163 | 164 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 165 | 166 | path_penalty = (path_lengths - path_mean).pow(2).mean() 167 | 168 | return path_penalty, path_lengths.detach().mean(), path_mean.detach() 169 | 170 | 171 | def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): 172 | """Calculate gradient penalty for wgan-gp. 173 | 174 | Args: 175 | discriminator (nn.Module): Network for the discriminator. 176 | real_data (Tensor): Real input data. 177 | fake_data (Tensor): Fake input data. 178 | weight (Tensor): Weight tensor. Default: None. 179 | 180 | Returns: 181 | Tensor: A tensor for gradient penalty. 182 | """ 183 | 184 | batch_size = real_data.size(0) 185 | alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) 186 | 187 | # interpolate between real_data and fake_data 188 | interpolates = alpha * real_data + (1. - alpha) * fake_data 189 | interpolates = autograd.Variable(interpolates, requires_grad=True) 190 | 191 | disc_interpolates = discriminator(interpolates) 192 | gradients = autograd.grad( 193 | outputs=disc_interpolates, 194 | inputs=interpolates, 195 | grad_outputs=torch.ones_like(disc_interpolates), 196 | create_graph=True, 197 | retain_graph=True, 198 | only_inputs=True)[0] 199 | 200 | if weight is not None: 201 | gradients = gradients * weight 202 | 203 | gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() 204 | if weight is not None: 205 | gradients_penalty /= torch.mean(weight) 206 | 207 | return gradients_penalty 208 | -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are 'none', 'mean' and 'sum'. 12 | 13 | Returns: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | reduction_enum = F._Reduction.get_enum(reduction) 17 | # none: 0, elementwise_mean:1, sum: 2 18 | if reduction_enum == 0: 19 | return loss 20 | elif reduction_enum == 1: 21 | return loss.mean() 22 | else: 23 | return loss.sum() 24 | 25 | 26 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. Default: None. 32 | reduction (str): Same as built-in losses of PyTorch. Options are 33 | 'none', 'mean' and 'sum'. Default: 'mean'. 34 | 35 | Returns: 36 | Tensor: Loss values. 37 | """ 38 | # if weight is specified, apply element-wise weight 39 | if weight is not None: 40 | assert weight.dim() == loss.dim() 41 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 42 | loss = loss * weight 43 | 44 | # if weight is not specified or reduction is sum, just reduce the loss 45 | if weight is None or reduction == 'sum': 46 | loss = reduce_loss(loss, reduction) 47 | # if reduction is mean, then compute mean over weight region 48 | elif reduction == 'mean': 49 | if weight.size(1) > 1: 50 | weight = weight.sum() 51 | else: 52 | weight = weight.sum() * loss.size(1) 53 | loss = loss.sum() / weight 54 | 55 | return loss 56 | 57 | 58 | def weighted_loss(loss_func): 59 | """Create a weighted version of a given loss function. 60 | 61 | To use this decorator, the loss function must have the signature like 62 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 63 | element-wise loss without any reduction. This decorator will add weight 64 | and reduction arguments to the function. The decorated function will have 65 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 66 | **kwargs)`. 67 | 68 | :Example: 69 | 70 | >>> import torch 71 | >>> @weighted_loss 72 | >>> def l1_loss(pred, target): 73 | >>> return (pred - target).abs() 74 | 75 | >>> pred = torch.Tensor([0, 2, 3]) 76 | >>> target = torch.Tensor([1, 1, 1]) 77 | >>> weight = torch.Tensor([1, 0, 1]) 78 | 79 | >>> l1_loss(pred, target) 80 | tensor(1.3333) 81 | >>> l1_loss(pred, target, weight) 82 | tensor(1.5000) 83 | >>> l1_loss(pred, target, reduction='none') 84 | tensor([1., 1., 2.]) 85 | >>> l1_loss(pred, target, weight, reduction='sum') 86 | tensor(3.) 87 | """ 88 | 89 | @functools.wraps(loss_func) 90 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 91 | # get element-wise loss 92 | loss = loss_func(pred, target, **kwargs) 93 | loss = weight_reduce_loss(loss, weight, reduction) 94 | return loss 95 | 96 | return wrapper 97 | 98 | 99 | def get_local_weights(residual, ksize): 100 | """Get local weights for generating the artifact map of LDL. 101 | 102 | It is only called by the `get_refined_artifact_map` function. 103 | 104 | Args: 105 | residual (Tensor): Residual between predicted and ground truth images. 106 | ksize (Int): size of the local window. 107 | 108 | Returns: 109 | Tensor: weight for each pixel to be discriminated as an artifact pixel 110 | """ 111 | 112 | pad = (ksize - 1) // 2 113 | residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') 114 | 115 | unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) 116 | pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) 117 | 118 | return pixel_level_weight 119 | 120 | 121 | def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): 122 | """Calculate the artifact map of LDL 123 | (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) 124 | 125 | Args: 126 | img_gt (Tensor): ground truth images. 127 | img_output (Tensor): output images given by the optimizing model. 128 | img_ema (Tensor): output images given by the ema model. 129 | ksize (Int): size of the local window. 130 | 131 | Returns: 132 | overall_weight: weight for each pixel to be discriminated as an artifact pixel 133 | (calculated based on both local and global observations). 134 | """ 135 | 136 | residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) 137 | residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) 138 | 139 | patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) 140 | pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) 141 | overall_weight = patch_level_weight * pixel_level_weight 142 | 143 | overall_weight[residual_sr < residual_ema] = 0 144 | 145 | return overall_weight 146 | -------------------------------------------------------------------------------- /basicsr/metrics/README.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | - [约定](#约定) 6 | - [PSNR 和 SSIM](#psnr-和-ssim) 7 | 8 | ## 约定 9 | 10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: 11 | 12 | - Numpy 类型 (一般是 cv2 的结果) 13 | - UINT8: BGR, [0, 255], (h, w, c) 14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 15 | - Tensor 类型 16 | - float: RGB, [0, 1], (n, c, h, w) 17 | 18 | 其他约定: 19 | 20 | - 以 `_pt` 结尾的是 PyTorch 结果 21 | - PyTorch version 支持 batch 计算 22 | - 颜色转换在 float32 上做;metric计算在 float64 上做 23 | 24 | ## PSNR 和 SSIM 25 | 26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) 28 | 29 | 下面列了各个实现的结果比对. 30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 31 | 32 | - PSNR 比对 33 | 34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 35 | |:---| :---: | :---: | :---: | :---: | :---: | 36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | 37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916| 38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | 39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663| 40 | 41 | - SSIM 比对 42 | 43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 44 | |:---| :---: | :---: | :---: | :---: | :---: | 45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | 46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171| 47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| 48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 | 49 | -------------------------------------------------------------------------------- /basicsr/metrics/README_CN.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | - [约定](#约定) 6 | - [PSNR 和 SSIM](#psnr-和-ssim) 7 | 8 | ## 约定 9 | 10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: 11 | 12 | - Numpy 类型 (一般是 cv2 的结果) 13 | - UINT8: BGR, [0, 255], (h, w, c) 14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 15 | - Tensor 类型 16 | - float: RGB, [0, 1], (n, c, h, w) 17 | 18 | 其他约定: 19 | 20 | - 以 `_pt` 结尾的是 PyTorch 结果 21 | - PyTorch version 支持 batch 计算 22 | - 颜色转换在 float32 上做;metric计算在 float64 上做 23 | 24 | ## PSNR 和 SSIM 25 | 26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) 28 | 29 | 下面列了各个实现的结果比对. 30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 31 | 32 | - PSNR 比对 33 | 34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 35 | |:---| :---: | :---: | :---: | :---: | :---: | 36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | 37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916| 38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | 39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663| 40 | 41 | - SSIM 比对 42 | 43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 44 | |:---| :---: | :---: | :---: | :---: | :---: | 45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | 46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171| 47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| 48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 | 49 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .niqe import calculate_niqe 5 | from .psnr_ssim import calculate_psnr, calculate_ssim 6 | 7 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 8 | 9 | 10 | def calculate_metric(data, opt): 11 | """Calculate metric from data and options. 12 | 13 | Args: 14 | opt (dict): Configuration. It must contain: 15 | type (str): Model type. 16 | """ 17 | opt = deepcopy(opt) 18 | metric_type = opt.pop('type') 19 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 20 | return metric 21 | -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): 11 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 12 | # does resize the input. 13 | inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) 14 | inception = nn.DataParallel(inception).eval().to(device) 15 | return inception 16 | 17 | 18 | @torch.no_grad() 19 | def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): 20 | """Extract inception features. 21 | 22 | Args: 23 | data_generator (generator): A data generator. 24 | inception (nn.Module): Inception model. 25 | len_generator (int): Length of the data_generator to show the 26 | progressbar. Default: None. 27 | device (str): Device. Default: cuda. 28 | 29 | Returns: 30 | Tensor: Extracted features. 31 | """ 32 | if len_generator is not None: 33 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 34 | else: 35 | pbar = None 36 | features = [] 37 | 38 | for data in data_generator: 39 | if pbar: 40 | pbar.update(1) 41 | data = data.to(device) 42 | feature = inception(data)[0].view(data.shape[0], -1) 43 | features.append(feature.to('cpu')) 44 | if pbar: 45 | pbar.close() 46 | features = torch.cat(features, 0) 47 | return features 48 | 49 | 50 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 51 | """Numpy implementation of the Frechet Distance. 52 | 53 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is: 54 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 55 | Stable version by Dougal J. Sutherland. 56 | 57 | Args: 58 | mu1 (np.array): The sample mean over activations. 59 | sigma1 (np.array): The covariance matrix over activations for generated samples. 60 | mu2 (np.array): The sample mean over activations, precalculated on an representative data set. 61 | sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set. 62 | 63 | Returns: 64 | float: The Frechet Distance. 65 | """ 66 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 67 | assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') 68 | 69 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 70 | 71 | # Product might be almost singular 72 | if not np.isfinite(cov_sqrt).all(): 73 | print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') 74 | offset = np.eye(sigma1.shape[0]) * eps 75 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 76 | 77 | # Numerical error might give slight imaginary component 78 | if np.iscomplexobj(cov_sqrt): 79 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 80 | m = np.max(np.abs(cov_sqrt.imag)) 81 | raise ValueError(f'Imaginary component {m}') 82 | cov_sqrt = cov_sqrt.real 83 | 84 | mean_diff = mu1 - mu2 85 | mean_norm = mean_diff @ mean_diff 86 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 87 | fid = mean_norm + trace 88 | 89 | return fid 90 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Aitical/SCNet/c0f8678f2f50e1f97e00c3e018e904a273f0f39a/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/metrics/test_metrics/test_psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | from basicsr.metrics import calculate_psnr, calculate_ssim 5 | from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt 6 | from basicsr.utils import img2tensor 7 | 8 | 9 | def test(img_path, img_path2, crop_border, test_y_channel=False): 10 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 11 | img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED) 12 | 13 | # --------------------- Numpy --------------------- 14 | psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 15 | ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 16 | print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') 17 | 18 | # --------------------- PyTorch (CPU) --------------------- 19 | img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 20 | img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 21 | 22 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 23 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 24 | print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 25 | 26 | # --------------------- PyTorch (GPU) --------------------- 27 | img = img.cuda() 28 | img2 = img2.cuda() 29 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 30 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 31 | print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 32 | 33 | psnr_pth = calculate_psnr_pt( 34 | torch.repeat_interleave(img, 2, dim=0), 35 | torch.repeat_interleave(img2, 2, dim=0), 36 | crop_border=crop_border, 37 | test_y_channel=test_y_channel) 38 | ssim_pth = calculate_ssim_pt( 39 | torch.repeat_interleave(img, 2, dim=0), 40 | torch.repeat_interleave(img2, 2, dim=0), 41 | crop_border=crop_border, 42 | test_y_channel=test_y_channel) 43 | print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,' 44 | f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}') 45 | 46 | 47 | if __name__ == '__main__': 48 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False) 49 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True) 50 | 51 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False) 52 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True) 53 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with '_model.py' 12 | model_folder = osp.dirname(osp.abspath(__file__)) 13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 14 | # import all the model modules 15 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 16 | 17 | 18 | def build_model(opt): 19 | """Build model from options. 20 | 21 | Args: 22 | opt (dict): Configuration. It must contain: 23 | model_type (str): Model type. 24 | """ 25 | opt = deepcopy(opt) 26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 27 | logger = get_root_logger() 28 | logger.info(f'Model [{model.__class__.__name__}] is created.') 29 | return model 30 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] -------------------------------------------------------------------------------- /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Aitical/SCNet/c0f8678f2f50e1f97e00c3e018e904a273f0f39a/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/ops/dcn/src/deform_conv_ext.cpp: -------------------------------------------------------------------------------- 1 | // modify from 2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #define WITH_CUDA // always use cuda 11 | #ifdef WITH_CUDA 12 | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, 13 | at::Tensor offset, at::Tensor output, 14 | at::Tensor columns, at::Tensor ones, int kW, 15 | int kH, int dW, int dH, int padW, int padH, 16 | int dilationW, int dilationH, int group, 17 | int deformable_group, int im2col_step); 18 | 19 | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, 20 | at::Tensor gradOutput, at::Tensor gradInput, 21 | at::Tensor gradOffset, at::Tensor weight, 22 | at::Tensor columns, int kW, int kH, int dW, 23 | int dH, int padW, int padH, int dilationW, 24 | int dilationH, int group, 25 | int deformable_group, int im2col_step); 26 | 27 | int deform_conv_backward_parameters_cuda( 28 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 29 | at::Tensor gradWeight, // at::Tensor gradBias, 30 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 31 | int padW, int padH, int dilationW, int dilationH, int group, 32 | int deformable_group, float scale, int im2col_step); 33 | 34 | void modulated_deform_conv_cuda_forward( 35 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 36 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 37 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 38 | const int pad_h, const int pad_w, const int dilation_h, 39 | const int dilation_w, const int group, const int deformable_group, 40 | const bool with_bias); 41 | 42 | void modulated_deform_conv_cuda_backward( 43 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 44 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 45 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 46 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 47 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 48 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 49 | const bool with_bias); 50 | #endif 51 | 52 | int deform_conv_forward(at::Tensor input, at::Tensor weight, 53 | at::Tensor offset, at::Tensor output, 54 | at::Tensor columns, at::Tensor ones, int kW, 55 | int kH, int dW, int dH, int padW, int padH, 56 | int dilationW, int dilationH, int group, 57 | int deformable_group, int im2col_step) { 58 | if (input.device().is_cuda()) { 59 | #ifdef WITH_CUDA 60 | return deform_conv_forward_cuda(input, weight, offset, output, columns, 61 | ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, 62 | deformable_group, im2col_step); 63 | #else 64 | AT_ERROR("deform conv is not compiled with GPU support"); 65 | #endif 66 | } 67 | AT_ERROR("deform conv is not implemented on CPU"); 68 | } 69 | 70 | int deform_conv_backward_input(at::Tensor input, at::Tensor offset, 71 | at::Tensor gradOutput, at::Tensor gradInput, 72 | at::Tensor gradOffset, at::Tensor weight, 73 | at::Tensor columns, int kW, int kH, int dW, 74 | int dH, int padW, int padH, int dilationW, 75 | int dilationH, int group, 76 | int deformable_group, int im2col_step) { 77 | if (input.device().is_cuda()) { 78 | #ifdef WITH_CUDA 79 | return deform_conv_backward_input_cuda(input, offset, gradOutput, 80 | gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, 81 | dilationW, dilationH, group, deformable_group, im2col_step); 82 | #else 83 | AT_ERROR("deform conv is not compiled with GPU support"); 84 | #endif 85 | } 86 | AT_ERROR("deform conv is not implemented on CPU"); 87 | } 88 | 89 | int deform_conv_backward_parameters( 90 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 91 | at::Tensor gradWeight, // at::Tensor gradBias, 92 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 93 | int padW, int padH, int dilationW, int dilationH, int group, 94 | int deformable_group, float scale, int im2col_step) { 95 | if (input.device().is_cuda()) { 96 | #ifdef WITH_CUDA 97 | return deform_conv_backward_parameters_cuda(input, offset, gradOutput, 98 | gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, 99 | dilationH, group, deformable_group, scale, im2col_step); 100 | #else 101 | AT_ERROR("deform conv is not compiled with GPU support"); 102 | #endif 103 | } 104 | AT_ERROR("deform conv is not implemented on CPU"); 105 | } 106 | 107 | void modulated_deform_conv_forward( 108 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 109 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 110 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 111 | const int pad_h, const int pad_w, const int dilation_h, 112 | const int dilation_w, const int group, const int deformable_group, 113 | const bool with_bias) { 114 | if (input.device().is_cuda()) { 115 | #ifdef WITH_CUDA 116 | return modulated_deform_conv_cuda_forward(input, weight, bias, ones, 117 | offset, mask, output, columns, kernel_h, kernel_w, stride_h, 118 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, 119 | deformable_group, with_bias); 120 | #else 121 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 122 | #endif 123 | } 124 | AT_ERROR("modulated deform conv is not implemented on CPU"); 125 | } 126 | 127 | void modulated_deform_conv_backward( 128 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 129 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 130 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 131 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 132 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 133 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 134 | const bool with_bias) { 135 | if (input.device().is_cuda()) { 136 | #ifdef WITH_CUDA 137 | return modulated_deform_conv_cuda_backward(input, weight, bias, ones, 138 | offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, 139 | grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, 140 | pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, 141 | with_bias); 142 | #else 143 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 144 | #endif 145 | } 146 | AT_ERROR("modulated deform conv is not implemented on CPU"); 147 | } 148 | 149 | 150 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 151 | m.def("deform_conv_forward", &deform_conv_forward, 152 | "deform forward"); 153 | m.def("deform_conv_backward_input", &deform_conv_backward_input, 154 | "deform_conv_backward_input"); 155 | m.def("deform_conv_backward_parameters", 156 | &deform_conv_backward_parameters, 157 | "deform_conv_backward_parameters"); 158 | m.def("modulated_deform_conv_forward", 159 | &modulated_deform_conv_forward, 160 | "modulated deform conv forward"); 161 | m.def("modulated_deform_conv_backward", 162 | &modulated_deform_conv_backward, 163 | "modulated deform conv backward"); 164 | } 165 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | fused_act_ext = load( 13 | 'fused', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 16 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import fused_act_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class FusedLeakyReLUFunctionBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, out, negative_slope, scale): 34 | ctx.save_for_backward(out) 35 | ctx.negative_slope = negative_slope 36 | ctx.scale = scale 37 | 38 | empty = grad_output.new_empty(0) 39 | 40 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 41 | 42 | dim = [0] 43 | 44 | if grad_input.ndim > 2: 45 | dim += list(range(2, grad_input.ndim)) 46 | 47 | grad_bias = grad_input.sum(dim).detach() 48 | 49 | return grad_input, grad_bias 50 | 51 | @staticmethod 52 | def backward(ctx, gradgrad_input, gradgrad_bias): 53 | out, = ctx.saved_tensors 54 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 55 | ctx.scale) 56 | 57 | return gradgrad_out, None, None, None 58 | 59 | 60 | class FusedLeakyReLUFunction(Function): 61 | 62 | @staticmethod 63 | def forward(ctx, input, bias, negative_slope, scale): 64 | empty = input.new_empty(0) 65 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 66 | ctx.save_for_backward(out) 67 | ctx.negative_slope = negative_slope 68 | ctx.scale = scale 69 | 70 | return out 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | out, = ctx.saved_tensors 75 | 76 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 77 | 78 | return grad_input, grad_bias, None, None 79 | 80 | 81 | class FusedLeakyReLU(nn.Module): 82 | 83 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 95 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 96 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch.autograd import Function 6 | from torch.nn import functional as F 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | upfirdn2d_ext = load( 13 | 'upfirdn2d', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'upfirdn2d.cpp'), 16 | os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import upfirdn2d_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class UpFirDn2dBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 34 | 35 | up_x, up_y = up 36 | down_x, down_y = down 37 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 38 | 39 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 40 | 41 | grad_input = upfirdn2d_ext.upfirdn2d( 42 | grad_output, 43 | grad_kernel, 44 | down_x, 45 | down_y, 46 | up_x, 47 | up_y, 48 | g_pad_x0, 49 | g_pad_x1, 50 | g_pad_y0, 51 | g_pad_y1, 52 | ) 53 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 54 | 55 | ctx.save_for_backward(kernel) 56 | 57 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 58 | 59 | ctx.up_x = up_x 60 | ctx.up_y = up_y 61 | ctx.down_x = down_x 62 | ctx.down_y = down_y 63 | ctx.pad_x0 = pad_x0 64 | ctx.pad_x1 = pad_x1 65 | ctx.pad_y0 = pad_y0 66 | ctx.pad_y1 = pad_y1 67 | ctx.in_size = in_size 68 | ctx.out_size = out_size 69 | 70 | return grad_input 71 | 72 | @staticmethod 73 | def backward(ctx, gradgrad_input): 74 | kernel, = ctx.saved_tensors 75 | 76 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 77 | 78 | gradgrad_out = upfirdn2d_ext.upfirdn2d( 79 | gradgrad_input, 80 | kernel, 81 | ctx.up_x, 82 | ctx.up_y, 83 | ctx.down_x, 84 | ctx.down_y, 85 | ctx.pad_x0, 86 | ctx.pad_x1, 87 | ctx.pad_y0, 88 | ctx.pad_y1, 89 | ) 90 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], 91 | # ctx.out_size[1], ctx.in_size[3]) 92 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 93 | 94 | return gradgrad_out, None, None, None, None, None, None, None, None 95 | 96 | 97 | class UpFirDn2d(Function): 98 | 99 | @staticmethod 100 | def forward(ctx, input, kernel, up, down, pad): 101 | up_x, up_y = up 102 | down_x, down_y = down 103 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 104 | 105 | kernel_h, kernel_w = kernel.shape 106 | _, channel, in_h, in_w = input.shape 107 | ctx.in_size = input.shape 108 | 109 | input = input.reshape(-1, in_h, in_w, 1) 110 | 111 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 112 | 113 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 114 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 115 | ctx.out_size = (out_h, out_w) 116 | 117 | ctx.up = (up_x, up_y) 118 | ctx.down = (down_x, down_y) 119 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 120 | 121 | g_pad_x0 = kernel_w - pad_x0 - 1 122 | g_pad_y0 = kernel_h - pad_y0 - 1 123 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 124 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 125 | 126 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 127 | 128 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 129 | # out = out.view(major, out_h, out_w, minor) 130 | out = out.view(-1, channel, out_h, out_w) 131 | 132 | return out 133 | 134 | @staticmethod 135 | def backward(ctx, grad_output): 136 | kernel, grad_kernel = ctx.saved_tensors 137 | 138 | grad_input = UpFirDn2dBackward.apply( 139 | grad_output, 140 | kernel, 141 | grad_kernel, 142 | ctx.up, 143 | ctx.down, 144 | ctx.pad, 145 | ctx.g_pad, 146 | ctx.in_size, 147 | ctx.out_size, 148 | ) 149 | 150 | return grad_input, None, None, None, None 151 | 152 | 153 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 154 | if input.device.type == 'cpu': 155 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 156 | else: 157 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 158 | 159 | return out 160 | 161 | 162 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 163 | _, channel, in_h, in_w = input.shape 164 | input = input.reshape(-1, in_h, in_w, 1) 165 | 166 | _, in_h, in_w, minor = input.shape 167 | kernel_h, kernel_w = kernel.shape 168 | 169 | out = input.view(-1, in_h, 1, in_w, 1, minor) 170 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 171 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 172 | 173 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 174 | out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] 175 | 176 | out = out.permute(0, 3, 1, 2) 177 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 178 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 179 | out = F.conv2d(out, w) 180 | out = out.reshape( 181 | -1, 182 | minor, 183 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 184 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 185 | ) 186 | out = out.permute(0, 2, 3, 1) 187 | out = out[:, ::down_y, ::down_x, :] 188 | 189 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 190 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 191 | 192 | return out.view(-1, channel, out_h, out_w) 193 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import build_dataloader, build_dataset 6 | from basicsr.models import build_model 7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs 8 | from basicsr.utils.options import dict2str, parse_options 9 | 10 | 11 | def test_pipeline(root_path): 12 | # parse options, set distributed setting, set ramdom seed 13 | opt, _ = parse_options(root_path, is_train=False) 14 | 15 | torch.backends.cudnn.benchmark = True 16 | # torch.backends.cudnn.deterministic = True 17 | 18 | # mkdir and initialize loggers 19 | make_exp_dirs(opt) 20 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 21 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 22 | logger.info(get_env_info()) 23 | logger.info(dict2str(opt)) 24 | 25 | # create test dataset and dataloader 26 | test_loaders = [] 27 | for _, dataset_opt in sorted(opt['datasets'].items()): 28 | test_set = build_dataset(dataset_opt) 29 | test_loader = build_dataloader( 30 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 31 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 32 | test_loaders.append(test_loader) 33 | 34 | # create model 35 | model = build_model(opt) 36 | 37 | for test_loader in test_loaders: 38 | test_set_name = test_loader.dataset.opt['name'] 39 | logger.info(f'Testing {test_set_name}...') 40 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 41 | 42 | 43 | if __name__ == '__main__': 44 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 45 | test_pipeline(root_path) 46 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb 2 | from .diffjpeg import DiffJPEG 3 | from .file_client import FileClient 4 | from .img_process_util import USMSharp, usm_sharp 5 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 6 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 7 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 8 | from .options import yaml_load 9 | 10 | __all__ = [ 11 | # color_util.py 12 | 'bgr2ycbcr', 13 | 'rgb2ycbcr', 14 | 'rgb2ycbcr_pt', 15 | 'ycbcr2bgr', 16 | 'ycbcr2rgb', 17 | # file_client.py 18 | 'FileClient', 19 | # img_util.py 20 | 'img2tensor', 21 | 'tensor2img', 22 | 'imfrombytes', 23 | 'imwrite', 24 | 'crop_border', 25 | # logger.py 26 | 'MessageLogger', 27 | 'AvgTimer', 28 | 'init_tb_logger', 29 | 'init_wandb_logger', 30 | 'get_root_logger', 31 | 'get_env_info', 32 | # misc.py 33 | 'set_random_seed', 34 | 'get_time_str', 35 | 'mkdir_and_rename', 36 | 'make_exp_dirs', 37 | 'scandir', 38 | 'check_resume', 39 | 'sizeof_fmt', 40 | # diffjpeg 41 | 'DiffJPEG', 42 | # img_process_util 43 | 'USMSharp', 44 | 'usm_sharp', 45 | # options 46 | 'yaml_load' 47 | ] 48 | -------------------------------------------------------------------------------- /basicsr/utils/color_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def rgb2ycbcr(img, y_only=False): 6 | """Convert a RGB image to YCbCr image. 7 | 8 | This function produces the same results as Matlab's `rgb2ycbcr` function. 9 | It implements the ITU-R BT.601 conversion for standard-definition 10 | television. See more details in 11 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 12 | 13 | It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. 14 | In OpenCV, it implements a JPEG conversion. See more details in 15 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 16 | 17 | Args: 18 | img (ndarray): The input image. It accepts: 19 | 1. np.uint8 type with range [0, 255]; 20 | 2. np.float32 type with range [0, 1]. 21 | y_only (bool): Whether to only return Y channel. Default: False. 22 | 23 | Returns: 24 | ndarray: The converted YCbCr image. The output image has the same type 25 | and range as input image. 26 | """ 27 | img_type = img.dtype 28 | img = _convert_input_type_range(img) 29 | if y_only: 30 | out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 31 | else: 32 | out_img = np.matmul( 33 | img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] 34 | out_img = _convert_output_type_range(out_img, img_type) 35 | return out_img 36 | 37 | 38 | def bgr2ycbcr(img, y_only=False): 39 | """Convert a BGR image to YCbCr image. 40 | 41 | The bgr version of rgb2ycbcr. 42 | It implements the ITU-R BT.601 conversion for standard-definition 43 | television. See more details in 44 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 45 | 46 | It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. 47 | In OpenCV, it implements a JPEG conversion. See more details in 48 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 49 | 50 | Args: 51 | img (ndarray): The input image. It accepts: 52 | 1. np.uint8 type with range [0, 255]; 53 | 2. np.float32 type with range [0, 1]. 54 | y_only (bool): Whether to only return Y channel. Default: False. 55 | 56 | Returns: 57 | ndarray: The converted YCbCr image. The output image has the same type 58 | and range as input image. 59 | """ 60 | img_type = img.dtype 61 | img = _convert_input_type_range(img) 62 | if y_only: 63 | out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 64 | else: 65 | out_img = np.matmul( 66 | img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] 67 | out_img = _convert_output_type_range(out_img, img_type) 68 | return out_img 69 | 70 | 71 | def ycbcr2rgb(img): 72 | """Convert a YCbCr image to RGB image. 73 | 74 | This function produces the same results as Matlab's ycbcr2rgb function. 75 | It implements the ITU-R BT.601 conversion for standard-definition 76 | television. See more details in 77 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 78 | 79 | It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. 80 | In OpenCV, it implements a JPEG conversion. See more details in 81 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 82 | 83 | Args: 84 | img (ndarray): The input image. It accepts: 85 | 1. np.uint8 type with range [0, 255]; 86 | 2. np.float32 type with range [0, 1]. 87 | 88 | Returns: 89 | ndarray: The converted RGB image. The output image has the same type 90 | and range as input image. 91 | """ 92 | img_type = img.dtype 93 | img = _convert_input_type_range(img) * 255 94 | out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 95 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 96 | out_img = _convert_output_type_range(out_img, img_type) 97 | return out_img 98 | 99 | 100 | def ycbcr2bgr(img): 101 | """Convert a YCbCr image to BGR image. 102 | 103 | The bgr version of ycbcr2rgb. 104 | It implements the ITU-R BT.601 conversion for standard-definition 105 | television. See more details in 106 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 107 | 108 | It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. 109 | In OpenCV, it implements a JPEG conversion. See more details in 110 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 111 | 112 | Args: 113 | img (ndarray): The input image. It accepts: 114 | 1. np.uint8 type with range [0, 255]; 115 | 2. np.float32 type with range [0, 1]. 116 | 117 | Returns: 118 | ndarray: The converted BGR image. The output image has the same type 119 | and range as input image. 120 | """ 121 | img_type = img.dtype 122 | img = _convert_input_type_range(img) * 255 123 | out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], 124 | [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 125 | out_img = _convert_output_type_range(out_img, img_type) 126 | return out_img 127 | 128 | 129 | def _convert_input_type_range(img): 130 | """Convert the type and range of the input image. 131 | 132 | It converts the input image to np.float32 type and range of [0, 1]. 133 | It is mainly used for pre-processing the input image in colorspace 134 | conversion functions such as rgb2ycbcr and ycbcr2rgb. 135 | 136 | Args: 137 | img (ndarray): The input image. It accepts: 138 | 1. np.uint8 type with range [0, 255]; 139 | 2. np.float32 type with range [0, 1]. 140 | 141 | Returns: 142 | (ndarray): The converted image with type of np.float32 and range of 143 | [0, 1]. 144 | """ 145 | img_type = img.dtype 146 | img = img.astype(np.float32) 147 | if img_type == np.float32: 148 | pass 149 | elif img_type == np.uint8: 150 | img /= 255. 151 | else: 152 | raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}') 153 | return img 154 | 155 | 156 | def _convert_output_type_range(img, dst_type): 157 | """Convert the type and range of the image according to dst_type. 158 | 159 | It converts the image to desired type and range. If `dst_type` is np.uint8, 160 | images will be converted to np.uint8 type with range [0, 255]. If 161 | `dst_type` is np.float32, it converts the image to np.float32 type with 162 | range [0, 1]. 163 | It is mainly used for post-processing images in colorspace conversion 164 | functions such as rgb2ycbcr and ycbcr2rgb. 165 | 166 | Args: 167 | img (ndarray): The image to be converted with np.float32 type and 168 | range [0, 255]. 169 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it 170 | converts the image to np.uint8 type with range [0, 255]. If 171 | dst_type is np.float32, it converts the image to np.float32 type 172 | with range [0, 1]. 173 | 174 | Returns: 175 | (ndarray): The converted image with desired type and range. 176 | """ 177 | if dst_type not in (np.uint8, np.float32): 178 | raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}') 179 | if dst_type == np.uint8: 180 | img = img.round() 181 | else: 182 | img /= 255. 183 | return img.astype(dst_type) 184 | 185 | 186 | def rgb2ycbcr_pt(img, y_only=False): 187 | """Convert RGB images to YCbCr images (PyTorch version). 188 | 189 | It implements the ITU-R BT.601 conversion for standard-definition television. See more details in 190 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 191 | 192 | Args: 193 | img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. 194 | y_only (bool): Whether to only return Y channel. Default: False. 195 | 196 | Returns: 197 | (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. 198 | """ 199 | if y_only: 200 | weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) 201 | out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 202 | else: 203 | weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) 204 | bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) 205 | out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias 206 | 207 | out_img = out_img / 255. 208 | return out_img 209 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | 14 | Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive 15 | 16 | Args: 17 | file_id (str): File id. 18 | save_path (str): Save path. 19 | """ 20 | 21 | session = requests.Session() 22 | URL = 'https://docs.google.com/uc?export=download' 23 | params = {'id': file_id} 24 | 25 | response = session.get(URL, params=params, stream=True) 26 | token = get_confirm_token(response) 27 | if token: 28 | params['confirm'] = token 29 | response = session.get(URL, params=params, stream=True) 30 | 31 | # get file size 32 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 33 | if 'Content-Range' in response_file_size.headers: 34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 49 | if file_size is not None: 50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 51 | 52 | readable_file_size = sizeof_fmt(file_size) 53 | else: 54 | pbar = None 55 | 56 | with open(destination, 'wb') as f: 57 | downloaded_size = 0 58 | for chunk in response.iter_content(chunk_size): 59 | downloaded_size += chunk_size 60 | if pbar is not None: 61 | pbar.update(1) 62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | if pbar is not None: 66 | pbar.close() 67 | 68 | 69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 70 | """Load file form http url, will download models if necessary. 71 | 72 | Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 73 | 74 | Args: 75 | url (str): URL to be downloaded. 76 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 77 | Default: None. 78 | progress (bool): Whether to show the download progress. Default: True. 79 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 80 | 81 | Returns: 82 | str: The path to the downloaded file. 83 | """ 84 | if model_dir is None: # use the pytorch hub_dir 85 | hub_dir = get_dir() 86 | model_dir = os.path.join(hub_dir, 'checkpoints') 87 | 88 | os.makedirs(model_dir, exist_ok=True) 89 | 90 | parts = urlparse(url) 91 | filename = os.path.basename(parts.path) 92 | if file_name is not None: 93 | filename = file_name 94 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 95 | if not os.path.exists(cached_file): 96 | print(f'Downloading: "{url}" to {cached_file}\n') 97 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 98 | return cached_file 99 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing different lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') 25 | assert cat_flow.shape[concat_axis] % 2 == 0 26 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 27 | flow = dequantize_flow(dx, dy, *args, **kwargs) 28 | else: 29 | with open(flow_path, 'rb') as f: 30 | try: 31 | header = f.read(4).decode('utf-8') 32 | except Exception: 33 | raise IOError(f'Invalid flow file: {flow_path}') 34 | else: 35 | if header != 'PIEH': 36 | raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') 37 | 38 | w = np.fromfile(f, np.int32, 1).squeeze() 39 | h = np.fromfile(f, np.int32, 1).squeeze() 40 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 41 | 42 | return flow.astype(np.float32) 43 | 44 | 45 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 46 | """Write optical flow to file. 47 | 48 | If the flow is not quantized, it will be saved as a .flo file losslessly, 49 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 50 | will be concatenated horizontally into a single image if quantize is True.) 51 | 52 | Args: 53 | flow (ndarray): (h, w, 2) array of optical flow. 54 | filename (str): Output filepath. 55 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 56 | images. If set to True, remaining args will be passed to 57 | :func:`quantize_flow`. 58 | concat_axis (int): The axis that dx and dy are concatenated, 59 | can be either 0 or 1. Ignored if quantize is False. 60 | """ 61 | if not quantize: 62 | with open(filename, 'wb') as f: 63 | f.write('PIEH'.encode('utf-8')) 64 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 65 | flow = flow.astype(np.float32) 66 | flow.tofile(f) 67 | f.flush() 68 | else: 69 | assert concat_axis in [0, 1] 70 | dx, dy = quantize_flow(flow, *args, **kwargs) 71 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 72 | os.makedirs(os.path.dirname(filename), exist_ok=True) 73 | cv2.imwrite(filename, dxdy) 74 | 75 | 76 | def quantize_flow(flow, max_val=0.02, norm=True): 77 | """Quantize flow to [0, 255]. 78 | 79 | After this step, the size of flow will be much smaller, and can be 80 | dumped as jpeg images. 81 | 82 | Args: 83 | flow (ndarray): (h, w, 2) array of optical flow. 84 | max_val (float): Maximum value of flow, values beyond 85 | [-max_val, max_val] will be truncated. 86 | norm (bool): Whether to divide flow values by image width/height. 87 | 88 | Returns: 89 | tuple[ndarray]: Quantized dx and dy. 90 | """ 91 | h, w, _ = flow.shape 92 | dx = flow[..., 0] 93 | dy = flow[..., 1] 94 | if norm: 95 | dx = dx / w # avoid inplace operations 96 | dy = dy / h 97 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 98 | flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] 99 | return tuple(flow_comps) 100 | 101 | 102 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 103 | """Recover from quantized flow. 104 | 105 | Args: 106 | dx (ndarray): Quantized dx. 107 | dy (ndarray): Quantized dy. 108 | max_val (float): Maximum value used when quantizing. 109 | denorm (bool): Whether to multiply flow values with width/height. 110 | 111 | Returns: 112 | ndarray: Dequantized flow. 113 | """ 114 | assert dx.shape == dy.shape 115 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 116 | 117 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 118 | 119 | if denorm: 120 | dx *= dx.shape[1] 121 | dy *= dx.shape[0] 122 | flow = np.dstack((dx, dy)) 123 | return flow 124 | 125 | 126 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 127 | """Quantize an array of (-inf, inf) to [0, levels-1]. 128 | 129 | Args: 130 | arr (ndarray): Input array. 131 | min_val (scalar): Minimum value to be clipped. 132 | max_val (scalar): Maximum value to be clipped. 133 | levels (int): Quantization levels. 134 | dtype (np.type): The type of the quantized array. 135 | 136 | Returns: 137 | tuple: Quantized array. 138 | """ 139 | if not (isinstance(levels, int) and levels > 1): 140 | raise ValueError(f'levels must be a positive integer, but got {levels}') 141 | if min_val >= max_val: 142 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 143 | 144 | arr = np.clip(arr, min_val, max_val) - min_val 145 | quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 146 | 147 | return quantized_arr 148 | 149 | 150 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 151 | """Dequantize an array. 152 | 153 | Args: 154 | arr (ndarray): Input array. 155 | min_val (scalar): Minimum value to be clipped. 156 | max_val (scalar): Maximum value to be clipped. 157 | levels (int): Quantization levels. 158 | dtype (np.type): The type of the dequantized array. 159 | 160 | Returns: 161 | tuple: Dequantized array. 162 | """ 163 | if not (isinstance(levels, int) and levels > 1): 164 | raise ValueError(f'levels must be a positive integer, but got {levels}') 165 | if min_val >= max_val: 166 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 167 | 168 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val 169 | 170 | return dequantized_arr 171 | -------------------------------------------------------------------------------- /basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | ok = cv2.imwrite(file_path, img, params) 152 | if not ok: 153 | raise IOError('Failed in writing images.') 154 | 155 | 156 | def crop_border(imgs, crop_border): 157 | """Crop borders of images. 158 | 159 | Args: 160 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 161 | crop_border (int): Crop border for each end of height and weight. 162 | 163 | Returns: 164 | list[ndarray]: Cropped images. 165 | """ 166 | if crop_border == 0: 167 | return imgs 168 | else: 169 | if isinstance(imgs, list): 170 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 171 | else: 172 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 173 | -------------------------------------------------------------------------------- /basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | 22 | :: 23 | 24 | example.lmdb 25 | ├── data.mdb 26 | ├── lock.mdb 27 | ├── meta_info.txt 28 | 29 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 30 | https://lmdb.readthedocs.io/en/release/ for more details. 31 | 32 | The meta_info.txt is a specified txt file to record the meta information 33 | of our datasets. It will be automatically created when preparing 34 | datasets by our provided dataset tools. 35 | Each line in the txt file records 1)image name (with extension), 36 | 2)image shape, and 3)compression level, separated by a white space. 37 | 38 | For example, the meta information could be: 39 | `000_00000000.png (720,1280,3) 1`, which means: 40 | 1) image name (with extension): 000_00000000.png; 41 | 2) image shape: (720,1280,3); 42 | 3) compression level: 1 43 | 44 | We use the image name without extension as the lmdb key. 45 | 46 | If `multiprocessing_read` is True, it will read all the images to memory 47 | using multiprocessing. Thus, your server needs to have enough memory. 48 | 49 | Args: 50 | data_path (str): Data path for reading images. 51 | lmdb_path (str): Lmdb save path. 52 | img_path_list (str): Image path list. 53 | keys (str): Used for lmdb keys. 54 | batch (int): After processing batch images, lmdb commits. 55 | Default: 5000. 56 | compress_level (int): Compress level when encoding images. Default: 1. 57 | multiprocessing_read (bool): Whether use multiprocessing to read all 58 | the images to memory. Default: False. 59 | n_thread (int): For multiprocessing. 60 | map_size (int | None): Map size for lmdb env. If None, use the 61 | estimated size from images. Default: None 62 | """ 63 | 64 | assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' 65 | f'but got {len(img_path_list)} and {len(keys)}') 66 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 67 | print(f'Totoal images: {len(img_path_list)}') 68 | if not lmdb_path.endswith('.lmdb'): 69 | raise ValueError("lmdb_path must end with '.lmdb'.") 70 | if osp.exists(lmdb_path): 71 | print(f'Folder {lmdb_path} already exists. Exit.') 72 | sys.exit(1) 73 | 74 | if multiprocessing_read: 75 | # read all the images to memory (multiprocessing) 76 | dataset = {} # use dict to keep the order for multiprocessing 77 | shapes = {} 78 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 79 | pbar = tqdm(total=len(img_path_list), unit='image') 80 | 81 | def callback(arg): 82 | """get the image data and update pbar.""" 83 | key, dataset[key], shapes[key] = arg 84 | pbar.update(1) 85 | pbar.set_description(f'Read {key}') 86 | 87 | pool = Pool(n_thread) 88 | for path, key in zip(img_path_list, keys): 89 | pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) 90 | pool.close() 91 | pool.join() 92 | pbar.close() 93 | print(f'Finish reading {len(img_path_list)} images.') 94 | 95 | # create lmdb environment 96 | if map_size is None: 97 | # obtain data size for one image 98 | img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 99 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 100 | data_size_per_img = img_byte.nbytes 101 | print('Data size per image is: ', data_size_per_img) 102 | data_size = data_size_per_img * len(img_path_list) 103 | map_size = data_size * 10 104 | 105 | env = lmdb.open(lmdb_path, map_size=map_size) 106 | 107 | # write data to lmdb 108 | pbar = tqdm(total=len(img_path_list), unit='chunk') 109 | txn = env.begin(write=True) 110 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 111 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 112 | pbar.update(1) 113 | pbar.set_description(f'Write {key}') 114 | key_byte = key.encode('ascii') 115 | if multiprocessing_read: 116 | img_byte = dataset[key] 117 | h, w, c = shapes[key] 118 | else: 119 | _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) 120 | h, w, c = img_shape 121 | 122 | txn.put(key_byte, img_byte) 123 | # write meta information 124 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 125 | if idx % batch == 0: 126 | txn.commit() 127 | txn = env.begin(write=True) 128 | pbar.close() 129 | txn.commit() 130 | env.close() 131 | txt_file.close() 132 | print('\nFinish writing lmdb.') 133 | 134 | 135 | def read_img_worker(path, key, compress_level): 136 | """Read image worker. 137 | 138 | Args: 139 | path (str): Image path. 140 | key (str): Image key. 141 | compress_level (int): Compress level when encoding images. 142 | 143 | Returns: 144 | str: Image key. 145 | byte: Image byte. 146 | tuple[int]: Image shape. 147 | """ 148 | 149 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 150 | if img.ndim == 2: 151 | h, w = img.shape 152 | c = 1 153 | else: 154 | h, w, c = img.shape 155 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 156 | return (key, img_byte, (h, w, c)) 157 | 158 | 159 | class LmdbMaker(): 160 | """LMDB Maker. 161 | 162 | Args: 163 | lmdb_path (str): Lmdb save path. 164 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 165 | batch (int): After processing batch images, lmdb commits. 166 | Default: 5000. 167 | compress_level (int): Compress level when encoding images. Default: 1. 168 | """ 169 | 170 | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): 171 | if not lmdb_path.endswith('.lmdb'): 172 | raise ValueError("lmdb_path must end with '.lmdb'.") 173 | if osp.exists(lmdb_path): 174 | print(f'Folder {lmdb_path} already exists. Exit.') 175 | sys.exit(1) 176 | 177 | self.lmdb_path = lmdb_path 178 | self.batch = batch 179 | self.compress_level = compress_level 180 | self.env = lmdb.open(lmdb_path, map_size=map_size) 181 | self.txn = self.env.begin(write=True) 182 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 183 | self.counter = 0 184 | 185 | def put(self, img_byte, key, img_shape): 186 | self.counter += 1 187 | key_byte = key.encode('ascii') 188 | self.txn.put(key_byte, img_byte) 189 | # write meta information 190 | h, w, c = img_shape 191 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 192 | if self.counter % self.batch == 0: 193 | self.txn.commit() 194 | self.txn = self.env.begin(write=True) 195 | 196 | def close(self): 197 | self.txn.commit() 198 | self.env.close() 199 | self.txt_file.close() 200 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class AvgTimer(): 11 | 12 | def __init__(self, window=200): 13 | self.window = window # average window 14 | self.current_time = 0 15 | self.total_time = 0 16 | self.count = 0 17 | self.avg_time = 0 18 | self.start() 19 | 20 | def start(self): 21 | self.start_time = self.tic = time.time() 22 | 23 | def record(self): 24 | self.count += 1 25 | self.toc = time.time() 26 | self.current_time = self.toc - self.tic 27 | self.total_time += self.current_time 28 | # calculate average time 29 | self.avg_time = self.total_time / self.count 30 | 31 | # reset 32 | if self.count > self.window: 33 | self.count = 0 34 | self.total_time = 0 35 | 36 | self.tic = time.time() 37 | 38 | def get_current_time(self): 39 | return self.current_time 40 | 41 | def get_avg_time(self): 42 | return self.avg_time 43 | 44 | 45 | class MessageLogger(): 46 | """Message logger for printing. 47 | 48 | Args: 49 | opt (dict): Config. It contains the following keys: 50 | name (str): Exp name. 51 | logger (dict): Contains 'print_freq' (str) for logger interval. 52 | train (dict): Contains 'total_iter' (int) for total iters. 53 | use_tb_logger (bool): Use tensorboard logger. 54 | start_iter (int): Start iter. Default: 1. 55 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 56 | """ 57 | 58 | def __init__(self, opt, start_iter=1, tb_logger=None): 59 | self.exp_name = opt['name'] 60 | self.interval = opt['logger']['print_freq'] 61 | self.start_iter = start_iter 62 | self.max_iters = opt['train']['total_iter'] 63 | self.use_tb_logger = opt['logger']['use_tb_logger'] 64 | self.tb_logger = tb_logger 65 | self.start_time = time.time() 66 | self.logger = get_root_logger() 67 | 68 | def reset_start_time(self): 69 | self.start_time = time.time() 70 | 71 | @master_only 72 | def __call__(self, log_vars): 73 | """Format logging message. 74 | 75 | Args: 76 | log_vars (dict): It contains the following keys: 77 | epoch (int): Epoch number. 78 | iter (int): Current iter. 79 | lrs (list): List for learning rates. 80 | 81 | time (float): Iter time. 82 | data_time (float): Data time for each iter. 83 | """ 84 | # epoch, iter, learning rates 85 | epoch = log_vars.pop('epoch') 86 | current_iter = log_vars.pop('iter') 87 | lrs = log_vars.pop('lrs') 88 | 89 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') 90 | for v in lrs: 91 | message += f'{v:.3e},' 92 | message += ')] ' 93 | 94 | # time and estimated time 95 | if 'time' in log_vars.keys(): 96 | iter_time = log_vars.pop('time') 97 | data_time = log_vars.pop('data_time') 98 | 99 | total_time = time.time() - self.start_time 100 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 101 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 102 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 103 | message += f'[eta: {eta_str}, ' 104 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 105 | 106 | # other items, especially losses 107 | for k, v in log_vars.items(): 108 | message += f'{k}: {v:.4e} ' 109 | # tensorboard logger 110 | if self.use_tb_logger and 'debug' not in self.exp_name: 111 | if k.startswith('l_'): 112 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 113 | else: 114 | self.tb_logger.add_scalar(k, v, current_iter) 115 | self.logger.info(message) 116 | 117 | 118 | @master_only 119 | def init_tb_logger(log_dir): 120 | from torch.utils.tensorboard import SummaryWriter 121 | tb_logger = SummaryWriter(log_dir=log_dir) 122 | return tb_logger 123 | 124 | 125 | @master_only 126 | def init_wandb_logger(opt): 127 | """We now only use wandb to sync tensorboard log.""" 128 | import wandb 129 | logger = get_root_logger() 130 | 131 | project = opt['logger']['wandb']['project'] 132 | resume_id = opt['logger']['wandb'].get('resume_id') 133 | if resume_id: 134 | wandb_id = resume_id 135 | resume = 'allow' 136 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 137 | else: 138 | wandb_id = wandb.util.generate_id() 139 | resume = 'never' 140 | 141 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 142 | 143 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 144 | 145 | 146 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 147 | """Get the root logger. 148 | 149 | The logger will be initialized if it has not been initialized. By default a 150 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 151 | also be added. 152 | 153 | Args: 154 | logger_name (str): root logger name. Default: 'basicsr'. 155 | log_file (str | None): The log filename. If specified, a FileHandler 156 | will be added to the root logger. 157 | log_level (int): The root logger level. Note that only the process of 158 | rank 0 is affected, while other processes will set the level to 159 | "Error" and be silent most of the time. 160 | 161 | Returns: 162 | logging.Logger: The root logger. 163 | """ 164 | logger = logging.getLogger(logger_name) 165 | # if the logger has been initialized, just return it 166 | if logger_name in initialized_logger: 167 | return logger 168 | 169 | format_str = '%(asctime)s %(levelname)s: %(message)s' 170 | stream_handler = logging.StreamHandler() 171 | stream_handler.setFormatter(logging.Formatter(format_str)) 172 | logger.addHandler(stream_handler) 173 | logger.propagate = False 174 | rank, _ = get_dist_info() 175 | if rank != 0: 176 | logger.setLevel('ERROR') 177 | elif log_file is not None: 178 | logger.setLevel(log_level) 179 | # add file handler 180 | file_handler = logging.FileHandler(log_file, 'w') 181 | file_handler.setFormatter(logging.Formatter(format_str)) 182 | file_handler.setLevel(log_level) 183 | logger.addHandler(file_handler) 184 | initialized_logger[logger_name] = True 185 | return logger 186 | 187 | 188 | def get_env_info(): 189 | """Get environment information. 190 | 191 | Currently, only log the software version. 192 | """ 193 | import torch 194 | import torchvision 195 | 196 | from basicsr.version import __version__ 197 | msg = r""" 198 | ____ _ _____ ____ 199 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 200 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 201 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 202 | /_____/ \__,_//____//_/ \___//____//_/ |_| 203 | ______ __ __ __ __ 204 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 205 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 206 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 207 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 208 | """ 209 | msg += ('\nVersion Information: ' 210 | f'\n\tBasicSR: {__version__}' 211 | f'\n\tPyTorch: {torch.__version__}' 212 | f'\n\tTorchVision: {torchvision.__version__}') 213 | return msg 214 | -------------------------------------------------------------------------------- /basicsr/utils/matlab_functions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def cubic(x): 7 | """cubic function used for calculate_weights_indices.""" 8 | absx = torch.abs(x) 9 | absx2 = absx**2 10 | absx3 = absx**3 11 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ( 12 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * 13 | (absx <= 2)).type_as(absx)) 14 | 15 | 16 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): 17 | """Calculate weights and indices, used for imresize function. 18 | 19 | Args: 20 | in_length (int): Input length. 21 | out_length (int): Output length. 22 | scale (float): Scale factor. 23 | kernel_width (int): Kernel width. 24 | antialisaing (bool): Whether to apply anti-aliasing when downsampling. 25 | """ 26 | 27 | if (scale < 1) and antialiasing: 28 | # Use a modified kernel (larger kernel width) to simultaneously 29 | # interpolate and antialias 30 | kernel_width = kernel_width / scale 31 | 32 | # Output-space coordinates 33 | x = torch.linspace(1, out_length, out_length) 34 | 35 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 36 | # in output space maps to 0.5 in input space, and 0.5 + scale in output 37 | # space maps to 1.5 in input space. 38 | u = x / scale + 0.5 * (1 - 1 / scale) 39 | 40 | # What is the left-most pixel that can be involved in the computation? 41 | left = torch.floor(u - kernel_width / 2) 42 | 43 | # What is the maximum number of pixels that can be involved in the 44 | # computation? Note: it's OK to use an extra pixel here; if the 45 | # corresponding weights are all zero, it will be eliminated at the end 46 | # of this function. 47 | p = math.ceil(kernel_width) + 2 48 | 49 | # The indices of the input pixels involved in computing the k-th output 50 | # pixel are in row k of the indices matrix. 51 | indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( 52 | out_length, p) 53 | 54 | # The weights used to compute the k-th output pixel are in row k of the 55 | # weights matrix. 56 | distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices 57 | 58 | # apply cubic kernel 59 | if (scale < 1) and antialiasing: 60 | weights = scale * cubic(distance_to_center * scale) 61 | else: 62 | weights = cubic(distance_to_center) 63 | 64 | # Normalize the weights matrix so that each row sums to 1. 65 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 66 | weights = weights / weights_sum.expand(out_length, p) 67 | 68 | # If a column in weights is all zero, get rid of it. only consider the 69 | # first and last column. 70 | weights_zero_tmp = torch.sum((weights == 0), 0) 71 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 72 | indices = indices.narrow(1, 1, p - 2) 73 | weights = weights.narrow(1, 1, p - 2) 74 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 75 | indices = indices.narrow(1, 0, p - 2) 76 | weights = weights.narrow(1, 0, p - 2) 77 | weights = weights.contiguous() 78 | indices = indices.contiguous() 79 | sym_len_s = -indices.min() + 1 80 | sym_len_e = indices.max() - in_length 81 | indices = indices + sym_len_s - 1 82 | return weights, indices, int(sym_len_s), int(sym_len_e) 83 | 84 | 85 | @torch.no_grad() 86 | def imresize(img, scale, antialiasing=True): 87 | """imresize function same as MATLAB. 88 | 89 | It now only supports bicubic. 90 | The same scale applies for both height and width. 91 | 92 | Args: 93 | img (Tensor | Numpy array): 94 | Tensor: Input image with shape (c, h, w), [0, 1] range. 95 | Numpy: Input image with shape (h, w, c), [0, 1] range. 96 | scale (float): Scale factor. The same scale applies for both height 97 | and width. 98 | antialisaing (bool): Whether to apply anti-aliasing when downsampling. 99 | Default: True. 100 | 101 | Returns: 102 | Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. 103 | """ 104 | squeeze_flag = False 105 | if type(img).__module__ == np.__name__: # numpy type 106 | numpy_type = True 107 | if img.ndim == 2: 108 | img = img[:, :, None] 109 | squeeze_flag = True 110 | img = torch.from_numpy(img.transpose(2, 0, 1)).float() 111 | else: 112 | numpy_type = False 113 | if img.ndim == 2: 114 | img = img.unsqueeze(0) 115 | squeeze_flag = True 116 | 117 | in_c, in_h, in_w = img.size() 118 | out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) 119 | kernel_width = 4 120 | kernel = 'cubic' 121 | 122 | # get weights and indices 123 | weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, 124 | antialiasing) 125 | weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, 126 | antialiasing) 127 | # process H dimension 128 | # symmetric copying 129 | img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) 130 | img_aug.narrow(1, sym_len_hs, in_h).copy_(img) 131 | 132 | sym_patch = img[:, :sym_len_hs, :] 133 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 134 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 135 | img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) 136 | 137 | sym_patch = img[:, -sym_len_he:, :] 138 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 139 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 140 | img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) 141 | 142 | out_1 = torch.FloatTensor(in_c, out_h, in_w) 143 | kernel_width = weights_h.size(1) 144 | for i in range(out_h): 145 | idx = int(indices_h[i][0]) 146 | for j in range(in_c): 147 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) 148 | 149 | # process W dimension 150 | # symmetric copying 151 | out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) 152 | out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) 153 | 154 | sym_patch = out_1[:, :, :sym_len_ws] 155 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 156 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 157 | out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) 158 | 159 | sym_patch = out_1[:, :, -sym_len_we:] 160 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 161 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 162 | out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) 163 | 164 | out_2 = torch.FloatTensor(in_c, out_h, out_w) 165 | kernel_width = weights_w.size(1) 166 | for i in range(out_w): 167 | idx = int(indices_w[i][0]) 168 | for j in range(in_c): 169 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) 170 | 171 | if squeeze_flag: 172 | out_2 = out_2.squeeze(0) 173 | if numpy_type: 174 | out_2 = out_2.numpy() 175 | if not squeeze_flag: 176 | out_2 = out_2.transpose(1, 2, 0) 177 | 178 | return out_2 179 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file size. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import torch 5 | import yaml 6 | from collections import OrderedDict 7 | from os import path as osp 8 | 9 | from basicsr.utils import set_random_seed 10 | from basicsr.utils.dist_util import get_dist_info, init_dist, master_only 11 | 12 | 13 | def ordered_yaml(): 14 | """Support OrderedDict for yaml. 15 | 16 | Returns: 17 | tuple: yaml Loader and Dumper. 18 | """ 19 | try: 20 | from yaml import CDumper as Dumper 21 | from yaml import CLoader as Loader 22 | except ImportError: 23 | from yaml import Dumper, Loader 24 | 25 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 26 | 27 | def dict_representer(dumper, data): 28 | return dumper.represent_dict(data.items()) 29 | 30 | def dict_constructor(loader, node): 31 | return OrderedDict(loader.construct_pairs(node)) 32 | 33 | Dumper.add_representer(OrderedDict, dict_representer) 34 | Loader.add_constructor(_mapping_tag, dict_constructor) 35 | return Loader, Dumper 36 | 37 | 38 | def yaml_load(f): 39 | """Load yaml file or string. 40 | 41 | Args: 42 | f (str): File path or a python string. 43 | 44 | Returns: 45 | dict: Loaded dict. 46 | """ 47 | if os.path.isfile(f): 48 | with open(f, 'r') as f: 49 | return yaml.load(f, Loader=ordered_yaml()[0]) 50 | else: 51 | return yaml.load(f, Loader=ordered_yaml()[0]) 52 | 53 | 54 | def dict2str(opt, indent_level=1): 55 | """dict to string for printing options. 56 | 57 | Args: 58 | opt (dict): Option dict. 59 | indent_level (int): Indent level. Default: 1. 60 | 61 | Return: 62 | (str): Option string for printing. 63 | """ 64 | msg = '\n' 65 | for k, v in opt.items(): 66 | if isinstance(v, dict): 67 | msg += ' ' * (indent_level * 2) + k + ':[' 68 | msg += dict2str(v, indent_level + 1) 69 | msg += ' ' * (indent_level * 2) + ']\n' 70 | else: 71 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 72 | return msg 73 | 74 | 75 | def _postprocess_yml_value(value): 76 | # None 77 | if value == '~' or value.lower() == 'none': 78 | return None 79 | # bool 80 | if value.lower() == 'true': 81 | return True 82 | elif value.lower() == 'false': 83 | return False 84 | # !!float number 85 | if value.startswith('!!float'): 86 | return float(value.replace('!!float', '')) 87 | # number 88 | if value.isdigit(): 89 | return int(value) 90 | elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: 91 | return float(value) 92 | # list 93 | if value.startswith('['): 94 | return eval(value) 95 | # str 96 | return value 97 | 98 | 99 | def parse_options(root_path, is_train=True): 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') 102 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') 103 | parser.add_argument('--auto_resume', action='store_true') 104 | parser.add_argument('--debug', action='store_true') 105 | parser.add_argument('--local_rank', type=int, default=0) 106 | parser.add_argument( 107 | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') 108 | args = parser.parse_args() 109 | 110 | # parse yml to dict 111 | opt = yaml_load(args.opt) 112 | 113 | # distributed settings 114 | if args.launcher == 'none': 115 | opt['dist'] = False 116 | print('Disable distributed.', flush=True) 117 | else: 118 | opt['dist'] = True 119 | if args.launcher == 'slurm' and 'dist_params' in opt: 120 | init_dist(args.launcher, **opt['dist_params']) 121 | else: 122 | init_dist(args.launcher) 123 | opt['rank'], opt['world_size'] = get_dist_info() 124 | 125 | # random seed 126 | seed = opt.get('manual_seed') 127 | if seed is None: 128 | seed = random.randint(1, 10000) 129 | opt['manual_seed'] = seed 130 | set_random_seed(seed + opt['rank']) 131 | 132 | # force to update yml options 133 | if args.force_yml is not None: 134 | for entry in args.force_yml: 135 | # now do not support creating new keys 136 | keys, value = entry.split('=') 137 | keys, value = keys.strip(), value.strip() 138 | value = _postprocess_yml_value(value) 139 | eval_str = 'opt' 140 | for key in keys.split(':'): 141 | eval_str += f'["{key}"]' 142 | eval_str += '=value' 143 | # using exec function 144 | exec(eval_str) 145 | 146 | opt['auto_resume'] = args.auto_resume 147 | opt['is_train'] = is_train 148 | 149 | # debug setting 150 | if args.debug and not opt['name'].startswith('debug'): 151 | opt['name'] = 'debug_' + opt['name'] 152 | 153 | if opt['num_gpu'] == 'auto': 154 | opt['num_gpu'] = torch.cuda.device_count() 155 | 156 | # datasets 157 | for phase, dataset in opt['datasets'].items(): 158 | # for multiple datasets, e.g., val_1, val_2; test_1, test_2 159 | phase = phase.split('_')[0] 160 | dataset['phase'] = phase 161 | if 'scale' in opt: 162 | dataset['scale'] = opt['scale'] 163 | if dataset.get('dataroot_gt') is not None: 164 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 165 | if dataset.get('dataroot_lq') is not None: 166 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 167 | 168 | # paths 169 | for key, val in opt['path'].items(): 170 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 171 | opt['path'][key] = osp.expanduser(val) 172 | 173 | if is_train: 174 | experiments_root = opt['path'].get('experiments_root') 175 | if experiments_root is None: 176 | experiments_root = osp.join(root_path, 'experiments') 177 | experiments_root = osp.join(experiments_root, opt['name']) 178 | 179 | opt['path']['experiments_root'] = experiments_root 180 | opt['path']['models'] = osp.join(experiments_root, 'models') 181 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 182 | opt['path']['log'] = experiments_root 183 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 184 | 185 | # change some options for debug mode 186 | if 'debug' in opt['name']: 187 | if 'val' in opt: 188 | opt['val']['val_freq'] = 8 189 | opt['logger']['print_freq'] = 1 190 | opt['logger']['save_checkpoint_freq'] = 8 191 | else: # test 192 | results_root = opt['path'].get('results_root') 193 | if results_root is None: 194 | results_root = osp.join(root_path, 'results') 195 | results_root = osp.join(results_root, opt['name']) 196 | 197 | opt['path']['results_root'] = results_root 198 | opt['path']['log'] = results_root 199 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 200 | 201 | return opt, args 202 | 203 | 204 | @master_only 205 | def copy_opt_file(opt_file, experiments_root): 206 | # copy the yml file to the experiment root 207 | import sys 208 | import time 209 | from shutil import copyfile 210 | cmd = ' '.join(sys.argv) 211 | filename = osp.join(experiments_root, osp.basename(opt_file)) 212 | copyfile(opt_file, filename) 213 | 214 | with open(filename, 'r+') as f: 215 | lines = f.readlines() 216 | lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') 217 | f.seek(0) 218 | f.writelines(lines) 219 | -------------------------------------------------------------------------------- /basicsr/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def read_data_from_tensorboard(log_path, tag): 5 | """Get raw data (steps and values) from tensorboard events. 6 | 7 | Args: 8 | log_path (str): Path to the tensorboard log. 9 | tag (str): tag to be read. 10 | """ 11 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 12 | 13 | # tensorboard event 14 | event_acc = EventAccumulator(log_path) 15 | event_acc.Reload() 16 | scalar_list = event_acc.Tags()['scalars'] 17 | print('tag list: ', scalar_list) 18 | steps = [int(s.step) for s in event_acc.Scalars(tag)] 19 | values = [s.value for s in event_acc.Scalars(tag)] 20 | return steps, values 21 | 22 | 23 | def read_data_from_txt_2v(path, pattern, step_one=False): 24 | """Read data from txt with 2 returned values (usually [step, value]). 25 | 26 | Args: 27 | path (str): path to the txt file. 28 | pattern (str): re (regular expression) pattern. 29 | step_one (bool): add 1 to steps. Default: False. 30 | """ 31 | with open(path) as f: 32 | lines = f.readlines() 33 | lines = [line.strip() for line in lines] 34 | steps = [] 35 | values = [] 36 | 37 | pattern = re.compile(pattern) 38 | for line in lines: 39 | match = pattern.match(line) 40 | if match: 41 | steps.append(int(match.group(1))) 42 | values.append(float(match.group(2))) 43 | if step_one: 44 | steps = [v + 1 for v in steps] 45 | return steps, values 46 | 47 | 48 | def read_data_from_txt_1v(path, pattern): 49 | """Read data from txt with 1 returned values. 50 | 51 | Args: 52 | path (str): path to the txt file. 53 | pattern (str): re (regular expression) pattern. 54 | """ 55 | with open(path) as f: 56 | lines = f.readlines() 57 | lines = [line.strip() for line in lines] 58 | data = [] 59 | 60 | pattern = re.compile(pattern) 61 | for line in lines: 62 | match = pattern.match(line) 63 | if match: 64 | data.append(float(match.group(1))) 65 | return data 66 | 67 | 68 | def smooth_data(values, smooth_weight): 69 | """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). 70 | 71 | Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501 72 | 73 | Args: 74 | values (list): A list of values to be smoothed. 75 | smooth_weight (float): Smooth weight. 76 | """ 77 | values_sm = [] 78 | last_sm_value = values[0] 79 | for value in values: 80 | value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value 81 | values_sm.append(value_sm) 82 | last_sm_value = value_sm 83 | return values_sm 84 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj, suffix=None): 39 | if isinstance(suffix, str): 40 | name = name + '_' + suffix 41 | 42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 43 | f"in '{self._name}' registry!") 44 | self._obj_map[name] = obj 45 | 46 | def register(self, obj=None, suffix=None): 47 | """ 48 | Register the given object under the the name `obj.__name__`. 49 | Can be used as either a decorator or not. 50 | See docstring of this class for usage. 51 | """ 52 | if obj is None: 53 | # used as a decorator 54 | def deco(func_or_class): 55 | name = func_or_class.__name__ 56 | self._do_register(name, func_or_class, suffix) 57 | return func_or_class 58 | 59 | return deco 60 | 61 | # used as a function call 62 | name = obj.__name__ 63 | self._do_register(name, obj, suffix) 64 | 65 | def get(self, name, suffix='basicsr'): 66 | ret = self._obj_map.get(name) 67 | if ret is None: 68 | ret = self._obj_map.get(name + '_' + suffix) 69 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 70 | if ret is None: 71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 72 | return ret 73 | 74 | def __contains__(self, name): 75 | return name in self._obj_map 76 | 77 | def __iter__(self): 78 | return iter(self._obj_map.items()) 79 | 80 | def keys(self): 81 | return self._obj_map.keys() 82 | 83 | 84 | DATASET_REGISTRY = Registry('dataset') 85 | ARCH_REGISTRY = Registry('arch') 86 | MODEL_REGISTRY = Registry('model') 87 | LOSS_REGISTRY = Registry('loss') 88 | METRIC_REGISTRY = Registry('metric') 89 | -------------------------------------------------------------------------------- /options/test/SCNet/SCNet-T-x4-PS.yml: -------------------------------------------------------------------------------- 1 | # Modified SRResNet w/o BN from: 2 | # Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 3 | 4 | # general settings 5 | name: SCNet-T-x4_D64B16_PS 6 | model_type: SRModel 7 | scale: 4 8 | num_gpu: 1 # set num_gpu: 0 for cpu mode 9 | manual_seed: 0 10 | 11 | # dataset and data loader settings 12 | datasets: 13 | test_1: # the 1st test dataset 14 | name: Set5 15 | type: PairedImageDataset 16 | dataroot_gt: datasets/benchmark/Set5/HR 17 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 18 | filename_tmpl: '{}x4' 19 | io_backend: 20 | type: disk 21 | test_2: # the 2nd test dataset 22 | name: Set14 23 | type: PairedImageDataset 24 | dataroot_gt: datasets/benchmark/Set14/HR 25 | dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 26 | filename_tmpl: '{}x4' 27 | io_backend: 28 | type: disk 29 | 30 | test_3: 31 | name: B100 32 | type: PairedImageDataset 33 | dataroot_gt: datasets/benchmark/B100/HR 34 | dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 35 | filename_tmpl: '{}x4' 36 | io_backend: 37 | type: disk 38 | test_4: 39 | name: Urban100 40 | type: PairedImageDataset 41 | dataroot_gt: datasets/benchmark/Urban100/HR 42 | dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 43 | filename_tmpl: '{}x4' 44 | io_backend: 45 | type: disk 46 | test_5: 47 | name: Manga109 48 | type: PairedImageDataset 49 | dataroot_gt: datasets/benchmark/Manga109/HR 50 | dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4 51 | filename_tmpl: '{}x4' 52 | io_backend: 53 | type: disk 54 | 55 | # network structures 56 | network_g: 57 | type: SCNet 58 | num_in_ch: 3 59 | num_out_ch: 3 60 | num_feat: 64 61 | num_block: 16 62 | upscale: 4 63 | use_pixelshuffle: true 64 | 65 | # Upsampling with pixelshuffle operation brings better performance 66 | # Ablations can be found at Section 4.3 and Table 7 67 | 68 | # path 69 | path: 70 | pretrain_network_g: model_zoo/SCNet/SCNet-T_x4_D64B16_PS.pth 71 | strict_load_g: true 72 | resume_state: ~ 73 | 74 | # validation settings 75 | val: 76 | save_img: true 77 | suffix: ~ # add suffix to saved images, if None, use exp name 78 | 79 | metrics: 80 | psnr: # metric name, can be arbitrary 81 | type: calculate_psnr 82 | crop_border: 4 83 | test_y_channel: true 84 | ssim: 85 | type: calculate_ssim 86 | crop_border: 4 87 | test_y_channel: true 88 | -------------------------------------------------------------------------------- /options/test/SCNet/SCNet-T-x4.yml: -------------------------------------------------------------------------------- 1 | # Modified SRResNet w/o BN from: 2 | # Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 3 | 4 | # general settings 5 | name: SCNet-T-x4_D64B16 6 | model_type: SRModel 7 | scale: 4 8 | num_gpu: 1 # set num_gpu: 0 for cpu mode 9 | manual_seed: 0 10 | 11 | # dataset and data loader settings 12 | datasets: 13 | test_1: # the 1st test dataset 14 | name: Set5 15 | type: PairedImageDataset 16 | dataroot_gt: datasets/benchmark/Set5/HR 17 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 18 | filename_tmpl: '{}x4' 19 | io_backend: 20 | type: disk 21 | test_2: # the 2nd test dataset 22 | name: Set14 23 | type: PairedImageDataset 24 | dataroot_gt: datasets/benchmark/Set14/HR 25 | dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 26 | filename_tmpl: '{}x4' 27 | io_backend: 28 | type: disk 29 | 30 | test_3: 31 | name: B100 32 | type: PairedImageDataset 33 | dataroot_gt: datasets/benchmark/B100/HR 34 | dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 35 | filename_tmpl: '{}x4' 36 | io_backend: 37 | type: disk 38 | test_4: 39 | name: Urban100 40 | type: PairedImageDataset 41 | dataroot_gt: datasets/benchmark/Urban100/HR 42 | dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 43 | filename_tmpl: '{}x4' 44 | io_backend: 45 | type: disk 46 | test_5: 47 | name: Manga109 48 | type: PairedImageDataset 49 | dataroot_gt: datasets/benchmark/Manga109/HR 50 | dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4 51 | filename_tmpl: '{}x4' 52 | io_backend: 53 | type: disk 54 | 55 | # network structures 56 | network_g: 57 | type: SCNet 58 | num_in_ch: 3 59 | num_out_ch: 3 60 | num_feat: 64 61 | num_block: 16 62 | upscale: 4 63 | # path 64 | path: 65 | pretrain_network_g: model_zoo/SCNet/SCNet-T-x4.pth 66 | strict_load_g: true 67 | resume_state: ~ 68 | 69 | # validation settings 70 | val: 71 | save_img: true 72 | suffix: ~ # add suffix to saved images, if None, use exp name 73 | 74 | metrics: 75 | psnr: # metric name, can be arbitrary 76 | type: calculate_psnr 77 | crop_border: 4 78 | test_y_channel: true 79 | ssim: 80 | type: calculate_ssim 81 | crop_border: 4 82 | test_y_channel: true 83 | -------------------------------------------------------------------------------- /options/train/SCNet/SCNet-T-x4.yml: -------------------------------------------------------------------------------- 1 | # Modified from SRResNet w/o BN config in BasicSR: 2 | 3 | # general settings 4 | name: SCNet-T-x4 5 | model_type: SRModel 6 | scale: 4 7 | num_gpu: 2 # set num_gpu: 0 for cpu mode 8 | manual_seed: 0 9 | 10 | # dataset and data loader settings 11 | datasets: 12 | train: 13 | name: DF2K 14 | type: PairedImageDataset 15 | dataroot_gt: Path to your data 16 | dataroot_lq: Path to your data 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: lmdb 20 | 21 | gt_size: 256 22 | use_hflip: true 23 | use_rot: true 24 | 25 | # data loader 26 | use_shuffle: true 27 | num_worker_per_gpu: 4 28 | batch_size_per_gpu: 8 29 | dataset_enlarge_ratio: 1 30 | pin_memory: true 31 | 32 | val: 33 | name: Set14 34 | type: PairedImageDataset 35 | dataroot_gt: datasets/benchmark/Set14/HR 36 | dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 37 | filename_tmpl: '{}x4' 38 | io_backend: 39 | type: disk 40 | 41 | # network structures 42 | network_g: 43 | type: SCNet 44 | num_in_ch: 3 45 | num_out_ch: 3 46 | num_feat: 64 47 | num_block: 16 48 | upscale: 4 49 | 50 | # path 51 | path: 52 | pretrain_network_g: ~ 53 | strict_load_g: false 54 | resume_state: ~ 55 | # training settings 56 | train: 57 | ema_decay: 0.999 58 | optim_g: 59 | type: Adam 60 | lr: !!float 2e-4 61 | weight_decay: 0 62 | betas: [0.9, 0.99] 63 | 64 | scheduler: 65 | type: MultiStepLR 66 | milestones: [200000] 67 | gamma: 0.5 68 | 69 | total_iter: 300000 70 | warmup_iter: -1 # no warm up 71 | 72 | # losses 73 | pixel_opt: 74 | type: L1Loss 75 | loss_weight: 1.0 76 | reduction: mean 77 | 78 | # validation settings 79 | val: 80 | val_freq: !!float 5e3 81 | save_img: false 82 | 83 | metrics: 84 | psnr: # metric name, can be arbitrary 85 | type: calculate_psnr 86 | crop_border: 4 87 | test_y_channel: true 88 | ssim: # metric name, can be arbitrary 89 | type: calculate_ssim 90 | crop_border: 4 91 | test_y_channel: true 92 | # logging settings 93 | logger: 94 | print_freq: 200 95 | save_checkpoint_freq: !!float 5e3 96 | use_tb_logger: true 97 | wandb: 98 | project: ~ 99 | resume_id: ~ 100 | 101 | # dist training settings 102 | dist_params: 103 | backend: nccl 104 | port: 29500 105 | --------------------------------------------------------------------------------