├── README.md
├── SCNet_arch.py
├── basicsr
├── __init__.py
├── archs
│ ├── SCNet_arch.py
│ ├── __init__.py
│ ├── arch_util.py
│ └── vgg_arch.py
├── data
│ ├── __init__.py
│ ├── data_sampler.py
│ ├── data_util.py
│ ├── degradations.py
│ ├── ffhq_dataset.py
│ ├── meta_info
│ │ ├── meta_info_DIV2K800sub_GT.txt
│ │ ├── meta_info_REDS4_test_GT.txt
│ │ ├── meta_info_REDS_GT.txt
│ │ ├── meta_info_REDSofficial4_test_GT.txt
│ │ ├── meta_info_REDSval_official_test_GT.txt
│ │ ├── meta_info_Vimeo90K_test_GT.txt
│ │ ├── meta_info_Vimeo90K_test_fast_GT.txt
│ │ ├── meta_info_Vimeo90K_test_medium_GT.txt
│ │ ├── meta_info_Vimeo90K_test_slow_GT.txt
│ │ └── meta_info_Vimeo90K_train_GT.txt
│ ├── paired_image_dataset.py
│ ├── prefetch_dataloader.py
│ ├── realesrgan_dataset.py
│ ├── realesrgan_paired_dataset.py
│ ├── reds_dataset.py
│ ├── single_image_dataset.py
│ ├── transforms.py
│ ├── video_test_dataset.py
│ └── vimeo90k_dataset.py
├── losses
│ ├── __init__.py
│ ├── basic_loss.py
│ ├── gan_loss.py
│ └── loss_util.py
├── metrics
│ ├── README.md
│ ├── README_CN.md
│ ├── __init__.py
│ ├── fid.py
│ ├── metric_util.py
│ ├── niqe.py
│ ├── niqe_pris_params.npz
│ ├── psnr_ssim.py
│ └── test_metrics
│ │ └── test_psnr_ssim.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── lr_scheduler.py
│ └── sr_model.py
├── ops
│ ├── __init__.py
│ ├── dcn
│ │ ├── __init__.py
│ │ ├── deform_conv.py
│ │ └── src
│ │ │ ├── deform_conv_cuda.cpp
│ │ │ ├── deform_conv_cuda_kernel.cu
│ │ │ └── deform_conv_ext.cpp
│ ├── fused_act
│ │ ├── __init__.py
│ │ ├── fused_act.py
│ │ └── src
│ │ │ ├── fused_bias_act.cpp
│ │ │ └── fused_bias_act_kernel.cu
│ └── upfirdn2d
│ │ ├── __init__.py
│ │ ├── src
│ │ ├── upfirdn2d.cpp
│ │ └── upfirdn2d_kernel.cu
│ │ └── upfirdn2d.py
├── test.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── color_util.py
│ ├── diffjpeg.py
│ ├── dist_util.py
│ ├── download_util.py
│ ├── file_client.py
│ ├── flow_util.py
│ ├── img_process_util.py
│ ├── img_util.py
│ ├── lmdb_util.py
│ ├── logger.py
│ ├── matlab_functions.py
│ ├── misc.py
│ ├── options.py
│ ├── plot_util.py
│ └── registry.py
└── options
├── test
└── SCNet
│ ├── SCNet-T-x4-PS.yml
│ └── SCNet-T-x4.yml
└── train
└── SCNet
└── SCNet-T-x4.yml
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Fully 1x1 Convolutional Network for Lightweight Image Super-Resolution
4 |
5 | [Gang Wu](https://scholar.google.com/citations?user=JSqb7QIAAAAJ), [Junjun Jiang](http://homepage.hit.edu.cn/jiangjunjun), [Kui Jiang](https://github.com/kuijiang94), and [Xianming Liu](http://homepage.hit.edu.cn/xmliu)
6 |
7 | [AIIA Lab](https://aiialabhit.github.io/team/), Harbin Institute of Technology.
8 |
9 | ---
10 |
11 |
12 | [](https://arxiv.org/abs/2307.16140)
13 | [](https://drive.google.com/drive/folders/1eUqL_8a9DQXZ2uCVyKeWB-6fO1ZdJciG?usp=sharing)
14 | [](https://pan.baidu.com/s/13_syaIXmG3lVnoMgzOS2Ag?pwd=SCSR)
15 | [](https://hits.sh/github.com/Aitical/SCNet/)
16 |
17 |
18 | This repository is the official PyTorch implementation of "Fully 1×1 Convolutional Network for Lightweight Image Super-Resolution". If our work helps your research or work, please cite it.
19 | ```
20 | @article{wu2023fully,
21 | title={Fully $1\times1$ Convolutional Network for Lightweight Image Super-Resolution},
22 | author={Gang Wu and Junjun Jiang and Kui Jiang and Xianming Liu},
23 | year={2023},
24 | journal={Machine Intelligence Research},
25 | doi={10.1007/s11633-024-1401-z},
26 | }
27 | ```
28 | >Wu, Gang, Junjun Jiang, Kui Jiang and Xianming Liu. “Fully 1×1 Convolutional Network for Lightweight Image Super-Resolution.” Machine Intelligence Research.
29 |
30 | ## News
31 |
32 | - [x] Update implementation codes.
33 |
34 | - [x] Upload pre-trained weights utilized in manuscript. You can download from [Google Drive](https://drive.google.com/drive/folders/1eUqL_8a9DQXZ2uCVyKeWB-6fO1ZdJciG?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/13_syaIXmG3lVnoMgzOS2Ag?pwd=SCSR) with password `SCSR`.
35 |
36 | ## Overview
37 | >Deep models have achieved significant process on single image super-resolution (SISR) tasks, in particular large models with large kernel (3×3 or more). However, the heavy computational footprint of such models prevents their deployment in real-time, resource-constrained environments. Conversely, 1×1 convolutions bring substantial computational efficiency, but struggle with aggregating local spatial representations, an essential capability to SISR models. In response to this dichotomy, we propose to harmonize the merits of both 3×3 and 1×1 kernels, and exploit a great potential for lightweight SISR tasks. Specifically, we propose a simple yet effective fully 1×1 convolutional network, named Shift-Conv-based Network (SCNet). By incorporating a parameter-free spatial-shift operation, it equips the fully 1×1 convolutional network with powerful representation capability while impressive computational efficiency. Extensive experiments demonstrate that SCNets, despite its fully 1×1 convolutional structure, consistently matches or even surpasses the performance of existing lightweight SR models that employ regular convolutions.
38 |
39 |
40 |

41 |
42 |
43 |
44 | ## Train
45 |
46 | All experiments are evaluated based on [BasicSR](https://github.com/XPixelGroup/BasicSR), and we provide a minimal implementation in `SCNet_arch.py`.
47 |
48 | For training, you may refer to the following script:
49 | ```
50 | python basicsr/train.py -opt options/train/SCNet/SCNet-T-x4.yml
51 | ```
52 | And for testing:
53 | ```
54 | python basicsr/test.py -opt options/test/SCNet/SCNet-T-x4.yml
55 | ```
56 | ## License
57 | This code is licensed under the [Creative Commons Attribution-NonCommercial 4.0 International](https://creativecommons.org/licenses/by-nc/4.0/) for non-commercial use only. Please note that any commercial use of this code requires formal permission prior to use.
58 |
59 | ## Acknowledgement
60 | The codes are based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Thanks for their nice sharing.
61 |
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/SCNet_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 |
5 | from basicsr.utils.registry import ARCH_REGISTRY
6 | from basicsr.archs.arch_util import default_init_weights, make_layer
7 |
8 |
9 | class Shift8(nn.Module):
10 | def __init__(self, groups=4, stride=1, mode='constant') -> None:
11 | super().__init__()
12 | self.g = groups
13 | self.mode = mode
14 | self.stride = stride
15 |
16 | def forward(self, x):
17 | b, c, h, w = x.shape
18 | out = torch.zeros_like(x)
19 |
20 | pad_x = F.pad(x, pad=[self.stride for _ in range(4)], mode=self.mode)
21 | assert c == self.g * 8
22 |
23 | cx, cy = self.stride, self.stride
24 | stride = self.stride
25 | out[:,0*self.g:1*self.g, :, :] = pad_x[:, 0*self.g:1*self.g, cx-stride:cx-stride+h, cy:cy+w]
26 | out[:,1*self.g:2*self.g, :, :] = pad_x[:, 1*self.g:2*self.g, cx+stride:cx+stride+h, cy:cy+w]
27 | out[:,2*self.g:3*self.g, :, :] = pad_x[:, 2*self.g:3*self.g, cx:cx+h, cy-stride:cy-stride+w]
28 | out[:,3*self.g:4*self.g, :, :] = pad_x[:, 3*self.g:4*self.g, cx:cx+h, cy+stride:cy+stride+w]
29 |
30 | out[:,4*self.g:5*self.g, :, :] = pad_x[:, 4*self.g:5*self.g, cx+stride:cx+stride+h, cy+stride:cy+stride+w]
31 | out[:,5*self.g:6*self.g, :, :] = pad_x[:, 5*self.g:6*self.g, cx+stride:cx+stride+h, cy-stride:cy-stride+w]
32 | out[:,6*self.g:7*self.g, :, :] = pad_x[:, 6*self.g:7*self.g, cx-stride:cx-stride+h, cy+stride:cy+stride+w]
33 | out[:,7*self.g:8*self.g, :, :] = pad_x[:, 7*self.g:8*self.g, cx-stride:cx-stride+h, cy-stride:cy-stride+w]
34 |
35 | #out[:, 8*self.g:, :, :] = pad_x[:, 8*self.g:, cx:cx+h, cy:cy+w]
36 | return out
37 |
38 |
39 | class ResidualBlockShift(nn.Module):
40 | """Residual block without BN.
41 |
42 | It has a style of:
43 | ---Conv-Shift-ReLU-Conv-+-
44 | |________________|
45 |
46 | Args:
47 | num_feat (int): Channel number of intermediate features.
48 | Default: 64.
49 | res_scale (float): Residual scale. Default: 1.
50 | pytorch_init (bool): If set to True, use pytorch default init,
51 | otherwise, use default_init_weights. Default: False.
52 | """
53 |
54 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
55 | super(ResidualBlockShift, self).__init__()
56 | self.res_scale = res_scale
57 | self.conv1 = nn.Conv2d(num_feat, num_feat, kernel_size=1)
58 | self.conv2 = nn.Conv2d(num_feat, num_feat, kernel_size=1)
59 | self.relu = nn.ReLU(inplace=True)
60 | self.shift = Shift8(groups=num_feat//8, stride=1)
61 |
62 | if not pytorch_init:
63 | default_init_weights([self.conv1, self.conv2], 0.1)
64 |
65 | def forward(self, x):
66 | identity = x
67 | out = self.conv2(self.relu(self.shift(self.conv1(x))))
68 | return identity + out * self.res_scale
69 |
70 |
71 | class UpShiftPixelShuffle(nn.Module):
72 | def __init__(self, dim, scale=2) -> None:
73 | super().__init__()
74 |
75 | self.up_layer = nn.Sequential(
76 | nn.Conv2d(dim, dim, kernel_size=1),
77 | nn.LeakyReLU(0.02),
78 | Shift8(groups=dim//8),
79 | nn.Conv2d(dim, dim*scale*scale, kernel_size=1),
80 | nn.PixelShuffle(upscale_factor=scale)
81 | )
82 | def forward(self, x):
83 | out = self.up_layer(x)
84 | return out
85 |
86 | class UpShiftMLP(nn.Module):
87 | def __init__(self, dim, mode='bilinear', scale=2) -> None:
88 | super().__init__()
89 |
90 | self.up_layer = nn.Sequential(
91 | nn.Upsample(scale_factor=scale, mode=mode, align_corners=False),
92 | nn.Conv2d(dim, dim, kernel_size=1),
93 | nn.LeakyReLU(0.02),
94 | Shift8(groups=dim//8),
95 | nn.Conv2d(dim, dim, kernel_size=1)
96 | )
97 | def forward(self, x):
98 | out = self.up_layer(x)
99 | return out
100 |
101 | @ARCH_REGISTRY.register()
102 | class SCNet(nn.Module):
103 | """ SCNet (https://arxiv.org/abs/2307.16140) based on the Modified SRResNet.
104 | Args:
105 | num_in_ch (int): Channel number of inputs. Default: 3.
106 | num_out_ch (int): Channel number of outputs. Default: 3.
107 | num_feat (int): Channel number of intermediate features. Default: 64.
108 | num_block (int): Block number in the body network. Default: 16.
109 | upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
110 | """
111 |
112 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
113 | super(SCNet, self).__init__()
114 | self.upscale = upscale
115 |
116 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 1)
117 | self.body = make_layer(ResidualBlockShift, num_block, num_feat=num_feat)
118 |
119 | # upsampling
120 | if self.upscale in [2, 3]:
121 | self.upconv1 = UpShiftMLP(num_feat, scale=self.upscale)
122 |
123 | elif self.upscale == 4:
124 | self.upconv1 = UpShiftMLP(num_feat)
125 | self.upconv2 = UpShiftMLP(num_feat)
126 | elif self.upscale == 8:
127 | self.upconv1 = UpShiftMLP(num_feat)
128 | self.upconv2 = UpShiftMLP(num_feat)
129 | self.upconv3 = UpShiftMLP(num_feat)
130 | # freeze infrence
131 | self.pixel_shuffle = nn.Identity()
132 |
133 | self.conv_hr = nn.Conv2d(num_feat, num_feat, kernel_size=1)
134 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, kernel_size=1)
135 |
136 | # activation function
137 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
138 |
139 | # initialization
140 | default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
141 | if self.upscale == 4:
142 | default_init_weights(self.upconv2, 0.1)
143 |
144 | def forward(self, x):
145 | feat = self.lrelu(self.conv_first(x))
146 | out = self.body(feat)
147 |
148 | if self.upscale == 4:
149 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
150 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
151 | elif self.upscale in [2, 3]:
152 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
153 | elif self.upscale == 8:
154 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
155 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
156 | out = self.lrelu(self.pixel_shuffle(self.upconv3(out)))
157 |
158 | out = self.conv_last(self.lrelu(self.conv_hr(out)))
159 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
160 | out += base
161 | return out
162 |
163 | if __name__ == '__main__':
164 | model = SCNet(upscale=4)
165 | load_dict = torch.load('SCNet-T-x4.pth')
166 | model.load_state_dict(load_dict['params'])
167 |
--------------------------------------------------------------------------------
/basicsr/__init__.py:
--------------------------------------------------------------------------------
1 | # https://github.com/xinntao/BasicSR
2 | # flake8: noqa
3 | from .archs import *
4 | from .data import *
5 | from .losses import *
6 | from .metrics import *
7 | from .models import *
8 | from .ops import *
9 | from .test import *
10 | from .train import *
11 | from .utils import *
12 |
--------------------------------------------------------------------------------
/basicsr/archs/SCNet_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 |
5 | from basicsr.utils.registry import ARCH_REGISTRY
6 | from basicsr.archs.arch_util import default_init_weights, make_layer
7 |
8 |
9 | class Shift8(nn.Module):
10 | def __init__(self, groups=4, stride=1, mode="constant") -> None:
11 | super().__init__()
12 | self.g = groups
13 | self.mode = mode
14 | self.stride = stride
15 |
16 | def forward(self, x):
17 | b, c, h, w = x.shape
18 | out = torch.zeros_like(x)
19 |
20 | pad_x = F.pad(x, pad=[self.stride for _ in range(4)], mode=self.mode)
21 | assert c == self.g * 8
22 |
23 | cx, cy = self.stride, self.stride
24 | stride = self.stride
25 | out[:, 0 * self.g : 1 * self.g, :, :] = pad_x[
26 | :, 0 * self.g : 1 * self.g, cx - stride : cx - stride + h, cy : cy + w
27 | ]
28 | out[:, 1 * self.g : 2 * self.g, :, :] = pad_x[
29 | :, 1 * self.g : 2 * self.g, cx + stride : cx + stride + h, cy : cy + w
30 | ]
31 | out[:, 2 * self.g : 3 * self.g, :, :] = pad_x[
32 | :, 2 * self.g : 3 * self.g, cx : cx + h, cy - stride : cy - stride + w
33 | ]
34 | out[:, 3 * self.g : 4 * self.g, :, :] = pad_x[
35 | :, 3 * self.g : 4 * self.g, cx : cx + h, cy + stride : cy + stride + w
36 | ]
37 |
38 | out[:, 4 * self.g : 5 * self.g, :, :] = pad_x[
39 | :,
40 | 4 * self.g : 5 * self.g,
41 | cx + stride : cx + stride + h,
42 | cy + stride : cy + stride + w,
43 | ]
44 | out[:, 5 * self.g : 6 * self.g, :, :] = pad_x[
45 | :,
46 | 5 * self.g : 6 * self.g,
47 | cx + stride : cx + stride + h,
48 | cy - stride : cy - stride + w,
49 | ]
50 | out[:, 6 * self.g : 7 * self.g, :, :] = pad_x[
51 | :,
52 | 6 * self.g : 7 * self.g,
53 | cx - stride : cx - stride + h,
54 | cy + stride : cy + stride + w,
55 | ]
56 | out[:, 7 * self.g : 8 * self.g, :, :] = pad_x[
57 | :,
58 | 7 * self.g : 8 * self.g,
59 | cx - stride : cx - stride + h,
60 | cy - stride : cy - stride + w,
61 | ]
62 |
63 | # out[:, 8*self.g:, :, :] = pad_x[:, 8*self.g:, cx:cx+h, cy:cy+w]
64 | return out
65 |
66 |
67 | class ResidualBlockShift(nn.Module):
68 | """Residual block without BN.
69 |
70 | It has a style of:
71 | ---Conv-Shift-ReLU-Conv-+-
72 | |________________|
73 |
74 | Args:
75 | num_feat (int): Channel number of intermediate features.
76 | Default: 64.
77 | res_scale (float): Residual scale. Default: 1.
78 | pytorch_init (bool): If set to True, use pytorch default init,
79 | otherwise, use default_init_weights. Default: False.
80 | """
81 |
82 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
83 | super(ResidualBlockShift, self).__init__()
84 | self.res_scale = res_scale
85 | self.conv1 = nn.Conv2d(num_feat, num_feat, kernel_size=1)
86 | self.conv2 = nn.Conv2d(num_feat, num_feat, kernel_size=1)
87 | self.relu = nn.ReLU(inplace=True)
88 | self.shift = Shift8(groups=num_feat // 8, stride=1)
89 |
90 | if not pytorch_init:
91 | default_init_weights([self.conv1, self.conv2], 0.1)
92 |
93 | def forward(self, x):
94 | identity = x
95 | out = self.conv2(self.relu(self.shift(self.conv1(x))))
96 | return identity + out * self.res_scale
97 |
98 |
99 | class UpShiftPixelShuffle(nn.Module):
100 | def __init__(self, dim, scale=2) -> None:
101 | super().__init__()
102 |
103 | self.up_layer = nn.Sequential(
104 | nn.Conv2d(dim, dim, kernel_size=1),
105 | nn.LeakyReLU(0.02),
106 | Shift8(groups=dim // 8),
107 | nn.Conv2d(dim, dim * scale * scale, kernel_size=1),
108 | nn.PixelShuffle(upscale_factor=scale),
109 | )
110 |
111 | def forward(self, x):
112 | out = self.up_layer(x)
113 | return out
114 |
115 |
116 | class UpShiftMLP(nn.Module):
117 | def __init__(self, dim, mode="bilinear", scale=2) -> None:
118 | super().__init__()
119 |
120 | self.up_layer = nn.Sequential(
121 | nn.Upsample(scale_factor=scale, mode=mode, align_corners=False),
122 | nn.Conv2d(dim, dim, kernel_size=1),
123 | nn.LeakyReLU(0.02),
124 | Shift8(groups=dim // 8),
125 | nn.Conv2d(dim, dim, kernel_size=1),
126 | )
127 |
128 | def forward(self, x):
129 | out = self.up_layer(x)
130 | return out
131 |
132 |
133 | @ARCH_REGISTRY.register()
134 | class SCNet(nn.Module):
135 | """SCNet (https://arxiv.org/abs/2307.16140) based on the Modified SRResNet.
136 | Args:
137 | num_in_ch (int): Channel number of inputs. Default: 3.
138 | num_out_ch (int): Channel number of outputs. Default: 3.
139 | num_feat (int): Channel number of intermediate features. Default: 64.
140 | num_block (int): Block number in the body network. Default: 16.
141 | upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
142 | use_pixelshuffle (bool): Upsampling with PixelShuffle operation.
143 | """
144 |
145 | def __init__(
146 | self,
147 | num_in_ch=3,
148 | num_out_ch=3,
149 | num_feat=64,
150 | num_block=16,
151 | upscale=4,
152 | use_pixelshuffle=False,
153 | ):
154 | super(SCNet, self).__init__()
155 | self.upscale = upscale
156 |
157 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 1)
158 | self.body = make_layer(ResidualBlockShift, num_block, num_feat=num_feat)
159 |
160 | UpLayer = UpShiftPixelShuffle if use_pixelshuffle else UpShiftMLP
161 | # upsampling
162 | if self.upscale in [2, 3]:
163 | self.upconv1 = UpLayer(num_feat, scale=self.upscale)
164 |
165 | elif self.upscale == 4:
166 | self.upconv1 = UpLayer(num_feat)
167 | self.upconv2 = UpLayer(num_feat)
168 | elif self.upscale == 8:
169 | self.upconv1 = UpLayer(num_feat)
170 | self.upconv2 = UpLayer(num_feat)
171 | self.upconv3 = UpLayer(num_feat)
172 | # freeze infrence
173 | self.pixel_shuffle = nn.Identity()
174 |
175 | self.conv_hr = nn.Conv2d(num_feat, num_feat, kernel_size=1)
176 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, kernel_size=1)
177 |
178 | # activation function
179 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
180 |
181 | # initialization
182 | default_init_weights(
183 | [self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1
184 | )
185 | if self.upscale == 4:
186 | default_init_weights(self.upconv2, 0.1)
187 |
188 | def forward(self, x):
189 | feat = self.lrelu(self.conv_first(x))
190 | out = self.body(feat)
191 |
192 | if self.upscale == 4:
193 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
194 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
195 | elif self.upscale in [2, 3]:
196 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
197 | elif self.upscale == 8:
198 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
199 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
200 | out = self.lrelu(self.pixel_shuffle(self.upconv3(out)))
201 |
202 | out = self.conv_last(self.lrelu(self.conv_hr(out)))
203 | base = F.interpolate(
204 | x, scale_factor=self.upscale, mode="bilinear", align_corners=False
205 | )
206 | out += base
207 | return out
208 |
209 |
210 | if __name__ == "__main__":
211 | model = SCNet(upscale=4)
212 | load_dict = torch.load("SCNet-T-x4.pth")
213 | model.load_state_dict(load_dict["params"])
214 |
--------------------------------------------------------------------------------
/basicsr/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from basicsr.utils import get_root_logger, scandir
6 | from basicsr.utils.registry import ARCH_REGISTRY
7 |
8 | __all__ = ['build_network']
9 |
10 | # automatically scan and import arch modules for registry
11 | # scan all the files under the 'archs' folder and collect files ending with '_arch.py'
12 | arch_folder = osp.dirname(osp.abspath(__file__))
13 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
14 | # import all the arch modules
15 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
16 |
17 |
18 | def build_network(opt):
19 | opt = deepcopy(opt)
20 | network_type = opt.pop('type')
21 | net = ARCH_REGISTRY.get(network_type)(**opt)
22 | logger = get_root_logger()
23 | logger.info(f'Network [{net.__class__.__name__}] is created.')
24 | return net
25 |
--------------------------------------------------------------------------------
/basicsr/archs/vgg_arch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from torch import nn as nn
5 | from torchvision.models import vgg as vgg
6 |
7 | from basicsr.utils.registry import ARCH_REGISTRY
8 |
9 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
10 | NAMES = {
11 | 'vgg11': [
12 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
13 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14 | 'pool5'
15 | ],
16 | 'vgg13': [
17 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
18 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
19 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
20 | ],
21 | 'vgg16': [
22 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
23 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
24 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
25 | 'pool5'
26 | ],
27 | 'vgg19': [
28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
29 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
30 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
31 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
32 | ]
33 | }
34 |
35 |
36 | def insert_bn(names):
37 | """Insert bn layer after each conv.
38 |
39 | Args:
40 | names (list): The list of layer names.
41 |
42 | Returns:
43 | list: The list of layer names with bn layers.
44 | """
45 | names_bn = []
46 | for name in names:
47 | names_bn.append(name)
48 | if 'conv' in name:
49 | position = name.replace('conv', '')
50 | names_bn.append('bn' + position)
51 | return names_bn
52 |
53 |
54 | @ARCH_REGISTRY.register()
55 | class VGGFeatureExtractor(nn.Module):
56 | """VGG network for feature extraction.
57 |
58 | In this implementation, we allow users to choose whether use normalization
59 | in the input feature and the type of vgg network. Note that the pretrained
60 | path must fit the vgg type.
61 |
62 | Args:
63 | layer_name_list (list[str]): Forward function returns the corresponding
64 | features according to the layer_name_list.
65 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
66 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
67 | use_input_norm (bool): If True, normalize the input image. Importantly,
68 | the input feature must in the range [0, 1]. Default: True.
69 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
70 | Default: False.
71 | requires_grad (bool): If true, the parameters of VGG network will be
72 | optimized. Default: False.
73 | remove_pooling (bool): If true, the max pooling operations in VGG net
74 | will be removed. Default: False.
75 | pooling_stride (int): The stride of max pooling operation. Default: 2.
76 | """
77 |
78 | def __init__(self,
79 | layer_name_list,
80 | vgg_type='vgg19',
81 | use_input_norm=True,
82 | range_norm=False,
83 | requires_grad=False,
84 | remove_pooling=False,
85 | pooling_stride=2):
86 | super(VGGFeatureExtractor, self).__init__()
87 |
88 | self.layer_name_list = layer_name_list
89 | self.use_input_norm = use_input_norm
90 | self.range_norm = range_norm
91 |
92 | self.names = NAMES[vgg_type.replace('_bn', '')]
93 | if 'bn' in vgg_type:
94 | self.names = insert_bn(self.names)
95 |
96 | # only borrow layers that will be used to avoid unused params
97 | max_idx = 0
98 | for v in layer_name_list:
99 | idx = self.names.index(v)
100 | if idx > max_idx:
101 | max_idx = idx
102 |
103 | if os.path.exists(VGG_PRETRAIN_PATH):
104 | vgg_net = getattr(vgg, vgg_type)(pretrained=False)
105 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
106 | vgg_net.load_state_dict(state_dict)
107 | else:
108 | vgg_net = getattr(vgg, vgg_type)(pretrained=True)
109 |
110 | features = vgg_net.features[:max_idx + 1]
111 |
112 | modified_net = OrderedDict()
113 | for k, v in zip(self.names, features):
114 | if 'pool' in k:
115 | # if remove_pooling is true, pooling operation will be removed
116 | if remove_pooling:
117 | continue
118 | else:
119 | # in some cases, we may want to change the default stride
120 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
121 | else:
122 | modified_net[k] = v
123 |
124 | self.vgg_net = nn.Sequential(modified_net)
125 |
126 | if not requires_grad:
127 | self.vgg_net.eval()
128 | for param in self.parameters():
129 | param.requires_grad = False
130 | else:
131 | self.vgg_net.train()
132 | for param in self.parameters():
133 | param.requires_grad = True
134 |
135 | if self.use_input_norm:
136 | # the mean is for image with range [0, 1]
137 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
138 | # the std is for image with range [0, 1]
139 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
140 |
141 | def forward(self, x):
142 | """Forward function.
143 |
144 | Args:
145 | x (Tensor): Input tensor with shape (n, c, h, w).
146 |
147 | Returns:
148 | Tensor: Forward results.
149 | """
150 | if self.range_norm:
151 | x = (x + 1) / 2
152 | if self.use_input_norm:
153 | x = (x - self.mean) / self.std
154 |
155 | output = {}
156 | for key, layer in self.vgg_net._modules.items():
157 | x = layer(x)
158 | if key in self.layer_name_list:
159 | output[key] = x.clone()
160 |
161 | return output
--------------------------------------------------------------------------------
/basicsr/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.utils.data
6 | from copy import deepcopy
7 | from functools import partial
8 | from os import path as osp
9 |
10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11 | from basicsr.utils import get_root_logger, scandir
12 | from basicsr.utils.dist_util import get_dist_info
13 | from basicsr.utils.registry import DATASET_REGISTRY
14 |
15 | __all__ = ['build_dataset', 'build_dataloader']
16 |
17 | # automatically scan and import dataset modules for registry
18 | # scan all the files under the data folder with '_dataset' in file names
19 | data_folder = osp.dirname(osp.abspath(__file__))
20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21 | # import all the dataset modules
22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23 |
24 |
25 | def build_dataset(dataset_opt):
26 | """Build dataset from options.
27 |
28 | Args:
29 | dataset_opt (dict): Configuration for dataset. It must contain:
30 | name (str): Dataset name.
31 | type (str): Dataset type.
32 | """
33 | dataset_opt = deepcopy(dataset_opt)
34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35 | logger = get_root_logger()
36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
37 | return dataset
38 |
39 |
40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41 | """Build dataloader.
42 |
43 | Args:
44 | dataset (torch.utils.data.Dataset): Dataset.
45 | dataset_opt (dict): Dataset options. It contains the following keys:
46 | phase (str): 'train' or 'val'.
47 | num_worker_per_gpu (int): Number of workers for each GPU.
48 | batch_size_per_gpu (int): Training batch size for each GPU.
49 | num_gpu (int): Number of GPUs. Used only in the train phase.
50 | Default: 1.
51 | dist (bool): Whether in distributed training. Used only in the train
52 | phase. Default: False.
53 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
54 | seed (int | None): Seed. Default: None
55 | """
56 | phase = dataset_opt['phase']
57 | rank, _ = get_dist_info()
58 | if phase == 'train':
59 | if dist: # distributed training
60 | batch_size = dataset_opt['batch_size_per_gpu']
61 | num_workers = dataset_opt['num_worker_per_gpu']
62 | else: # non-distributed training
63 | multiplier = 1 if num_gpu == 0 else num_gpu
64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66 | dataloader_args = dict(
67 | dataset=dataset,
68 | batch_size=batch_size,
69 | shuffle=False,
70 | num_workers=num_workers,
71 | sampler=sampler,
72 | drop_last=True)
73 | if sampler is None:
74 | dataloader_args['shuffle'] = True
75 | dataloader_args['worker_init_fn'] = partial(
76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77 | elif phase in ['val', 'test']: # validation
78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79 | else:
80 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
81 |
82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
84 |
85 | prefetch_mode = dataset_opt.get('prefetch_mode')
86 | if prefetch_mode == 'cpu': # CPUPrefetcher
87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
88 | logger = get_root_logger()
89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
91 | else:
92 | # prefetch_mode=None: Normal dataloader
93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
94 | return torch.utils.data.DataLoader(**dataloader_args)
95 |
96 |
97 | def worker_init_fn(worker_id, num_workers, rank, seed):
98 | # Set the worker seed to num_workers * rank + worker_id + seed
99 | worker_seed = num_workers * rank + worker_id + seed
100 | np.random.seed(worker_seed)
101 | random.seed(worker_seed)
102 |
--------------------------------------------------------------------------------
/basicsr/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.utils.data.sampler import Sampler
4 |
5 |
6 | class EnlargedSampler(Sampler):
7 | """Sampler that restricts data loading to a subset of the dataset.
8 |
9 | Modified from torch.utils.data.distributed.DistributedSampler
10 | Support enlarging the dataset for iteration-based training, for saving
11 | time when restart the dataloader after each epoch
12 |
13 | Args:
14 | dataset (torch.utils.data.Dataset): Dataset used for sampling.
15 | num_replicas (int | None): Number of processes participating in
16 | the training. It is usually the world_size.
17 | rank (int | None): Rank of the current process within num_replicas.
18 | ratio (int): Enlarging ratio. Default: 1.
19 | """
20 |
21 | def __init__(self, dataset, num_replicas, rank, ratio=1):
22 | self.dataset = dataset
23 | self.num_replicas = num_replicas
24 | self.rank = rank
25 | self.epoch = 0
26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27 | self.total_size = self.num_samples * self.num_replicas
28 |
29 | def __iter__(self):
30 | # deterministically shuffle based on epoch
31 | g = torch.Generator()
32 | g.manual_seed(self.epoch)
33 | indices = torch.randperm(self.total_size, generator=g).tolist()
34 |
35 | dataset_size = len(self.dataset)
36 | indices = [v % dataset_size for v in indices]
37 |
38 | # subsample
39 | indices = indices[self.rank:self.total_size:self.num_replicas]
40 | assert len(indices) == self.num_samples
41 |
42 | return iter(indices)
43 |
44 | def __len__(self):
45 | return self.num_samples
46 |
47 | def set_epoch(self, epoch):
48 | self.epoch = epoch
49 |
--------------------------------------------------------------------------------
/basicsr/data/ffhq_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import time
3 | from os import path as osp
4 | from torch.utils import data as data
5 | from torchvision.transforms.functional import normalize
6 |
7 | from basicsr.data.transforms import augment
8 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
9 | from basicsr.utils.registry import DATASET_REGISTRY
10 |
11 |
12 | @DATASET_REGISTRY.register()
13 | class FFHQDataset(data.Dataset):
14 | """FFHQ dataset for StyleGAN.
15 |
16 | Args:
17 | opt (dict): Config for train datasets. It contains the following keys:
18 | dataroot_gt (str): Data root path for gt.
19 | io_backend (dict): IO backend type and other kwarg.
20 | mean (list | tuple): Image mean.
21 | std (list | tuple): Image std.
22 | use_hflip (bool): Whether to horizontally flip.
23 |
24 | """
25 |
26 | def __init__(self, opt):
27 | super(FFHQDataset, self).__init__()
28 | self.opt = opt
29 | # file client (io backend)
30 | self.file_client = None
31 | self.io_backend_opt = opt['io_backend']
32 |
33 | self.gt_folder = opt['dataroot_gt']
34 | self.mean = opt['mean']
35 | self.std = opt['std']
36 |
37 | if self.io_backend_opt['type'] == 'lmdb':
38 | self.io_backend_opt['db_paths'] = self.gt_folder
39 | if not self.gt_folder.endswith('.lmdb'):
40 | raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
41 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
42 | self.paths = [line.split('.')[0] for line in fin]
43 | else:
44 | # FFHQ has 70000 images in total
45 | self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
46 |
47 | def __getitem__(self, index):
48 | if self.file_client is None:
49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
50 |
51 | # load gt image
52 | gt_path = self.paths[index]
53 | # avoid errors caused by high latency in reading files
54 | retry = 3
55 | while retry > 0:
56 | try:
57 | img_bytes = self.file_client.get(gt_path)
58 | except Exception as e:
59 | logger = get_root_logger()
60 | logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
61 | # change another file to read
62 | index = random.randint(0, self.__len__())
63 | gt_path = self.paths[index]
64 | time.sleep(1) # sleep 1s for occasional server congestion
65 | else:
66 | break
67 | finally:
68 | retry -= 1
69 | img_gt = imfrombytes(img_bytes, float32=True)
70 |
71 | # random horizontal flip
72 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
73 | # BGR to RGB, HWC to CHW, numpy to tensor
74 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
75 | # normalize
76 | normalize(img_gt, self.mean, self.std, inplace=True)
77 | return {'gt': img_gt, 'gt_path': gt_path}
78 |
79 | def __len__(self):
80 | return len(self.paths)
81 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDS4_test_GT.txt:
--------------------------------------------------------------------------------
1 | 000 100 (720,1280,3)
2 | 011 100 (720,1280,3)
3 | 015 100 (720,1280,3)
4 | 020 100 (720,1280,3)
5 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDS_GT.txt:
--------------------------------------------------------------------------------
1 | 000 100 (720,1280,3)
2 | 001 100 (720,1280,3)
3 | 002 100 (720,1280,3)
4 | 003 100 (720,1280,3)
5 | 004 100 (720,1280,3)
6 | 005 100 (720,1280,3)
7 | 006 100 (720,1280,3)
8 | 007 100 (720,1280,3)
9 | 008 100 (720,1280,3)
10 | 009 100 (720,1280,3)
11 | 010 100 (720,1280,3)
12 | 011 100 (720,1280,3)
13 | 012 100 (720,1280,3)
14 | 013 100 (720,1280,3)
15 | 014 100 (720,1280,3)
16 | 015 100 (720,1280,3)
17 | 016 100 (720,1280,3)
18 | 017 100 (720,1280,3)
19 | 018 100 (720,1280,3)
20 | 019 100 (720,1280,3)
21 | 020 100 (720,1280,3)
22 | 021 100 (720,1280,3)
23 | 022 100 (720,1280,3)
24 | 023 100 (720,1280,3)
25 | 024 100 (720,1280,3)
26 | 025 100 (720,1280,3)
27 | 026 100 (720,1280,3)
28 | 027 100 (720,1280,3)
29 | 028 100 (720,1280,3)
30 | 029 100 (720,1280,3)
31 | 030 100 (720,1280,3)
32 | 031 100 (720,1280,3)
33 | 032 100 (720,1280,3)
34 | 033 100 (720,1280,3)
35 | 034 100 (720,1280,3)
36 | 035 100 (720,1280,3)
37 | 036 100 (720,1280,3)
38 | 037 100 (720,1280,3)
39 | 038 100 (720,1280,3)
40 | 039 100 (720,1280,3)
41 | 040 100 (720,1280,3)
42 | 041 100 (720,1280,3)
43 | 042 100 (720,1280,3)
44 | 043 100 (720,1280,3)
45 | 044 100 (720,1280,3)
46 | 045 100 (720,1280,3)
47 | 046 100 (720,1280,3)
48 | 047 100 (720,1280,3)
49 | 048 100 (720,1280,3)
50 | 049 100 (720,1280,3)
51 | 050 100 (720,1280,3)
52 | 051 100 (720,1280,3)
53 | 052 100 (720,1280,3)
54 | 053 100 (720,1280,3)
55 | 054 100 (720,1280,3)
56 | 055 100 (720,1280,3)
57 | 056 100 (720,1280,3)
58 | 057 100 (720,1280,3)
59 | 058 100 (720,1280,3)
60 | 059 100 (720,1280,3)
61 | 060 100 (720,1280,3)
62 | 061 100 (720,1280,3)
63 | 062 100 (720,1280,3)
64 | 063 100 (720,1280,3)
65 | 064 100 (720,1280,3)
66 | 065 100 (720,1280,3)
67 | 066 100 (720,1280,3)
68 | 067 100 (720,1280,3)
69 | 068 100 (720,1280,3)
70 | 069 100 (720,1280,3)
71 | 070 100 (720,1280,3)
72 | 071 100 (720,1280,3)
73 | 072 100 (720,1280,3)
74 | 073 100 (720,1280,3)
75 | 074 100 (720,1280,3)
76 | 075 100 (720,1280,3)
77 | 076 100 (720,1280,3)
78 | 077 100 (720,1280,3)
79 | 078 100 (720,1280,3)
80 | 079 100 (720,1280,3)
81 | 080 100 (720,1280,3)
82 | 081 100 (720,1280,3)
83 | 082 100 (720,1280,3)
84 | 083 100 (720,1280,3)
85 | 084 100 (720,1280,3)
86 | 085 100 (720,1280,3)
87 | 086 100 (720,1280,3)
88 | 087 100 (720,1280,3)
89 | 088 100 (720,1280,3)
90 | 089 100 (720,1280,3)
91 | 090 100 (720,1280,3)
92 | 091 100 (720,1280,3)
93 | 092 100 (720,1280,3)
94 | 093 100 (720,1280,3)
95 | 094 100 (720,1280,3)
96 | 095 100 (720,1280,3)
97 | 096 100 (720,1280,3)
98 | 097 100 (720,1280,3)
99 | 098 100 (720,1280,3)
100 | 099 100 (720,1280,3)
101 | 100 100 (720,1280,3)
102 | 101 100 (720,1280,3)
103 | 102 100 (720,1280,3)
104 | 103 100 (720,1280,3)
105 | 104 100 (720,1280,3)
106 | 105 100 (720,1280,3)
107 | 106 100 (720,1280,3)
108 | 107 100 (720,1280,3)
109 | 108 100 (720,1280,3)
110 | 109 100 (720,1280,3)
111 | 110 100 (720,1280,3)
112 | 111 100 (720,1280,3)
113 | 112 100 (720,1280,3)
114 | 113 100 (720,1280,3)
115 | 114 100 (720,1280,3)
116 | 115 100 (720,1280,3)
117 | 116 100 (720,1280,3)
118 | 117 100 (720,1280,3)
119 | 118 100 (720,1280,3)
120 | 119 100 (720,1280,3)
121 | 120 100 (720,1280,3)
122 | 121 100 (720,1280,3)
123 | 122 100 (720,1280,3)
124 | 123 100 (720,1280,3)
125 | 124 100 (720,1280,3)
126 | 125 100 (720,1280,3)
127 | 126 100 (720,1280,3)
128 | 127 100 (720,1280,3)
129 | 128 100 (720,1280,3)
130 | 129 100 (720,1280,3)
131 | 130 100 (720,1280,3)
132 | 131 100 (720,1280,3)
133 | 132 100 (720,1280,3)
134 | 133 100 (720,1280,3)
135 | 134 100 (720,1280,3)
136 | 135 100 (720,1280,3)
137 | 136 100 (720,1280,3)
138 | 137 100 (720,1280,3)
139 | 138 100 (720,1280,3)
140 | 139 100 (720,1280,3)
141 | 140 100 (720,1280,3)
142 | 141 100 (720,1280,3)
143 | 142 100 (720,1280,3)
144 | 143 100 (720,1280,3)
145 | 144 100 (720,1280,3)
146 | 145 100 (720,1280,3)
147 | 146 100 (720,1280,3)
148 | 147 100 (720,1280,3)
149 | 148 100 (720,1280,3)
150 | 149 100 (720,1280,3)
151 | 150 100 (720,1280,3)
152 | 151 100 (720,1280,3)
153 | 152 100 (720,1280,3)
154 | 153 100 (720,1280,3)
155 | 154 100 (720,1280,3)
156 | 155 100 (720,1280,3)
157 | 156 100 (720,1280,3)
158 | 157 100 (720,1280,3)
159 | 158 100 (720,1280,3)
160 | 159 100 (720,1280,3)
161 | 160 100 (720,1280,3)
162 | 161 100 (720,1280,3)
163 | 162 100 (720,1280,3)
164 | 163 100 (720,1280,3)
165 | 164 100 (720,1280,3)
166 | 165 100 (720,1280,3)
167 | 166 100 (720,1280,3)
168 | 167 100 (720,1280,3)
169 | 168 100 (720,1280,3)
170 | 169 100 (720,1280,3)
171 | 170 100 (720,1280,3)
172 | 171 100 (720,1280,3)
173 | 172 100 (720,1280,3)
174 | 173 100 (720,1280,3)
175 | 174 100 (720,1280,3)
176 | 175 100 (720,1280,3)
177 | 176 100 (720,1280,3)
178 | 177 100 (720,1280,3)
179 | 178 100 (720,1280,3)
180 | 179 100 (720,1280,3)
181 | 180 100 (720,1280,3)
182 | 181 100 (720,1280,3)
183 | 182 100 (720,1280,3)
184 | 183 100 (720,1280,3)
185 | 184 100 (720,1280,3)
186 | 185 100 (720,1280,3)
187 | 186 100 (720,1280,3)
188 | 187 100 (720,1280,3)
189 | 188 100 (720,1280,3)
190 | 189 100 (720,1280,3)
191 | 190 100 (720,1280,3)
192 | 191 100 (720,1280,3)
193 | 192 100 (720,1280,3)
194 | 193 100 (720,1280,3)
195 | 194 100 (720,1280,3)
196 | 195 100 (720,1280,3)
197 | 196 100 (720,1280,3)
198 | 197 100 (720,1280,3)
199 | 198 100 (720,1280,3)
200 | 199 100 (720,1280,3)
201 | 200 100 (720,1280,3)
202 | 201 100 (720,1280,3)
203 | 202 100 (720,1280,3)
204 | 203 100 (720,1280,3)
205 | 204 100 (720,1280,3)
206 | 205 100 (720,1280,3)
207 | 206 100 (720,1280,3)
208 | 207 100 (720,1280,3)
209 | 208 100 (720,1280,3)
210 | 209 100 (720,1280,3)
211 | 210 100 (720,1280,3)
212 | 211 100 (720,1280,3)
213 | 212 100 (720,1280,3)
214 | 213 100 (720,1280,3)
215 | 214 100 (720,1280,3)
216 | 215 100 (720,1280,3)
217 | 216 100 (720,1280,3)
218 | 217 100 (720,1280,3)
219 | 218 100 (720,1280,3)
220 | 219 100 (720,1280,3)
221 | 220 100 (720,1280,3)
222 | 221 100 (720,1280,3)
223 | 222 100 (720,1280,3)
224 | 223 100 (720,1280,3)
225 | 224 100 (720,1280,3)
226 | 225 100 (720,1280,3)
227 | 226 100 (720,1280,3)
228 | 227 100 (720,1280,3)
229 | 228 100 (720,1280,3)
230 | 229 100 (720,1280,3)
231 | 230 100 (720,1280,3)
232 | 231 100 (720,1280,3)
233 | 232 100 (720,1280,3)
234 | 233 100 (720,1280,3)
235 | 234 100 (720,1280,3)
236 | 235 100 (720,1280,3)
237 | 236 100 (720,1280,3)
238 | 237 100 (720,1280,3)
239 | 238 100 (720,1280,3)
240 | 239 100 (720,1280,3)
241 | 240 100 (720,1280,3)
242 | 241 100 (720,1280,3)
243 | 242 100 (720,1280,3)
244 | 243 100 (720,1280,3)
245 | 244 100 (720,1280,3)
246 | 245 100 (720,1280,3)
247 | 246 100 (720,1280,3)
248 | 247 100 (720,1280,3)
249 | 248 100 (720,1280,3)
250 | 249 100 (720,1280,3)
251 | 250 100 (720,1280,3)
252 | 251 100 (720,1280,3)
253 | 252 100 (720,1280,3)
254 | 253 100 (720,1280,3)
255 | 254 100 (720,1280,3)
256 | 255 100 (720,1280,3)
257 | 256 100 (720,1280,3)
258 | 257 100 (720,1280,3)
259 | 258 100 (720,1280,3)
260 | 259 100 (720,1280,3)
261 | 260 100 (720,1280,3)
262 | 261 100 (720,1280,3)
263 | 262 100 (720,1280,3)
264 | 263 100 (720,1280,3)
265 | 264 100 (720,1280,3)
266 | 265 100 (720,1280,3)
267 | 266 100 (720,1280,3)
268 | 267 100 (720,1280,3)
269 | 268 100 (720,1280,3)
270 | 269 100 (720,1280,3)
271 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt:
--------------------------------------------------------------------------------
1 | 240 100 (720,1280,3)
2 | 241 100 (720,1280,3)
3 | 246 100 (720,1280,3)
4 | 257 100 (720,1280,3)
5 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt:
--------------------------------------------------------------------------------
1 | 240 100 (720,1280,3)
2 | 241 100 (720,1280,3)
3 | 242 100 (720,1280,3)
4 | 243 100 (720,1280,3)
5 | 244 100 (720,1280,3)
6 | 245 100 (720,1280,3)
7 | 246 100 (720,1280,3)
8 | 247 100 (720,1280,3)
9 | 248 100 (720,1280,3)
10 | 249 100 (720,1280,3)
11 | 250 100 (720,1280,3)
12 | 251 100 (720,1280,3)
13 | 252 100 (720,1280,3)
14 | 253 100 (720,1280,3)
15 | 254 100 (720,1280,3)
16 | 255 100 (720,1280,3)
17 | 256 100 (720,1280,3)
18 | 257 100 (720,1280,3)
19 | 258 100 (720,1280,3)
20 | 259 100 (720,1280,3)
21 | 260 100 (720,1280,3)
22 | 261 100 (720,1280,3)
23 | 262 100 (720,1280,3)
24 | 263 100 (720,1280,3)
25 | 264 100 (720,1280,3)
26 | 265 100 (720,1280,3)
27 | 266 100 (720,1280,3)
28 | 267 100 (720,1280,3)
29 | 268 100 (720,1280,3)
30 | 269 100 (720,1280,3)
31 |
--------------------------------------------------------------------------------
/basicsr/data/paired_image_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data as data
2 | from torchvision.transforms.functional import normalize
3 |
4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
5 | from basicsr.data.transforms import augment, paired_random_crop
6 | from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
7 | from basicsr.utils.registry import DATASET_REGISTRY
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class PairedImageDataset(data.Dataset):
12 | """Paired image dataset for image restoration.
13 |
14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
15 |
16 | There are three modes:
17 |
18 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
19 | 2. **meta_info_file**: Use meta information file to generate paths. \
20 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
21 | 3. **folder**: Scan folders to generate paths. The rest.
22 |
23 | Args:
24 | opt (dict): Config for train datasets. It contains the following keys:
25 | dataroot_gt (str): Data root path for gt.
26 | dataroot_lq (str): Data root path for lq.
27 | meta_info_file (str): Path for meta information file.
28 | io_backend (dict): IO backend type and other kwarg.
29 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
30 | Default: '{}'.
31 | gt_size (int): Cropped patched size for gt patches.
32 | use_hflip (bool): Use horizontal flips.
33 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
34 | scale (bool): Scale, which will be added automatically.
35 | phase (str): 'train' or 'val'.
36 | """
37 |
38 | def __init__(self, opt):
39 | super(PairedImageDataset, self).__init__()
40 | self.opt = opt
41 | # file client (io backend)
42 | self.file_client = None
43 | self.io_backend_opt = opt['io_backend']
44 | self.mean = opt['mean'] if 'mean' in opt else None
45 | self.std = opt['std'] if 'std' in opt else None
46 |
47 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
48 | if 'filename_tmpl' in opt:
49 | self.filename_tmpl = opt['filename_tmpl']
50 | else:
51 | self.filename_tmpl = '{}'
52 |
53 | if self.io_backend_opt['type'] == 'lmdb':
54 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
55 | self.io_backend_opt['client_keys'] = ['lq', 'gt']
56 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
57 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
58 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
59 | self.opt['meta_info_file'], self.filename_tmpl)
60 | else:
61 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
62 |
63 | def __getitem__(self, index):
64 | if self.file_client is None:
65 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
66 |
67 | scale = self.opt['scale']
68 |
69 | # Load gt and lq images. Dimension order: HWC; channel order: BGR;
70 | # image range: [0, 1], float32.
71 | gt_path = self.paths[index]['gt_path']
72 | img_bytes = self.file_client.get(gt_path, 'gt')
73 | img_gt = imfrombytes(img_bytes, float32=True)
74 | lq_path = self.paths[index]['lq_path']
75 | img_bytes = self.file_client.get(lq_path, 'lq')
76 | img_lq = imfrombytes(img_bytes, float32=True)
77 |
78 | # augmentation for training
79 | if self.opt['phase'] == 'train':
80 | gt_size = self.opt['gt_size']
81 | # random crop
82 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
83 | # flip, rotation
84 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
85 |
86 | # color space transform
87 | if 'color' in self.opt and self.opt['color'] == 'y':
88 | img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
89 | img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
90 |
91 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
92 | # TODO: It is better to update the datasets, rather than force to crop
93 | if self.opt['phase'] != 'train':
94 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
95 |
96 | # BGR to RGB, HWC to CHW, numpy to tensor
97 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
98 | # normalize
99 | if self.mean is not None or self.std is not None:
100 | normalize(img_lq, self.mean, self.std, inplace=True)
101 | normalize(img_gt, self.mean, self.std, inplace=True)
102 |
103 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
104 |
105 | def __len__(self):
106 | return len(self.paths)
107 |
--------------------------------------------------------------------------------
/basicsr/data/prefetch_dataloader.py:
--------------------------------------------------------------------------------
1 | import queue as Queue
2 | import threading
3 | import torch
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | class PrefetchGenerator(threading.Thread):
8 | """A general prefetch generator.
9 |
10 | Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
11 |
12 | Args:
13 | generator: Python generator.
14 | num_prefetch_queue (int): Number of prefetch queue.
15 | """
16 |
17 | def __init__(self, generator, num_prefetch_queue):
18 | threading.Thread.__init__(self)
19 | self.queue = Queue.Queue(num_prefetch_queue)
20 | self.generator = generator
21 | self.daemon = True
22 | self.start()
23 |
24 | def run(self):
25 | for item in self.generator:
26 | self.queue.put(item)
27 | self.queue.put(None)
28 |
29 | def __next__(self):
30 | next_item = self.queue.get()
31 | if next_item is None:
32 | raise StopIteration
33 | return next_item
34 |
35 | def __iter__(self):
36 | return self
37 |
38 |
39 | class PrefetchDataLoader(DataLoader):
40 | """Prefetch version of dataloader.
41 |
42 | Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
43 |
44 | TODO:
45 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in
46 | ddp.
47 |
48 | Args:
49 | num_prefetch_queue (int): Number of prefetch queue.
50 | kwargs (dict): Other arguments for dataloader.
51 | """
52 |
53 | def __init__(self, num_prefetch_queue, **kwargs):
54 | self.num_prefetch_queue = num_prefetch_queue
55 | super(PrefetchDataLoader, self).__init__(**kwargs)
56 |
57 | def __iter__(self):
58 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
59 |
60 |
61 | class CPUPrefetcher():
62 | """CPU prefetcher.
63 |
64 | Args:
65 | loader: Dataloader.
66 | """
67 |
68 | def __init__(self, loader):
69 | self.ori_loader = loader
70 | self.loader = iter(loader)
71 |
72 | def next(self):
73 | try:
74 | return next(self.loader)
75 | except StopIteration:
76 | return None
77 |
78 | def reset(self):
79 | self.loader = iter(self.ori_loader)
80 |
81 |
82 | class CUDAPrefetcher():
83 | """CUDA prefetcher.
84 |
85 | Reference: https://github.com/NVIDIA/apex/issues/304#
86 |
87 | It may consume more GPU memory.
88 |
89 | Args:
90 | loader: Dataloader.
91 | opt (dict): Options.
92 | """
93 |
94 | def __init__(self, loader, opt):
95 | self.ori_loader = loader
96 | self.loader = iter(loader)
97 | self.opt = opt
98 | self.stream = torch.cuda.Stream()
99 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
100 | self.preload()
101 |
102 | def preload(self):
103 | try:
104 | self.batch = next(self.loader) # self.batch is a dict
105 | except StopIteration:
106 | self.batch = None
107 | return None
108 | # put tensors to gpu
109 | with torch.cuda.stream(self.stream):
110 | for k, v in self.batch.items():
111 | if torch.is_tensor(v):
112 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
113 |
114 | def next(self):
115 | torch.cuda.current_stream().wait_stream(self.stream)
116 | batch = self.batch
117 | self.preload()
118 | return batch
119 |
120 | def reset(self):
121 | self.loader = iter(self.ori_loader)
122 | self.preload()
123 |
--------------------------------------------------------------------------------
/basicsr/data/realesrgan_paired_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
6 | from basicsr.data.transforms import augment, paired_random_crop
7 | from basicsr.utils import FileClient, imfrombytes, img2tensor
8 | from basicsr.utils.registry import DATASET_REGISTRY
9 |
10 |
11 | @DATASET_REGISTRY.register(suffix='basicsr')
12 | class RealESRGANPairedDataset(data.Dataset):
13 | """Paired image dataset for image restoration.
14 |
15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
16 |
17 | There are three modes:
18 |
19 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
20 | 2. **meta_info_file**: Use meta information file to generate paths. \
21 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
22 | 3. **folder**: Scan folders to generate paths. The rest.
23 |
24 | Args:
25 | opt (dict): Config for train datasets. It contains the following keys:
26 | dataroot_gt (str): Data root path for gt.
27 | dataroot_lq (str): Data root path for lq.
28 | meta_info (str): Path for meta information file.
29 | io_backend (dict): IO backend type and other kwarg.
30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
31 | Default: '{}'.
32 | gt_size (int): Cropped patched size for gt patches.
33 | use_hflip (bool): Use horizontal flips.
34 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
35 | scale (bool): Scale, which will be added automatically.
36 | phase (str): 'train' or 'val'.
37 | """
38 |
39 | def __init__(self, opt):
40 | super(RealESRGANPairedDataset, self).__init__()
41 | self.opt = opt
42 | self.file_client = None
43 | self.io_backend_opt = opt['io_backend']
44 | # mean and std for normalizing the input images
45 | self.mean = opt['mean'] if 'mean' in opt else None
46 | self.std = opt['std'] if 'std' in opt else None
47 |
48 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
49 | self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
50 |
51 | # file client (lmdb io backend)
52 | if self.io_backend_opt['type'] == 'lmdb':
53 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
54 | self.io_backend_opt['client_keys'] = ['lq', 'gt']
55 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
56 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
57 | # disk backend with meta_info
58 | # Each line in the meta_info describes the relative path to an image
59 | with open(self.opt['meta_info']) as fin:
60 | paths = [line.strip() for line in fin]
61 | self.paths = []
62 | for path in paths:
63 | gt_path, lq_path = path.split(', ')
64 | gt_path = os.path.join(self.gt_folder, gt_path)
65 | lq_path = os.path.join(self.lq_folder, lq_path)
66 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
67 | else:
68 | # disk backend
69 | # it will scan the whole folder to get meta info
70 | # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
71 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
72 |
73 | def __getitem__(self, index):
74 | if self.file_client is None:
75 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
76 |
77 | scale = self.opt['scale']
78 |
79 | # Load gt and lq images. Dimension order: HWC; channel order: BGR;
80 | # image range: [0, 1], float32.
81 | gt_path = self.paths[index]['gt_path']
82 | img_bytes = self.file_client.get(gt_path, 'gt')
83 | img_gt = imfrombytes(img_bytes, float32=True)
84 | lq_path = self.paths[index]['lq_path']
85 | img_bytes = self.file_client.get(lq_path, 'lq')
86 | img_lq = imfrombytes(img_bytes, float32=True)
87 |
88 | # augmentation for training
89 | if self.opt['phase'] == 'train':
90 | gt_size = self.opt['gt_size']
91 | # random crop
92 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
93 | # flip, rotation
94 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
95 |
96 | # BGR to RGB, HWC to CHW, numpy to tensor
97 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
98 | # normalize
99 | if self.mean is not None or self.std is not None:
100 | normalize(img_lq, self.mean, self.std, inplace=True)
101 | normalize(img_gt, self.mean, self.std, inplace=True)
102 |
103 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
104 |
105 | def __len__(self):
106 | return len(self.paths)
107 |
--------------------------------------------------------------------------------
/basicsr/data/single_image_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.data_util import paths_from_lmdb
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
7 | from basicsr.utils.registry import DATASET_REGISTRY
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class SingleImageDataset(data.Dataset):
12 | """Read only lq images in the test phase.
13 |
14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
15 |
16 | There are two modes:
17 | 1. 'meta_info_file': Use meta information file to generate paths.
18 | 2. 'folder': Scan folders to generate paths.
19 |
20 | Args:
21 | opt (dict): Config for train datasets. It contains the following keys:
22 | dataroot_lq (str): Data root path for lq.
23 | meta_info_file (str): Path for meta information file.
24 | io_backend (dict): IO backend type and other kwarg.
25 | """
26 |
27 | def __init__(self, opt):
28 | super(SingleImageDataset, self).__init__()
29 | self.opt = opt
30 | # file client (io backend)
31 | self.file_client = None
32 | self.io_backend_opt = opt['io_backend']
33 | self.mean = opt['mean'] if 'mean' in opt else None
34 | self.std = opt['std'] if 'std' in opt else None
35 | self.lq_folder = opt['dataroot_lq']
36 |
37 | if self.io_backend_opt['type'] == 'lmdb':
38 | self.io_backend_opt['db_paths'] = [self.lq_folder]
39 | self.io_backend_opt['client_keys'] = ['lq']
40 | self.paths = paths_from_lmdb(self.lq_folder)
41 | elif 'meta_info_file' in self.opt:
42 | with open(self.opt['meta_info_file'], 'r') as fin:
43 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
44 | else:
45 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
46 |
47 | def __getitem__(self, index):
48 | if self.file_client is None:
49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
50 |
51 | # load lq image
52 | lq_path = self.paths[index]
53 | img_bytes = self.file_client.get(lq_path, 'lq')
54 | img_lq = imfrombytes(img_bytes, float32=True)
55 |
56 | # color space transform
57 | if 'color' in self.opt and self.opt['color'] == 'y':
58 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
59 |
60 | # BGR to RGB, HWC to CHW, numpy to tensor
61 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
62 | # normalize
63 | if self.mean is not None or self.std is not None:
64 | normalize(img_lq, self.mean, self.std, inplace=True)
65 | return {'lq': img_lq, 'lq_path': lq_path}
66 |
67 | def __len__(self):
68 | return len(self.paths)
69 |
--------------------------------------------------------------------------------
/basicsr/data/transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import torch
4 |
5 |
6 | def mod_crop(img, scale):
7 | """Mod crop images, used during testing.
8 |
9 | Args:
10 | img (ndarray): Input image.
11 | scale (int): Scale factor.
12 |
13 | Returns:
14 | ndarray: Result image.
15 | """
16 | img = img.copy()
17 | if img.ndim in (2, 3):
18 | h, w = img.shape[0], img.shape[1]
19 | h_remainder, w_remainder = h % scale, w % scale
20 | img = img[:h - h_remainder, :w - w_remainder, ...]
21 | else:
22 | raise ValueError(f'Wrong img ndim: {img.ndim}.')
23 | return img
24 |
25 |
26 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
27 | """Paired random crop. Support Numpy array and Tensor inputs.
28 |
29 | It crops lists of lq and gt images with corresponding locations.
30 |
31 | Args:
32 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
33 | should have the same shape. If the input is an ndarray, it will
34 | be transformed to a list containing itself.
35 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
36 | should have the same shape. If the input is an ndarray, it will
37 | be transformed to a list containing itself.
38 | gt_patch_size (int): GT patch size.
39 | scale (int): Scale factor.
40 | gt_path (str): Path to ground-truth. Default: None.
41 |
42 | Returns:
43 | list[ndarray] | ndarray: GT images and LQ images. If returned results
44 | only have one element, just return ndarray.
45 | """
46 |
47 | if not isinstance(img_gts, list):
48 | img_gts = [img_gts]
49 | if not isinstance(img_lqs, list):
50 | img_lqs = [img_lqs]
51 |
52 | # determine input type: Numpy array or Tensor
53 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
54 |
55 | if input_type == 'Tensor':
56 | h_lq, w_lq = img_lqs[0].size()[-2:]
57 | h_gt, w_gt = img_gts[0].size()[-2:]
58 | else:
59 | h_lq, w_lq = img_lqs[0].shape[0:2]
60 | h_gt, w_gt = img_gts[0].shape[0:2]
61 | lq_patch_size = gt_patch_size // scale
62 |
63 | if h_gt != h_lq * scale or w_gt != w_lq * scale:
64 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
65 | f'multiplication of LQ ({h_lq}, {w_lq}).')
66 | if h_lq < lq_patch_size or w_lq < lq_patch_size:
67 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
68 | f'({lq_patch_size}, {lq_patch_size}). '
69 | f'Please remove {gt_path}.')
70 |
71 | # randomly choose top and left coordinates for lq patch
72 | top = random.randint(0, h_lq - lq_patch_size)
73 | left = random.randint(0, w_lq - lq_patch_size)
74 |
75 | # crop lq patch
76 | if input_type == 'Tensor':
77 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
78 | else:
79 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
80 |
81 | # crop corresponding gt patch
82 | top_gt, left_gt = int(top * scale), int(left * scale)
83 | if input_type == 'Tensor':
84 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
85 | else:
86 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
87 | if len(img_gts) == 1:
88 | img_gts = img_gts[0]
89 | if len(img_lqs) == 1:
90 | img_lqs = img_lqs[0]
91 | return img_gts, img_lqs
92 |
93 |
94 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
95 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
96 |
97 | We use vertical flip and transpose for rotation implementation.
98 | All the images in the list use the same augmentation.
99 |
100 | Args:
101 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input
102 | is an ndarray, it will be transformed to a list.
103 | hflip (bool): Horizontal flip. Default: True.
104 | rotation (bool): Ratotation. Default: True.
105 | flows (list[ndarray]: Flows to be augmented. If the input is an
106 | ndarray, it will be transformed to a list.
107 | Dimension is (h, w, 2). Default: None.
108 | return_status (bool): Return the status of flip and rotation.
109 | Default: False.
110 |
111 | Returns:
112 | list[ndarray] | ndarray: Augmented images and flows. If returned
113 | results only have one element, just return ndarray.
114 |
115 | """
116 | hflip = hflip and random.random() < 0.5
117 | vflip = rotation and random.random() < 0.5
118 | rot90 = rotation and random.random() < 0.5
119 |
120 | def _augment(img):
121 | if hflip: # horizontal
122 | cv2.flip(img, 1, img)
123 | if vflip: # vertical
124 | cv2.flip(img, 0, img)
125 | if rot90:
126 | img = img.transpose(1, 0, 2)
127 | return img
128 |
129 | def _augment_flow(flow):
130 | if hflip: # horizontal
131 | cv2.flip(flow, 1, flow)
132 | flow[:, :, 0] *= -1
133 | if vflip: # vertical
134 | cv2.flip(flow, 0, flow)
135 | flow[:, :, 1] *= -1
136 | if rot90:
137 | flow = flow.transpose(1, 0, 2)
138 | flow = flow[:, :, [1, 0]]
139 | return flow
140 |
141 | if not isinstance(imgs, list):
142 | imgs = [imgs]
143 | imgs = [_augment(img) for img in imgs]
144 | if len(imgs) == 1:
145 | imgs = imgs[0]
146 |
147 | if flows is not None:
148 | if not isinstance(flows, list):
149 | flows = [flows]
150 | flows = [_augment_flow(flow) for flow in flows]
151 | if len(flows) == 1:
152 | flows = flows[0]
153 | return imgs, flows
154 | else:
155 | if return_status:
156 | return imgs, (hflip, vflip, rot90)
157 | else:
158 | return imgs
159 |
160 |
161 | def img_rotate(img, angle, center=None, scale=1.0):
162 | """Rotate image.
163 |
164 | Args:
165 | img (ndarray): Image to be rotated.
166 | angle (float): Rotation angle in degrees. Positive values mean
167 | counter-clockwise rotation.
168 | center (tuple[int]): Rotation center. If the center is None,
169 | initialize it as the center of the image. Default: None.
170 | scale (float): Isotropic scale factor. Default: 1.0.
171 | """
172 | (h, w) = img.shape[:2]
173 |
174 | if center is None:
175 | center = (w // 2, h // 2)
176 |
177 | matrix = cv2.getRotationMatrix2D(center, angle, scale)
178 | rotated_img = cv2.warpAffine(img, matrix, (w, h))
179 | return rotated_img
180 |
--------------------------------------------------------------------------------
/basicsr/data/vimeo90k_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from pathlib import Path
4 | from torch.utils import data as data
5 |
6 | from basicsr.data.transforms import augment, paired_random_crop
7 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
8 | from basicsr.utils.registry import DATASET_REGISTRY
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class Vimeo90KDataset(data.Dataset):
13 | """Vimeo90K dataset for training.
14 |
15 | The keys are generated from a meta info txt file.
16 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
17 |
18 | Each line contains the following items, separated by a white space.
19 |
20 | 1. clip name;
21 | 2. frame number;
22 | 3. image shape
23 |
24 | Examples:
25 |
26 | ::
27 |
28 | 00001/0001 7 (256,448,3)
29 | 00001/0002 7 (256,448,3)
30 |
31 | - Key examples: "00001/0001"
32 | - GT (gt): Ground-Truth;
33 | - LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
34 |
35 | The neighboring frame list for different num_frame:
36 |
37 | ::
38 |
39 | num_frame | frame list
40 | 1 | 4
41 | 3 | 3,4,5
42 | 5 | 2,3,4,5,6
43 | 7 | 1,2,3,4,5,6,7
44 |
45 | Args:
46 | opt (dict): Config for train dataset. It contains the following keys:
47 | dataroot_gt (str): Data root path for gt.
48 | dataroot_lq (str): Data root path for lq.
49 | meta_info_file (str): Path for meta information file.
50 | io_backend (dict): IO backend type and other kwarg.
51 | num_frame (int): Window size for input frames.
52 | gt_size (int): Cropped patched size for gt patches.
53 | random_reverse (bool): Random reverse input frames.
54 | use_hflip (bool): Use horizontal flips.
55 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
56 | scale (bool): Scale, which will be added automatically.
57 | """
58 |
59 | def __init__(self, opt):
60 | super(Vimeo90KDataset, self).__init__()
61 | self.opt = opt
62 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
63 |
64 | with open(opt['meta_info_file'], 'r') as fin:
65 | self.keys = [line.split(' ')[0] for line in fin]
66 |
67 | # file client (io backend)
68 | self.file_client = None
69 | self.io_backend_opt = opt['io_backend']
70 | self.is_lmdb = False
71 | if self.io_backend_opt['type'] == 'lmdb':
72 | self.is_lmdb = True
73 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
74 | self.io_backend_opt['client_keys'] = ['lq', 'gt']
75 |
76 | # indices of input images
77 | self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
78 |
79 | # temporal augmentation configs
80 | self.random_reverse = opt['random_reverse']
81 | logger = get_root_logger()
82 | logger.info(f'Random reverse is {self.random_reverse}.')
83 |
84 | def __getitem__(self, index):
85 | if self.file_client is None:
86 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
87 |
88 | # random reverse
89 | if self.random_reverse and random.random() < 0.5:
90 | self.neighbor_list.reverse()
91 |
92 | scale = self.opt['scale']
93 | gt_size = self.opt['gt_size']
94 | key = self.keys[index]
95 | clip, seq = key.split('/') # key example: 00001/0001
96 |
97 | # get the GT frame (im4.png)
98 | if self.is_lmdb:
99 | img_gt_path = f'{key}/im4'
100 | else:
101 | img_gt_path = self.gt_root / clip / seq / 'im4.png'
102 | img_bytes = self.file_client.get(img_gt_path, 'gt')
103 | img_gt = imfrombytes(img_bytes, float32=True)
104 |
105 | # get the neighboring LQ frames
106 | img_lqs = []
107 | for neighbor in self.neighbor_list:
108 | if self.is_lmdb:
109 | img_lq_path = f'{clip}/{seq}/im{neighbor}'
110 | else:
111 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
112 | img_bytes = self.file_client.get(img_lq_path, 'lq')
113 | img_lq = imfrombytes(img_bytes, float32=True)
114 | img_lqs.append(img_lq)
115 |
116 | # randomly crop
117 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
118 |
119 | # augmentation - flip, rotate
120 | img_lqs.append(img_gt)
121 | img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
122 |
123 | img_results = img2tensor(img_results)
124 | img_lqs = torch.stack(img_results[0:-1], dim=0)
125 | img_gt = img_results[-1]
126 |
127 | # img_lqs: (t, c, h, w)
128 | # img_gt: (c, h, w)
129 | # key: str
130 | return {'lq': img_lqs, 'gt': img_gt, 'key': key}
131 |
132 | def __len__(self):
133 | return len(self.keys)
134 |
135 |
136 | @DATASET_REGISTRY.register()
137 | class Vimeo90KRecurrentDataset(Vimeo90KDataset):
138 |
139 | def __init__(self, opt):
140 | super(Vimeo90KRecurrentDataset, self).__init__(opt)
141 |
142 | self.flip_sequence = opt['flip_sequence']
143 | self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
144 |
145 | def __getitem__(self, index):
146 | if self.file_client is None:
147 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
148 |
149 | # random reverse
150 | if self.random_reverse and random.random() < 0.5:
151 | self.neighbor_list.reverse()
152 |
153 | scale = self.opt['scale']
154 | gt_size = self.opt['gt_size']
155 | key = self.keys[index]
156 | clip, seq = key.split('/') # key example: 00001/0001
157 |
158 | # get the neighboring LQ and GT frames
159 | img_lqs = []
160 | img_gts = []
161 | for neighbor in self.neighbor_list:
162 | if self.is_lmdb:
163 | img_lq_path = f'{clip}/{seq}/im{neighbor}'
164 | img_gt_path = f'{clip}/{seq}/im{neighbor}'
165 | else:
166 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
167 | img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
168 | # LQ
169 | img_bytes = self.file_client.get(img_lq_path, 'lq')
170 | img_lq = imfrombytes(img_bytes, float32=True)
171 | # GT
172 | img_bytes = self.file_client.get(img_gt_path, 'gt')
173 | img_gt = imfrombytes(img_bytes, float32=True)
174 |
175 | img_lqs.append(img_lq)
176 | img_gts.append(img_gt)
177 |
178 | # randomly crop
179 | img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
180 |
181 | # augmentation - flip, rotate
182 | img_lqs.extend(img_gts)
183 | img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
184 |
185 | img_results = img2tensor(img_results)
186 | img_lqs = torch.stack(img_results[:7], dim=0)
187 | img_gts = torch.stack(img_results[7:], dim=0)
188 |
189 | if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
190 | img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
191 | img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
192 |
193 | # img_lqs: (t, c, h, w)
194 | # img_gt: (c, h, w)
195 | # key: str
196 | return {'lq': img_lqs, 'gt': img_gts, 'key': key}
197 |
198 | def __len__(self):
199 | return len(self.keys)
200 |
--------------------------------------------------------------------------------
/basicsr/losses/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from basicsr.utils import get_root_logger, scandir
6 | from basicsr.utils.registry import LOSS_REGISTRY
7 | from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty
8 |
9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize']
10 |
11 | # automatically scan and import loss modules for registry
12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py'
13 | loss_folder = osp.dirname(osp.abspath(__file__))
14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
15 | # import all the loss modules
16 | _model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames]
17 |
18 |
19 | def build_loss(opt):
20 | """Build loss from options.
21 |
22 | Args:
23 | opt (dict): Configuration. It must contain:
24 | type (str): Model type.
25 | """
26 | opt = deepcopy(opt)
27 | loss_type = opt.pop('type')
28 | loss = LOSS_REGISTRY.get(loss_type)(**opt)
29 | logger = get_root_logger()
30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.')
31 | return loss
32 |
--------------------------------------------------------------------------------
/basicsr/losses/gan_loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import autograd as autograd
4 | from torch import nn as nn
5 | from torch.nn import functional as F
6 |
7 | from basicsr.utils.registry import LOSS_REGISTRY
8 |
9 |
10 | @LOSS_REGISTRY.register()
11 | class GANLoss(nn.Module):
12 | """Define GAN loss.
13 |
14 | Args:
15 | gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
16 | real_label_val (float): The value for real label. Default: 1.0.
17 | fake_label_val (float): The value for fake label. Default: 0.0.
18 | loss_weight (float): Loss weight. Default: 1.0.
19 | Note that loss_weight is only for generators; and it is always 1.0
20 | for discriminators.
21 | """
22 |
23 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
24 | super(GANLoss, self).__init__()
25 | self.gan_type = gan_type
26 | self.loss_weight = loss_weight
27 | self.real_label_val = real_label_val
28 | self.fake_label_val = fake_label_val
29 |
30 | if self.gan_type == 'vanilla':
31 | self.loss = nn.BCEWithLogitsLoss()
32 | elif self.gan_type == 'lsgan':
33 | self.loss = nn.MSELoss()
34 | elif self.gan_type == 'wgan':
35 | self.loss = self._wgan_loss
36 | elif self.gan_type == 'wgan_softplus':
37 | self.loss = self._wgan_softplus_loss
38 | elif self.gan_type == 'hinge':
39 | self.loss = nn.ReLU()
40 | else:
41 | raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
42 |
43 | def _wgan_loss(self, input, target):
44 | """wgan loss.
45 |
46 | Args:
47 | input (Tensor): Input tensor.
48 | target (bool): Target label.
49 |
50 | Returns:
51 | Tensor: wgan loss.
52 | """
53 | return -input.mean() if target else input.mean()
54 |
55 | def _wgan_softplus_loss(self, input, target):
56 | """wgan loss with soft plus. softplus is a smooth approximation to the
57 | ReLU function.
58 |
59 | In StyleGAN2, it is called:
60 | Logistic loss for discriminator;
61 | Non-saturating loss for generator.
62 |
63 | Args:
64 | input (Tensor): Input tensor.
65 | target (bool): Target label.
66 |
67 | Returns:
68 | Tensor: wgan loss.
69 | """
70 | return F.softplus(-input).mean() if target else F.softplus(input).mean()
71 |
72 | def get_target_label(self, input, target_is_real):
73 | """Get target label.
74 |
75 | Args:
76 | input (Tensor): Input tensor.
77 | target_is_real (bool): Whether the target is real or fake.
78 |
79 | Returns:
80 | (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
81 | return Tensor.
82 | """
83 |
84 | if self.gan_type in ['wgan', 'wgan_softplus']:
85 | return target_is_real
86 | target_val = (self.real_label_val if target_is_real else self.fake_label_val)
87 | return input.new_ones(input.size()) * target_val
88 |
89 | def forward(self, input, target_is_real, is_disc=False):
90 | """
91 | Args:
92 | input (Tensor): The input for the loss module, i.e., the network
93 | prediction.
94 | target_is_real (bool): Whether the targe is real or fake.
95 | is_disc (bool): Whether the loss for discriminators or not.
96 | Default: False.
97 |
98 | Returns:
99 | Tensor: GAN loss value.
100 | """
101 | target_label = self.get_target_label(input, target_is_real)
102 | if self.gan_type == 'hinge':
103 | if is_disc: # for discriminators in hinge-gan
104 | input = -input if target_is_real else input
105 | loss = self.loss(1 + input).mean()
106 | else: # for generators in hinge-gan
107 | loss = -input.mean()
108 | else: # other gan types
109 | loss = self.loss(input, target_label)
110 |
111 | # loss_weight is always 1.0 for discriminators
112 | return loss if is_disc else loss * self.loss_weight
113 |
114 |
115 | @LOSS_REGISTRY.register()
116 | class MultiScaleGANLoss(GANLoss):
117 | """
118 | MultiScaleGANLoss accepts a list of predictions
119 | """
120 |
121 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
122 | super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
123 |
124 | def forward(self, input, target_is_real, is_disc=False):
125 | """
126 | The input is a list of tensors, or a list of (a list of tensors)
127 | """
128 | if isinstance(input, list):
129 | loss = 0
130 | for pred_i in input:
131 | if isinstance(pred_i, list):
132 | # Only compute GAN loss for the last layer
133 | # in case of multiscale feature matching
134 | pred_i = pred_i[-1]
135 | # Safe operation: 0-dim tensor calling self.mean() does nothing
136 | loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
137 | loss += loss_tensor
138 | return loss / len(input)
139 | else:
140 | return super().forward(input, target_is_real, is_disc)
141 |
142 |
143 | def r1_penalty(real_pred, real_img):
144 | """R1 regularization for discriminator. The core idea is to
145 | penalize the gradient on real data alone: when the
146 | generator distribution produces the true data distribution
147 | and the discriminator is equal to 0 on the data manifold, the
148 | gradient penalty ensures that the discriminator cannot create
149 | a non-zero gradient orthogonal to the data manifold without
150 | suffering a loss in the GAN game.
151 |
152 | Reference: Eq. 9 in Which training methods for GANs do actually converge.
153 | """
154 | grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
155 | grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
156 | return grad_penalty
157 |
158 |
159 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
160 | noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
161 | grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
162 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
163 |
164 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
165 |
166 | path_penalty = (path_lengths - path_mean).pow(2).mean()
167 |
168 | return path_penalty, path_lengths.detach().mean(), path_mean.detach()
169 |
170 |
171 | def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
172 | """Calculate gradient penalty for wgan-gp.
173 |
174 | Args:
175 | discriminator (nn.Module): Network for the discriminator.
176 | real_data (Tensor): Real input data.
177 | fake_data (Tensor): Fake input data.
178 | weight (Tensor): Weight tensor. Default: None.
179 |
180 | Returns:
181 | Tensor: A tensor for gradient penalty.
182 | """
183 |
184 | batch_size = real_data.size(0)
185 | alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
186 |
187 | # interpolate between real_data and fake_data
188 | interpolates = alpha * real_data + (1. - alpha) * fake_data
189 | interpolates = autograd.Variable(interpolates, requires_grad=True)
190 |
191 | disc_interpolates = discriminator(interpolates)
192 | gradients = autograd.grad(
193 | outputs=disc_interpolates,
194 | inputs=interpolates,
195 | grad_outputs=torch.ones_like(disc_interpolates),
196 | create_graph=True,
197 | retain_graph=True,
198 | only_inputs=True)[0]
199 |
200 | if weight is not None:
201 | gradients = gradients * weight
202 |
203 | gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
204 | if weight is not None:
205 | gradients_penalty /= torch.mean(weight)
206 |
207 | return gradients_penalty
208 |
--------------------------------------------------------------------------------
/basicsr/losses/loss_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | from torch.nn import functional as F
4 |
5 |
6 | def reduce_loss(loss, reduction):
7 | """Reduce loss as specified.
8 |
9 | Args:
10 | loss (Tensor): Elementwise loss tensor.
11 | reduction (str): Options are 'none', 'mean' and 'sum'.
12 |
13 | Returns:
14 | Tensor: Reduced loss tensor.
15 | """
16 | reduction_enum = F._Reduction.get_enum(reduction)
17 | # none: 0, elementwise_mean:1, sum: 2
18 | if reduction_enum == 0:
19 | return loss
20 | elif reduction_enum == 1:
21 | return loss.mean()
22 | else:
23 | return loss.sum()
24 |
25 |
26 | def weight_reduce_loss(loss, weight=None, reduction='mean'):
27 | """Apply element-wise weight and reduce loss.
28 |
29 | Args:
30 | loss (Tensor): Element-wise loss.
31 | weight (Tensor): Element-wise weights. Default: None.
32 | reduction (str): Same as built-in losses of PyTorch. Options are
33 | 'none', 'mean' and 'sum'. Default: 'mean'.
34 |
35 | Returns:
36 | Tensor: Loss values.
37 | """
38 | # if weight is specified, apply element-wise weight
39 | if weight is not None:
40 | assert weight.dim() == loss.dim()
41 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
42 | loss = loss * weight
43 |
44 | # if weight is not specified or reduction is sum, just reduce the loss
45 | if weight is None or reduction == 'sum':
46 | loss = reduce_loss(loss, reduction)
47 | # if reduction is mean, then compute mean over weight region
48 | elif reduction == 'mean':
49 | if weight.size(1) > 1:
50 | weight = weight.sum()
51 | else:
52 | weight = weight.sum() * loss.size(1)
53 | loss = loss.sum() / weight
54 |
55 | return loss
56 |
57 |
58 | def weighted_loss(loss_func):
59 | """Create a weighted version of a given loss function.
60 |
61 | To use this decorator, the loss function must have the signature like
62 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
63 | element-wise loss without any reduction. This decorator will add weight
64 | and reduction arguments to the function. The decorated function will have
65 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
66 | **kwargs)`.
67 |
68 | :Example:
69 |
70 | >>> import torch
71 | >>> @weighted_loss
72 | >>> def l1_loss(pred, target):
73 | >>> return (pred - target).abs()
74 |
75 | >>> pred = torch.Tensor([0, 2, 3])
76 | >>> target = torch.Tensor([1, 1, 1])
77 | >>> weight = torch.Tensor([1, 0, 1])
78 |
79 | >>> l1_loss(pred, target)
80 | tensor(1.3333)
81 | >>> l1_loss(pred, target, weight)
82 | tensor(1.5000)
83 | >>> l1_loss(pred, target, reduction='none')
84 | tensor([1., 1., 2.])
85 | >>> l1_loss(pred, target, weight, reduction='sum')
86 | tensor(3.)
87 | """
88 |
89 | @functools.wraps(loss_func)
90 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
91 | # get element-wise loss
92 | loss = loss_func(pred, target, **kwargs)
93 | loss = weight_reduce_loss(loss, weight, reduction)
94 | return loss
95 |
96 | return wrapper
97 |
98 |
99 | def get_local_weights(residual, ksize):
100 | """Get local weights for generating the artifact map of LDL.
101 |
102 | It is only called by the `get_refined_artifact_map` function.
103 |
104 | Args:
105 | residual (Tensor): Residual between predicted and ground truth images.
106 | ksize (Int): size of the local window.
107 |
108 | Returns:
109 | Tensor: weight for each pixel to be discriminated as an artifact pixel
110 | """
111 |
112 | pad = (ksize - 1) // 2
113 | residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
114 |
115 | unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
116 | pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
117 |
118 | return pixel_level_weight
119 |
120 |
121 | def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
122 | """Calculate the artifact map of LDL
123 | (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
124 |
125 | Args:
126 | img_gt (Tensor): ground truth images.
127 | img_output (Tensor): output images given by the optimizing model.
128 | img_ema (Tensor): output images given by the ema model.
129 | ksize (Int): size of the local window.
130 |
131 | Returns:
132 | overall_weight: weight for each pixel to be discriminated as an artifact pixel
133 | (calculated based on both local and global observations).
134 | """
135 |
136 | residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
137 | residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
138 |
139 | patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
140 | pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
141 | overall_weight = patch_level_weight * pixel_level_weight
142 |
143 | overall_weight[residual_sr < residual_ema] = 0
144 |
145 | return overall_weight
146 |
--------------------------------------------------------------------------------
/basicsr/metrics/README.md:
--------------------------------------------------------------------------------
1 | # Metrics
2 |
3 | [English](README.md) **|** [简体中文](README_CN.md)
4 |
5 | - [约定](#约定)
6 | - [PSNR 和 SSIM](#psnr-和-ssim)
7 |
8 | ## 约定
9 |
10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
11 |
12 | - Numpy 类型 (一般是 cv2 的结果)
13 | - UINT8: BGR, [0, 255], (h, w, c)
14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果
15 | - Tensor 类型
16 | - float: RGB, [0, 1], (n, c, h, w)
17 |
18 | 其他约定:
19 |
20 | - 以 `_pt` 结尾的是 PyTorch 结果
21 | - PyTorch version 支持 batch 计算
22 | - 颜色转换在 float32 上做;metric计算在 float64 上做
23 |
24 | ## PSNR 和 SSIM
25 |
26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378))
28 |
29 | 下面列了各个实现的结果比对.
30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
31 |
32 | - PSNR 比对
33 |
34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
35 | |:---| :---: | :---: | :---: | :---: | :---: |
36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 |
37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916|
38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663|
40 |
41 | - SSIM 比对
42 |
43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
44 | |:---| :---: | :---: | :---: | :---: | :---: |
45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 |
46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171|
47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
49 |
--------------------------------------------------------------------------------
/basicsr/metrics/README_CN.md:
--------------------------------------------------------------------------------
1 | # Metrics
2 |
3 | [English](README.md) **|** [简体中文](README_CN.md)
4 |
5 | - [约定](#约定)
6 | - [PSNR 和 SSIM](#psnr-和-ssim)
7 |
8 | ## 约定
9 |
10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
11 |
12 | - Numpy 类型 (一般是 cv2 的结果)
13 | - UINT8: BGR, [0, 255], (h, w, c)
14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果
15 | - Tensor 类型
16 | - float: RGB, [0, 1], (n, c, h, w)
17 |
18 | 其他约定:
19 |
20 | - 以 `_pt` 结尾的是 PyTorch 结果
21 | - PyTorch version 支持 batch 计算
22 | - 颜色转换在 float32 上做;metric计算在 float64 上做
23 |
24 | ## PSNR 和 SSIM
25 |
26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378))
28 |
29 | 下面列了各个实现的结果比对.
30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
31 |
32 | - PSNR 比对
33 |
34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
35 | |:---| :---: | :---: | :---: | :---: | :---: |
36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 |
37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916|
38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663|
40 |
41 | - SSIM 比对
42 |
43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
44 | |:---| :---: | :---: | :---: | :---: | :---: |
45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 |
46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171|
47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
49 |
--------------------------------------------------------------------------------
/basicsr/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | from basicsr.utils.registry import METRIC_REGISTRY
4 | from .niqe import calculate_niqe
5 | from .psnr_ssim import calculate_psnr, calculate_ssim
6 |
7 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
8 |
9 |
10 | def calculate_metric(data, opt):
11 | """Calculate metric from data and options.
12 |
13 | Args:
14 | opt (dict): Configuration. It must contain:
15 | type (str): Model type.
16 | """
17 | opt = deepcopy(opt)
18 | metric_type = opt.pop('type')
19 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
20 | return metric
21 |
--------------------------------------------------------------------------------
/basicsr/metrics/fid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from scipy import linalg
5 | from tqdm import tqdm
6 |
7 | from basicsr.archs.inception import InceptionV3
8 |
9 |
10 | def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False):
11 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it
12 | # does resize the input.
13 | inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input)
14 | inception = nn.DataParallel(inception).eval().to(device)
15 | return inception
16 |
17 |
18 | @torch.no_grad()
19 | def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'):
20 | """Extract inception features.
21 |
22 | Args:
23 | data_generator (generator): A data generator.
24 | inception (nn.Module): Inception model.
25 | len_generator (int): Length of the data_generator to show the
26 | progressbar. Default: None.
27 | device (str): Device. Default: cuda.
28 |
29 | Returns:
30 | Tensor: Extracted features.
31 | """
32 | if len_generator is not None:
33 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
34 | else:
35 | pbar = None
36 | features = []
37 |
38 | for data in data_generator:
39 | if pbar:
40 | pbar.update(1)
41 | data = data.to(device)
42 | feature = inception(data)[0].view(data.shape[0], -1)
43 | features.append(feature.to('cpu'))
44 | if pbar:
45 | pbar.close()
46 | features = torch.cat(features, 0)
47 | return features
48 |
49 |
50 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
51 | """Numpy implementation of the Frechet Distance.
52 |
53 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is:
54 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
55 | Stable version by Dougal J. Sutherland.
56 |
57 | Args:
58 | mu1 (np.array): The sample mean over activations.
59 | sigma1 (np.array): The covariance matrix over activations for generated samples.
60 | mu2 (np.array): The sample mean over activations, precalculated on an representative data set.
61 | sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set.
62 |
63 | Returns:
64 | float: The Frechet Distance.
65 | """
66 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
67 | assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions')
68 |
69 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
70 |
71 | # Product might be almost singular
72 | if not np.isfinite(cov_sqrt).all():
73 | print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates')
74 | offset = np.eye(sigma1.shape[0]) * eps
75 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
76 |
77 | # Numerical error might give slight imaginary component
78 | if np.iscomplexobj(cov_sqrt):
79 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
80 | m = np.max(np.abs(cov_sqrt.imag))
81 | raise ValueError(f'Imaginary component {m}')
82 | cov_sqrt = cov_sqrt.real
83 |
84 | mean_diff = mu1 - mu2
85 | mean_norm = mean_diff @ mean_diff
86 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
87 | fid = mean_norm + trace
88 |
89 | return fid
90 |
--------------------------------------------------------------------------------
/basicsr/metrics/metric_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from basicsr.utils import bgr2ycbcr
4 |
5 |
6 | def reorder_image(img, input_order='HWC'):
7 | """Reorder images to 'HWC' order.
8 |
9 | If the input_order is (h, w), return (h, w, 1);
10 | If the input_order is (c, h, w), return (h, w, c);
11 | If the input_order is (h, w, c), return as it is.
12 |
13 | Args:
14 | img (ndarray): Input image.
15 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
16 | If the input image shape is (h, w), input_order will not have
17 | effects. Default: 'HWC'.
18 |
19 | Returns:
20 | ndarray: reordered image.
21 | """
22 |
23 | if input_order not in ['HWC', 'CHW']:
24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
25 | if len(img.shape) == 2:
26 | img = img[..., None]
27 | if input_order == 'CHW':
28 | img = img.transpose(1, 2, 0)
29 | return img
30 |
31 |
32 | def to_y_channel(img):
33 | """Change to Y channel of YCbCr.
34 |
35 | Args:
36 | img (ndarray): Images with range [0, 255].
37 |
38 | Returns:
39 | (ndarray): Images with range [0, 255] (float type) without round.
40 | """
41 | img = img.astype(np.float32) / 255.
42 | if img.ndim == 3 and img.shape[2] == 3:
43 | img = bgr2ycbcr(img, y_only=True)
44 | img = img[..., None]
45 | return img * 255.
46 |
--------------------------------------------------------------------------------
/basicsr/metrics/niqe_pris_params.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Aitical/SCNet/c0f8678f2f50e1f97e00c3e018e904a273f0f39a/basicsr/metrics/niqe_pris_params.npz
--------------------------------------------------------------------------------
/basicsr/metrics/test_metrics/test_psnr_ssim.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 |
4 | from basicsr.metrics import calculate_psnr, calculate_ssim
5 | from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt
6 | from basicsr.utils import img2tensor
7 |
8 |
9 | def test(img_path, img_path2, crop_border, test_y_channel=False):
10 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
11 | img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED)
12 |
13 | # --------------------- Numpy ---------------------
14 | psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel)
15 | ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel)
16 | print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
17 |
18 | # --------------------- PyTorch (CPU) ---------------------
19 | img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
20 | img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
21 |
22 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
23 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
24 | print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}')
25 |
26 | # --------------------- PyTorch (GPU) ---------------------
27 | img = img.cuda()
28 | img2 = img2.cuda()
29 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
30 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
31 | print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}')
32 |
33 | psnr_pth = calculate_psnr_pt(
34 | torch.repeat_interleave(img, 2, dim=0),
35 | torch.repeat_interleave(img2, 2, dim=0),
36 | crop_border=crop_border,
37 | test_y_channel=test_y_channel)
38 | ssim_pth = calculate_ssim_pt(
39 | torch.repeat_interleave(img, 2, dim=0),
40 | torch.repeat_interleave(img2, 2, dim=0),
41 | crop_border=crop_border,
42 | test_y_channel=test_y_channel)
43 | print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,'
44 | f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}')
45 |
46 |
47 | if __name__ == '__main__':
48 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False)
49 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True)
50 |
51 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False)
52 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True)
53 |
--------------------------------------------------------------------------------
/basicsr/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from basicsr.utils import get_root_logger, scandir
6 | from basicsr.utils.registry import MODEL_REGISTRY
7 |
8 | __all__ = ['build_model']
9 |
10 | # automatically scan and import model modules for registry
11 | # scan all the files under the 'models' folder and collect files ending with '_model.py'
12 | model_folder = osp.dirname(osp.abspath(__file__))
13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
14 | # import all the model modules
15 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
16 |
17 |
18 | def build_model(opt):
19 | """Build model from options.
20 |
21 | Args:
22 | opt (dict): Configuration. It must contain:
23 | model_type (str): Model type.
24 | """
25 | opt = deepcopy(opt)
26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt)
27 | logger = get_root_logger()
28 | logger.info(f'Model [{model.__class__.__name__}] is created.')
29 | return model
30 |
--------------------------------------------------------------------------------
/basicsr/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 |
6 | class MultiStepRestartLR(_LRScheduler):
7 | """ MultiStep with restarts learning rate scheme.
8 |
9 | Args:
10 | optimizer (torch.nn.optimizer): Torch optimizer.
11 | milestones (list): Iterations that will decrease learning rate.
12 | gamma (float): Decrease ratio. Default: 0.1.
13 | restarts (list): Restart iterations. Default: [0].
14 | restart_weights (list): Restart weights at each restart iteration.
15 | Default: [1].
16 | last_epoch (int): Used in _LRScheduler. Default: -1.
17 | """
18 |
19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
20 | self.milestones = Counter(milestones)
21 | self.gamma = gamma
22 | self.restarts = restarts
23 | self.restart_weights = restart_weights
24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
26 |
27 | def get_lr(self):
28 | if self.last_epoch in self.restarts:
29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
31 | if self.last_epoch not in self.milestones:
32 | return [group['lr'] for group in self.optimizer.param_groups]
33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
34 |
35 |
36 | def get_position_from_periods(iteration, cumulative_period):
37 | """Get the position from a period list.
38 |
39 | It will return the index of the right-closest number in the period list.
40 | For example, the cumulative_period = [100, 200, 300, 400],
41 | if iteration == 50, return 0;
42 | if iteration == 210, return 2;
43 | if iteration == 300, return 2.
44 |
45 | Args:
46 | iteration (int): Current iteration.
47 | cumulative_period (list[int]): Cumulative period list.
48 |
49 | Returns:
50 | int: The position of the right-closest number in the period list.
51 | """
52 | for i, period in enumerate(cumulative_period):
53 | if iteration <= period:
54 | return i
55 |
56 |
57 | class CosineAnnealingRestartLR(_LRScheduler):
58 | """ Cosine annealing with restarts learning rate scheme.
59 |
60 | An example of config:
61 | periods = [10, 10, 10, 10]
62 | restart_weights = [1, 0.5, 0.5, 0.5]
63 | eta_min=1e-7
64 |
65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
66 | scheduler will restart with the weights in restart_weights.
67 |
68 | Args:
69 | optimizer (torch.nn.optimizer): Torch optimizer.
70 | periods (list): Period for each cosine anneling cycle.
71 | restart_weights (list): Restart weights at each restart iteration.
72 | Default: [1].
73 | eta_min (float): The minimum lr. Default: 0.
74 | last_epoch (int): Used in _LRScheduler. Default: -1.
75 | """
76 |
77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
78 | self.periods = periods
79 | self.restart_weights = restart_weights
80 | self.eta_min = eta_min
81 | assert (len(self.periods) == len(
82 | self.restart_weights)), 'periods and restart_weights should have the same length.'
83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
85 |
86 | def get_lr(self):
87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
88 | current_weight = self.restart_weights[idx]
89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
90 | current_period = self.periods[idx]
91 |
92 | return [
93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
95 | for base_lr in self.base_lrs
96 | ]
--------------------------------------------------------------------------------
/basicsr/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Aitical/SCNet/c0f8678f2f50e1f97e00c3e018e904a273f0f39a/basicsr/ops/__init__.py
--------------------------------------------------------------------------------
/basicsr/ops/dcn/__init__.py:
--------------------------------------------------------------------------------
1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
2 | modulated_deform_conv)
3 |
4 | __all__ = [
5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
6 | 'modulated_deform_conv'
7 | ]
8 |
--------------------------------------------------------------------------------
/basicsr/ops/dcn/src/deform_conv_ext.cpp:
--------------------------------------------------------------------------------
1 | // modify from
2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
3 |
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 |
10 | #define WITH_CUDA // always use cuda
11 | #ifdef WITH_CUDA
12 | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
13 | at::Tensor offset, at::Tensor output,
14 | at::Tensor columns, at::Tensor ones, int kW,
15 | int kH, int dW, int dH, int padW, int padH,
16 | int dilationW, int dilationH, int group,
17 | int deformable_group, int im2col_step);
18 |
19 | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
20 | at::Tensor gradOutput, at::Tensor gradInput,
21 | at::Tensor gradOffset, at::Tensor weight,
22 | at::Tensor columns, int kW, int kH, int dW,
23 | int dH, int padW, int padH, int dilationW,
24 | int dilationH, int group,
25 | int deformable_group, int im2col_step);
26 |
27 | int deform_conv_backward_parameters_cuda(
28 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
29 | at::Tensor gradWeight, // at::Tensor gradBias,
30 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
31 | int padW, int padH, int dilationW, int dilationH, int group,
32 | int deformable_group, float scale, int im2col_step);
33 |
34 | void modulated_deform_conv_cuda_forward(
35 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
36 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
37 | int kernel_h, int kernel_w, const int stride_h, const int stride_w,
38 | const int pad_h, const int pad_w, const int dilation_h,
39 | const int dilation_w, const int group, const int deformable_group,
40 | const bool with_bias);
41 |
42 | void modulated_deform_conv_cuda_backward(
43 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
44 | at::Tensor offset, at::Tensor mask, at::Tensor columns,
45 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
46 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
47 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
48 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
49 | const bool with_bias);
50 | #endif
51 |
52 | int deform_conv_forward(at::Tensor input, at::Tensor weight,
53 | at::Tensor offset, at::Tensor output,
54 | at::Tensor columns, at::Tensor ones, int kW,
55 | int kH, int dW, int dH, int padW, int padH,
56 | int dilationW, int dilationH, int group,
57 | int deformable_group, int im2col_step) {
58 | if (input.device().is_cuda()) {
59 | #ifdef WITH_CUDA
60 | return deform_conv_forward_cuda(input, weight, offset, output, columns,
61 | ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
62 | deformable_group, im2col_step);
63 | #else
64 | AT_ERROR("deform conv is not compiled with GPU support");
65 | #endif
66 | }
67 | AT_ERROR("deform conv is not implemented on CPU");
68 | }
69 |
70 | int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
71 | at::Tensor gradOutput, at::Tensor gradInput,
72 | at::Tensor gradOffset, at::Tensor weight,
73 | at::Tensor columns, int kW, int kH, int dW,
74 | int dH, int padW, int padH, int dilationW,
75 | int dilationH, int group,
76 | int deformable_group, int im2col_step) {
77 | if (input.device().is_cuda()) {
78 | #ifdef WITH_CUDA
79 | return deform_conv_backward_input_cuda(input, offset, gradOutput,
80 | gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
81 | dilationW, dilationH, group, deformable_group, im2col_step);
82 | #else
83 | AT_ERROR("deform conv is not compiled with GPU support");
84 | #endif
85 | }
86 | AT_ERROR("deform conv is not implemented on CPU");
87 | }
88 |
89 | int deform_conv_backward_parameters(
90 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
91 | at::Tensor gradWeight, // at::Tensor gradBias,
92 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
93 | int padW, int padH, int dilationW, int dilationH, int group,
94 | int deformable_group, float scale, int im2col_step) {
95 | if (input.device().is_cuda()) {
96 | #ifdef WITH_CUDA
97 | return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
98 | gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
99 | dilationH, group, deformable_group, scale, im2col_step);
100 | #else
101 | AT_ERROR("deform conv is not compiled with GPU support");
102 | #endif
103 | }
104 | AT_ERROR("deform conv is not implemented on CPU");
105 | }
106 |
107 | void modulated_deform_conv_forward(
108 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
109 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
110 | int kernel_h, int kernel_w, const int stride_h, const int stride_w,
111 | const int pad_h, const int pad_w, const int dilation_h,
112 | const int dilation_w, const int group, const int deformable_group,
113 | const bool with_bias) {
114 | if (input.device().is_cuda()) {
115 | #ifdef WITH_CUDA
116 | return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
117 | offset, mask, output, columns, kernel_h, kernel_w, stride_h,
118 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
119 | deformable_group, with_bias);
120 | #else
121 | AT_ERROR("modulated deform conv is not compiled with GPU support");
122 | #endif
123 | }
124 | AT_ERROR("modulated deform conv is not implemented on CPU");
125 | }
126 |
127 | void modulated_deform_conv_backward(
128 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
129 | at::Tensor offset, at::Tensor mask, at::Tensor columns,
130 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
131 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
132 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
133 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
134 | const bool with_bias) {
135 | if (input.device().is_cuda()) {
136 | #ifdef WITH_CUDA
137 | return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
138 | offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
139 | grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
140 | pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
141 | with_bias);
142 | #else
143 | AT_ERROR("modulated deform conv is not compiled with GPU support");
144 | #endif
145 | }
146 | AT_ERROR("modulated deform conv is not implemented on CPU");
147 | }
148 |
149 |
150 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
151 | m.def("deform_conv_forward", &deform_conv_forward,
152 | "deform forward");
153 | m.def("deform_conv_backward_input", &deform_conv_backward_input,
154 | "deform_conv_backward_input");
155 | m.def("deform_conv_backward_parameters",
156 | &deform_conv_backward_parameters,
157 | "deform_conv_backward_parameters");
158 | m.def("modulated_deform_conv_forward",
159 | &modulated_deform_conv_forward,
160 | "modulated deform conv forward");
161 | m.def("modulated_deform_conv_backward",
162 | &modulated_deform_conv_backward,
163 | "modulated deform conv backward");
164 | }
165 |
--------------------------------------------------------------------------------
/basicsr/ops/fused_act/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 |
3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
4 |
--------------------------------------------------------------------------------
/basicsr/ops/fused_act/fused_act.py:
--------------------------------------------------------------------------------
1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
2 |
3 | import os
4 | import torch
5 | from torch import nn
6 | from torch.autograd import Function
7 |
8 | BASICSR_JIT = os.getenv('BASICSR_JIT')
9 | if BASICSR_JIT == 'True':
10 | from torch.utils.cpp_extension import load
11 | module_path = os.path.dirname(__file__)
12 | fused_act_ext = load(
13 | 'fused',
14 | sources=[
15 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
16 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
17 | ],
18 | )
19 | else:
20 | try:
21 | from . import fused_act_ext
22 | except ImportError:
23 | pass
24 | # avoid annoying print output
25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
26 | # '1. compile with BASICSR_EXT=True. or\n '
27 | # '2. set BASICSR_JIT=True during running')
28 |
29 |
30 | class FusedLeakyReLUFunctionBackward(Function):
31 |
32 | @staticmethod
33 | def forward(ctx, grad_output, out, negative_slope, scale):
34 | ctx.save_for_backward(out)
35 | ctx.negative_slope = negative_slope
36 | ctx.scale = scale
37 |
38 | empty = grad_output.new_empty(0)
39 |
40 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
41 |
42 | dim = [0]
43 |
44 | if grad_input.ndim > 2:
45 | dim += list(range(2, grad_input.ndim))
46 |
47 | grad_bias = grad_input.sum(dim).detach()
48 |
49 | return grad_input, grad_bias
50 |
51 | @staticmethod
52 | def backward(ctx, gradgrad_input, gradgrad_bias):
53 | out, = ctx.saved_tensors
54 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
55 | ctx.scale)
56 |
57 | return gradgrad_out, None, None, None
58 |
59 |
60 | class FusedLeakyReLUFunction(Function):
61 |
62 | @staticmethod
63 | def forward(ctx, input, bias, negative_slope, scale):
64 | empty = input.new_empty(0)
65 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
66 | ctx.save_for_backward(out)
67 | ctx.negative_slope = negative_slope
68 | ctx.scale = scale
69 |
70 | return out
71 |
72 | @staticmethod
73 | def backward(ctx, grad_output):
74 | out, = ctx.saved_tensors
75 |
76 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
77 |
78 | return grad_input, grad_bias, None, None
79 |
80 |
81 | class FusedLeakyReLU(nn.Module):
82 |
83 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
84 | super().__init__()
85 |
86 | self.bias = nn.Parameter(torch.zeros(channel))
87 | self.negative_slope = negative_slope
88 | self.scale = scale
89 |
90 | def forward(self, input):
91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
92 |
93 |
94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
95 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
96 |
--------------------------------------------------------------------------------
/basicsr/ops/fused_act/src/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
2 | #include
3 |
4 |
5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input,
6 | const torch::Tensor& bias,
7 | const torch::Tensor& refer,
8 | int act, int grad, float alpha, float scale);
9 |
10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
13 |
14 | torch::Tensor fused_bias_act(const torch::Tensor& input,
15 | const torch::Tensor& bias,
16 | const torch::Tensor& refer,
17 | int act, int grad, float alpha, float scale) {
18 | CHECK_CUDA(input);
19 | CHECK_CUDA(bias);
20 |
21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
22 | }
23 |
24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
26 | }
27 |
--------------------------------------------------------------------------------
/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
3 | //
4 | // This work is made available under the Nvidia Source Code License-NC.
5 | // To view a copy of this license, visit
6 | // https://nvlabs.github.io/stylegan2/license.html
7 |
8 | #include
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 |
15 | #include
16 | #include
17 |
18 |
19 | template
20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
23 |
24 | scalar_t zero = 0.0;
25 |
26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
27 | scalar_t x = p_x[xi];
28 |
29 | if (use_bias) {
30 | x += p_b[(xi / step_b) % size_b];
31 | }
32 |
33 | scalar_t ref = use_ref ? p_ref[xi] : zero;
34 |
35 | scalar_t y;
36 |
37 | switch (act * 10 + grad) {
38 | default:
39 | case 10: y = x; break;
40 | case 11: y = x; break;
41 | case 12: y = 0.0; break;
42 |
43 | case 30: y = (x > 0.0) ? x : x * alpha; break;
44 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
45 | case 32: y = 0.0; break;
46 | }
47 |
48 | out[xi] = y * scale;
49 | }
50 | }
51 |
52 |
53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
54 | int act, int grad, float alpha, float scale) {
55 | int curDevice = -1;
56 | cudaGetDevice(&curDevice);
57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
58 |
59 | auto x = input.contiguous();
60 | auto b = bias.contiguous();
61 | auto ref = refer.contiguous();
62 |
63 | int use_bias = b.numel() ? 1 : 0;
64 | int use_ref = ref.numel() ? 1 : 0;
65 |
66 | int size_x = x.numel();
67 | int size_b = b.numel();
68 | int step_b = 1;
69 |
70 | for (int i = 1 + 1; i < x.dim(); i++) {
71 | step_b *= x.size(i);
72 | }
73 |
74 | int loop_x = 4;
75 | int block_size = 4 * 32;
76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
77 |
78 | auto y = torch::empty_like(x);
79 |
80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
81 | fused_bias_act_kernel<<>>(
82 | y.data_ptr(),
83 | x.data_ptr(),
84 | b.data_ptr(),
85 | ref.data_ptr(),
86 | act,
87 | grad,
88 | alpha,
89 | scale,
90 | loop_x,
91 | size_x,
92 | step_b,
93 | size_b,
94 | use_bias,
95 | use_ref
96 | );
97 | });
98 |
99 | return y;
100 | }
101 |
--------------------------------------------------------------------------------
/basicsr/ops/upfirdn2d/__init__.py:
--------------------------------------------------------------------------------
1 | from .upfirdn2d import upfirdn2d
2 |
3 | __all__ = ['upfirdn2d']
4 |
--------------------------------------------------------------------------------
/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
2 | #include
3 |
4 |
5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
6 | int up_x, int up_y, int down_x, int down_y,
7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
8 |
9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
12 |
13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
14 | int up_x, int up_y, int down_x, int down_y,
15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
16 | CHECK_CUDA(input);
17 | CHECK_CUDA(kernel);
18 |
19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
20 | }
21 |
22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
24 | }
25 |
--------------------------------------------------------------------------------
/basicsr/ops/upfirdn2d/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
2 |
3 | import os
4 | import torch
5 | from torch.autograd import Function
6 | from torch.nn import functional as F
7 |
8 | BASICSR_JIT = os.getenv('BASICSR_JIT')
9 | if BASICSR_JIT == 'True':
10 | from torch.utils.cpp_extension import load
11 | module_path = os.path.dirname(__file__)
12 | upfirdn2d_ext = load(
13 | 'upfirdn2d',
14 | sources=[
15 | os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
16 | os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
17 | ],
18 | )
19 | else:
20 | try:
21 | from . import upfirdn2d_ext
22 | except ImportError:
23 | pass
24 | # avoid annoying print output
25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
26 | # '1. compile with BASICSR_EXT=True. or\n '
27 | # '2. set BASICSR_JIT=True during running')
28 |
29 |
30 | class UpFirDn2dBackward(Function):
31 |
32 | @staticmethod
33 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
34 |
35 | up_x, up_y = up
36 | down_x, down_y = down
37 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
38 |
39 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
40 |
41 | grad_input = upfirdn2d_ext.upfirdn2d(
42 | grad_output,
43 | grad_kernel,
44 | down_x,
45 | down_y,
46 | up_x,
47 | up_y,
48 | g_pad_x0,
49 | g_pad_x1,
50 | g_pad_y0,
51 | g_pad_y1,
52 | )
53 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
54 |
55 | ctx.save_for_backward(kernel)
56 |
57 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
58 |
59 | ctx.up_x = up_x
60 | ctx.up_y = up_y
61 | ctx.down_x = down_x
62 | ctx.down_y = down_y
63 | ctx.pad_x0 = pad_x0
64 | ctx.pad_x1 = pad_x1
65 | ctx.pad_y0 = pad_y0
66 | ctx.pad_y1 = pad_y1
67 | ctx.in_size = in_size
68 | ctx.out_size = out_size
69 |
70 | return grad_input
71 |
72 | @staticmethod
73 | def backward(ctx, gradgrad_input):
74 | kernel, = ctx.saved_tensors
75 |
76 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
77 |
78 | gradgrad_out = upfirdn2d_ext.upfirdn2d(
79 | gradgrad_input,
80 | kernel,
81 | ctx.up_x,
82 | ctx.up_y,
83 | ctx.down_x,
84 | ctx.down_y,
85 | ctx.pad_x0,
86 | ctx.pad_x1,
87 | ctx.pad_y0,
88 | ctx.pad_y1,
89 | )
90 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
91 | # ctx.out_size[1], ctx.in_size[3])
92 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
93 |
94 | return gradgrad_out, None, None, None, None, None, None, None, None
95 |
96 |
97 | class UpFirDn2d(Function):
98 |
99 | @staticmethod
100 | def forward(ctx, input, kernel, up, down, pad):
101 | up_x, up_y = up
102 | down_x, down_y = down
103 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
104 |
105 | kernel_h, kernel_w = kernel.shape
106 | _, channel, in_h, in_w = input.shape
107 | ctx.in_size = input.shape
108 |
109 | input = input.reshape(-1, in_h, in_w, 1)
110 |
111 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
112 |
113 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
114 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
115 | ctx.out_size = (out_h, out_w)
116 |
117 | ctx.up = (up_x, up_y)
118 | ctx.down = (down_x, down_y)
119 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
120 |
121 | g_pad_x0 = kernel_w - pad_x0 - 1
122 | g_pad_y0 = kernel_h - pad_y0 - 1
123 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
124 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
125 |
126 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
127 |
128 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
129 | # out = out.view(major, out_h, out_w, minor)
130 | out = out.view(-1, channel, out_h, out_w)
131 |
132 | return out
133 |
134 | @staticmethod
135 | def backward(ctx, grad_output):
136 | kernel, grad_kernel = ctx.saved_tensors
137 |
138 | grad_input = UpFirDn2dBackward.apply(
139 | grad_output,
140 | kernel,
141 | grad_kernel,
142 | ctx.up,
143 | ctx.down,
144 | ctx.pad,
145 | ctx.g_pad,
146 | ctx.in_size,
147 | ctx.out_size,
148 | )
149 |
150 | return grad_input, None, None, None, None
151 |
152 |
153 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
154 | if input.device.type == 'cpu':
155 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
156 | else:
157 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
158 |
159 | return out
160 |
161 |
162 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
163 | _, channel, in_h, in_w = input.shape
164 | input = input.reshape(-1, in_h, in_w, 1)
165 |
166 | _, in_h, in_w, minor = input.shape
167 | kernel_h, kernel_w = kernel.shape
168 |
169 | out = input.view(-1, in_h, 1, in_w, 1, minor)
170 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
171 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
172 |
173 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
174 | out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
175 |
176 | out = out.permute(0, 3, 1, 2)
177 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
178 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
179 | out = F.conv2d(out, w)
180 | out = out.reshape(
181 | -1,
182 | minor,
183 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
184 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
185 | )
186 | out = out.permute(0, 2, 3, 1)
187 | out = out[:, ::down_y, ::down_x, :]
188 |
189 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
190 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
191 |
192 | return out.view(-1, channel, out_h, out_w)
193 |
--------------------------------------------------------------------------------
/basicsr/test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from os import path as osp
4 |
5 | from basicsr.data import build_dataloader, build_dataset
6 | from basicsr.models import build_model
7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs
8 | from basicsr.utils.options import dict2str, parse_options
9 |
10 |
11 | def test_pipeline(root_path):
12 | # parse options, set distributed setting, set ramdom seed
13 | opt, _ = parse_options(root_path, is_train=False)
14 |
15 | torch.backends.cudnn.benchmark = True
16 | # torch.backends.cudnn.deterministic = True
17 |
18 | # mkdir and initialize loggers
19 | make_exp_dirs(opt)
20 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log")
21 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
22 | logger.info(get_env_info())
23 | logger.info(dict2str(opt))
24 |
25 | # create test dataset and dataloader
26 | test_loaders = []
27 | for _, dataset_opt in sorted(opt['datasets'].items()):
28 | test_set = build_dataset(dataset_opt)
29 | test_loader = build_dataloader(
30 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
31 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
32 | test_loaders.append(test_loader)
33 |
34 | # create model
35 | model = build_model(opt)
36 |
37 | for test_loader in test_loaders:
38 | test_set_name = test_loader.dataset.opt['name']
39 | logger.info(f'Testing {test_set_name}...')
40 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'])
41 |
42 |
43 | if __name__ == '__main__':
44 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
45 | test_pipeline(root_path)
46 |
--------------------------------------------------------------------------------
/basicsr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
2 | from .diffjpeg import DiffJPEG
3 | from .file_client import FileClient
4 | from .img_process_util import USMSharp, usm_sharp
5 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
6 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
7 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
8 | from .options import yaml_load
9 |
10 | __all__ = [
11 | # color_util.py
12 | 'bgr2ycbcr',
13 | 'rgb2ycbcr',
14 | 'rgb2ycbcr_pt',
15 | 'ycbcr2bgr',
16 | 'ycbcr2rgb',
17 | # file_client.py
18 | 'FileClient',
19 | # img_util.py
20 | 'img2tensor',
21 | 'tensor2img',
22 | 'imfrombytes',
23 | 'imwrite',
24 | 'crop_border',
25 | # logger.py
26 | 'MessageLogger',
27 | 'AvgTimer',
28 | 'init_tb_logger',
29 | 'init_wandb_logger',
30 | 'get_root_logger',
31 | 'get_env_info',
32 | # misc.py
33 | 'set_random_seed',
34 | 'get_time_str',
35 | 'mkdir_and_rename',
36 | 'make_exp_dirs',
37 | 'scandir',
38 | 'check_resume',
39 | 'sizeof_fmt',
40 | # diffjpeg
41 | 'DiffJPEG',
42 | # img_process_util
43 | 'USMSharp',
44 | 'usm_sharp',
45 | # options
46 | 'yaml_load'
47 | ]
48 |
--------------------------------------------------------------------------------
/basicsr/utils/color_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def rgb2ycbcr(img, y_only=False):
6 | """Convert a RGB image to YCbCr image.
7 |
8 | This function produces the same results as Matlab's `rgb2ycbcr` function.
9 | It implements the ITU-R BT.601 conversion for standard-definition
10 | television. See more details in
11 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
12 |
13 | It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
14 | In OpenCV, it implements a JPEG conversion. See more details in
15 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
16 |
17 | Args:
18 | img (ndarray): The input image. It accepts:
19 | 1. np.uint8 type with range [0, 255];
20 | 2. np.float32 type with range [0, 1].
21 | y_only (bool): Whether to only return Y channel. Default: False.
22 |
23 | Returns:
24 | ndarray: The converted YCbCr image. The output image has the same type
25 | and range as input image.
26 | """
27 | img_type = img.dtype
28 | img = _convert_input_type_range(img)
29 | if y_only:
30 | out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
31 | else:
32 | out_img = np.matmul(
33 | img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
34 | out_img = _convert_output_type_range(out_img, img_type)
35 | return out_img
36 |
37 |
38 | def bgr2ycbcr(img, y_only=False):
39 | """Convert a BGR image to YCbCr image.
40 |
41 | The bgr version of rgb2ycbcr.
42 | It implements the ITU-R BT.601 conversion for standard-definition
43 | television. See more details in
44 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
45 |
46 | It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
47 | In OpenCV, it implements a JPEG conversion. See more details in
48 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
49 |
50 | Args:
51 | img (ndarray): The input image. It accepts:
52 | 1. np.uint8 type with range [0, 255];
53 | 2. np.float32 type with range [0, 1].
54 | y_only (bool): Whether to only return Y channel. Default: False.
55 |
56 | Returns:
57 | ndarray: The converted YCbCr image. The output image has the same type
58 | and range as input image.
59 | """
60 | img_type = img.dtype
61 | img = _convert_input_type_range(img)
62 | if y_only:
63 | out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
64 | else:
65 | out_img = np.matmul(
66 | img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
67 | out_img = _convert_output_type_range(out_img, img_type)
68 | return out_img
69 |
70 |
71 | def ycbcr2rgb(img):
72 | """Convert a YCbCr image to RGB image.
73 |
74 | This function produces the same results as Matlab's ycbcr2rgb function.
75 | It implements the ITU-R BT.601 conversion for standard-definition
76 | television. See more details in
77 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
78 |
79 | It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
80 | In OpenCV, it implements a JPEG conversion. See more details in
81 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
82 |
83 | Args:
84 | img (ndarray): The input image. It accepts:
85 | 1. np.uint8 type with range [0, 255];
86 | 2. np.float32 type with range [0, 1].
87 |
88 | Returns:
89 | ndarray: The converted RGB image. The output image has the same type
90 | and range as input image.
91 | """
92 | img_type = img.dtype
93 | img = _convert_input_type_range(img) * 255
94 | out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
95 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
96 | out_img = _convert_output_type_range(out_img, img_type)
97 | return out_img
98 |
99 |
100 | def ycbcr2bgr(img):
101 | """Convert a YCbCr image to BGR image.
102 |
103 | The bgr version of ycbcr2rgb.
104 | It implements the ITU-R BT.601 conversion for standard-definition
105 | television. See more details in
106 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
107 |
108 | It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
109 | In OpenCV, it implements a JPEG conversion. See more details in
110 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
111 |
112 | Args:
113 | img (ndarray): The input image. It accepts:
114 | 1. np.uint8 type with range [0, 255];
115 | 2. np.float32 type with range [0, 1].
116 |
117 | Returns:
118 | ndarray: The converted BGR image. The output image has the same type
119 | and range as input image.
120 | """
121 | img_type = img.dtype
122 | img = _convert_input_type_range(img) * 255
123 | out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
124 | [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
125 | out_img = _convert_output_type_range(out_img, img_type)
126 | return out_img
127 |
128 |
129 | def _convert_input_type_range(img):
130 | """Convert the type and range of the input image.
131 |
132 | It converts the input image to np.float32 type and range of [0, 1].
133 | It is mainly used for pre-processing the input image in colorspace
134 | conversion functions such as rgb2ycbcr and ycbcr2rgb.
135 |
136 | Args:
137 | img (ndarray): The input image. It accepts:
138 | 1. np.uint8 type with range [0, 255];
139 | 2. np.float32 type with range [0, 1].
140 |
141 | Returns:
142 | (ndarray): The converted image with type of np.float32 and range of
143 | [0, 1].
144 | """
145 | img_type = img.dtype
146 | img = img.astype(np.float32)
147 | if img_type == np.float32:
148 | pass
149 | elif img_type == np.uint8:
150 | img /= 255.
151 | else:
152 | raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
153 | return img
154 |
155 |
156 | def _convert_output_type_range(img, dst_type):
157 | """Convert the type and range of the image according to dst_type.
158 |
159 | It converts the image to desired type and range. If `dst_type` is np.uint8,
160 | images will be converted to np.uint8 type with range [0, 255]. If
161 | `dst_type` is np.float32, it converts the image to np.float32 type with
162 | range [0, 1].
163 | It is mainly used for post-processing images in colorspace conversion
164 | functions such as rgb2ycbcr and ycbcr2rgb.
165 |
166 | Args:
167 | img (ndarray): The image to be converted with np.float32 type and
168 | range [0, 255].
169 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
170 | converts the image to np.uint8 type with range [0, 255]. If
171 | dst_type is np.float32, it converts the image to np.float32 type
172 | with range [0, 1].
173 |
174 | Returns:
175 | (ndarray): The converted image with desired type and range.
176 | """
177 | if dst_type not in (np.uint8, np.float32):
178 | raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
179 | if dst_type == np.uint8:
180 | img = img.round()
181 | else:
182 | img /= 255.
183 | return img.astype(dst_type)
184 |
185 |
186 | def rgb2ycbcr_pt(img, y_only=False):
187 | """Convert RGB images to YCbCr images (PyTorch version).
188 |
189 | It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
190 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
191 |
192 | Args:
193 | img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
194 | y_only (bool): Whether to only return Y channel. Default: False.
195 |
196 | Returns:
197 | (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
198 | """
199 | if y_only:
200 | weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
201 | out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
202 | else:
203 | weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
204 | bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
205 | out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
206 |
207 | out_img = out_img / 255.
208 | return out_img
209 |
--------------------------------------------------------------------------------
/basicsr/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2 | import functools
3 | import os
4 | import subprocess
5 | import torch
6 | import torch.distributed as dist
7 | import torch.multiprocessing as mp
8 |
9 |
10 | def init_dist(launcher, backend='nccl', **kwargs):
11 | if mp.get_start_method(allow_none=True) is None:
12 | mp.set_start_method('spawn')
13 | if launcher == 'pytorch':
14 | _init_dist_pytorch(backend, **kwargs)
15 | elif launcher == 'slurm':
16 | _init_dist_slurm(backend, **kwargs)
17 | else:
18 | raise ValueError(f'Invalid launcher type: {launcher}')
19 |
20 |
21 | def _init_dist_pytorch(backend, **kwargs):
22 | rank = int(os.environ['RANK'])
23 | num_gpus = torch.cuda.device_count()
24 | torch.cuda.set_device(rank % num_gpus)
25 | dist.init_process_group(backend=backend, **kwargs)
26 |
27 |
28 | def _init_dist_slurm(backend, port=None):
29 | """Initialize slurm distributed training environment.
30 |
31 | If argument ``port`` is not specified, then the master port will be system
32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33 | environment variable, then a default port ``29500`` will be used.
34 |
35 | Args:
36 | backend (str): Backend of torch.distributed.
37 | port (int, optional): Master port. Defaults to None.
38 | """
39 | proc_id = int(os.environ['SLURM_PROCID'])
40 | ntasks = int(os.environ['SLURM_NTASKS'])
41 | node_list = os.environ['SLURM_NODELIST']
42 | num_gpus = torch.cuda.device_count()
43 | torch.cuda.set_device(proc_id % num_gpus)
44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
45 | # specify master port
46 | if port is not None:
47 | os.environ['MASTER_PORT'] = str(port)
48 | elif 'MASTER_PORT' in os.environ:
49 | pass # use MASTER_PORT in the environment variable
50 | else:
51 | # 29500 is torch.distributed default port
52 | os.environ['MASTER_PORT'] = '29500'
53 | os.environ['MASTER_ADDR'] = addr
54 | os.environ['WORLD_SIZE'] = str(ntasks)
55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
56 | os.environ['RANK'] = str(proc_id)
57 | dist.init_process_group(backend=backend)
58 |
59 |
60 | def get_dist_info():
61 | if dist.is_available():
62 | initialized = dist.is_initialized()
63 | else:
64 | initialized = False
65 | if initialized:
66 | rank = dist.get_rank()
67 | world_size = dist.get_world_size()
68 | else:
69 | rank = 0
70 | world_size = 1
71 | return rank, world_size
72 |
73 |
74 | def master_only(func):
75 |
76 | @functools.wraps(func)
77 | def wrapper(*args, **kwargs):
78 | rank, _ = get_dist_info()
79 | if rank == 0:
80 | return func(*args, **kwargs)
81 |
82 | return wrapper
83 |
--------------------------------------------------------------------------------
/basicsr/utils/download_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import requests
4 | from torch.hub import download_url_to_file, get_dir
5 | from tqdm import tqdm
6 | from urllib.parse import urlparse
7 |
8 | from .misc import sizeof_fmt
9 |
10 |
11 | def download_file_from_google_drive(file_id, save_path):
12 | """Download files from google drive.
13 |
14 | Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive
15 |
16 | Args:
17 | file_id (str): File id.
18 | save_path (str): Save path.
19 | """
20 |
21 | session = requests.Session()
22 | URL = 'https://docs.google.com/uc?export=download'
23 | params = {'id': file_id}
24 |
25 | response = session.get(URL, params=params, stream=True)
26 | token = get_confirm_token(response)
27 | if token:
28 | params['confirm'] = token
29 | response = session.get(URL, params=params, stream=True)
30 |
31 | # get file size
32 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
33 | if 'Content-Range' in response_file_size.headers:
34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
35 | else:
36 | file_size = None
37 |
38 | save_response_content(response, save_path, file_size)
39 |
40 |
41 | def get_confirm_token(response):
42 | for key, value in response.cookies.items():
43 | if key.startswith('download_warning'):
44 | return value
45 | return None
46 |
47 |
48 | def save_response_content(response, destination, file_size=None, chunk_size=32768):
49 | if file_size is not None:
50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
51 |
52 | readable_file_size = sizeof_fmt(file_size)
53 | else:
54 | pbar = None
55 |
56 | with open(destination, 'wb') as f:
57 | downloaded_size = 0
58 | for chunk in response.iter_content(chunk_size):
59 | downloaded_size += chunk_size
60 | if pbar is not None:
61 | pbar.update(1)
62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
63 | if chunk: # filter out keep-alive new chunks
64 | f.write(chunk)
65 | if pbar is not None:
66 | pbar.close()
67 |
68 |
69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
70 | """Load file form http url, will download models if necessary.
71 |
72 | Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
73 |
74 | Args:
75 | url (str): URL to be downloaded.
76 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
77 | Default: None.
78 | progress (bool): Whether to show the download progress. Default: True.
79 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
80 |
81 | Returns:
82 | str: The path to the downloaded file.
83 | """
84 | if model_dir is None: # use the pytorch hub_dir
85 | hub_dir = get_dir()
86 | model_dir = os.path.join(hub_dir, 'checkpoints')
87 |
88 | os.makedirs(model_dir, exist_ok=True)
89 |
90 | parts = urlparse(url)
91 | filename = os.path.basename(parts.path)
92 | if file_name is not None:
93 | filename = file_name
94 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
95 | if not os.path.exists(cached_file):
96 | print(f'Downloading: "{url}" to {cached_file}\n')
97 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
98 | return cached_file
99 |
--------------------------------------------------------------------------------
/basicsr/utils/file_client.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class BaseStorageBackend(metaclass=ABCMeta):
6 | """Abstract class of storage backends.
7 |
8 | All backends need to implement two apis: ``get()`` and ``get_text()``.
9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10 | as texts.
11 | """
12 |
13 | @abstractmethod
14 | def get(self, filepath):
15 | pass
16 |
17 | @abstractmethod
18 | def get_text(self, filepath):
19 | pass
20 |
21 |
22 | class MemcachedBackend(BaseStorageBackend):
23 | """Memcached storage backend.
24 |
25 | Attributes:
26 | server_list_cfg (str): Config file for memcached server list.
27 | client_cfg (str): Config file for memcached client.
28 | sys_path (str | None): Additional path to be appended to `sys.path`.
29 | Default: None.
30 | """
31 |
32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33 | if sys_path is not None:
34 | import sys
35 | sys.path.append(sys_path)
36 | try:
37 | import mc
38 | except ImportError:
39 | raise ImportError('Please install memcached to enable MemcachedBackend.')
40 |
41 | self.server_list_cfg = server_list_cfg
42 | self.client_cfg = client_cfg
43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
44 | # mc.pyvector servers as a point which points to a memory cache
45 | self._mc_buffer = mc.pyvector()
46 |
47 | def get(self, filepath):
48 | filepath = str(filepath)
49 | import mc
50 | self._client.Get(filepath, self._mc_buffer)
51 | value_buf = mc.ConvertBuffer(self._mc_buffer)
52 | return value_buf
53 |
54 | def get_text(self, filepath):
55 | raise NotImplementedError
56 |
57 |
58 | class HardDiskBackend(BaseStorageBackend):
59 | """Raw hard disks storage backend."""
60 |
61 | def get(self, filepath):
62 | filepath = str(filepath)
63 | with open(filepath, 'rb') as f:
64 | value_buf = f.read()
65 | return value_buf
66 |
67 | def get_text(self, filepath):
68 | filepath = str(filepath)
69 | with open(filepath, 'r') as f:
70 | value_buf = f.read()
71 | return value_buf
72 |
73 |
74 | class LmdbBackend(BaseStorageBackend):
75 | """Lmdb storage backend.
76 |
77 | Args:
78 | db_paths (str | list[str]): Lmdb database paths.
79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
80 | readonly (bool, optional): Lmdb environment parameter. If True,
81 | disallow any write operations. Default: True.
82 | lock (bool, optional): Lmdb environment parameter. If False, when
83 | concurrent access occurs, do not lock the database. Default: False.
84 | readahead (bool, optional): Lmdb environment parameter. If False,
85 | disable the OS filesystem readahead mechanism, which may improve
86 | random read performance when a database is larger than RAM.
87 | Default: False.
88 |
89 | Attributes:
90 | db_paths (list): Lmdb database path.
91 | _client (list): A list of several lmdb envs.
92 | """
93 |
94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
95 | try:
96 | import lmdb
97 | except ImportError:
98 | raise ImportError('Please install lmdb to enable LmdbBackend.')
99 |
100 | if isinstance(client_keys, str):
101 | client_keys = [client_keys]
102 |
103 | if isinstance(db_paths, list):
104 | self.db_paths = [str(v) for v in db_paths]
105 | elif isinstance(db_paths, str):
106 | self.db_paths = [str(db_paths)]
107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
108 | f'but received {len(client_keys)} and {len(self.db_paths)}.')
109 |
110 | self._client = {}
111 | for client, path in zip(client_keys, self.db_paths):
112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
113 |
114 | def get(self, filepath, client_key):
115 | """Get values according to the filepath from one lmdb named client_key.
116 |
117 | Args:
118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key.
119 | client_key (str): Used for distinguishing different lmdb envs.
120 | """
121 | filepath = str(filepath)
122 | assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
123 | client = self._client[client_key]
124 | with client.begin(write=False) as txn:
125 | value_buf = txn.get(filepath.encode('ascii'))
126 | return value_buf
127 |
128 | def get_text(self, filepath):
129 | raise NotImplementedError
130 |
131 |
132 | class FileClient(object):
133 | """A general file client to access files in different backend.
134 |
135 | The client loads a file or text in a specified backend from its path
136 | and return it as a binary file. it can also register other backend
137 | accessor with a given name and backend class.
138 |
139 | Attributes:
140 | backend (str): The storage backend type. Options are "disk",
141 | "memcached" and "lmdb".
142 | client (:obj:`BaseStorageBackend`): The backend object.
143 | """
144 |
145 | _backends = {
146 | 'disk': HardDiskBackend,
147 | 'memcached': MemcachedBackend,
148 | 'lmdb': LmdbBackend,
149 | }
150 |
151 | def __init__(self, backend='disk', **kwargs):
152 | if backend not in self._backends:
153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
154 | f' are {list(self._backends.keys())}')
155 | self.backend = backend
156 | self.client = self._backends[backend](**kwargs)
157 |
158 | def get(self, filepath, client_key='default'):
159 | # client_key is used only for lmdb, where different fileclients have
160 | # different lmdb environments.
161 | if self.backend == 'lmdb':
162 | return self.client.get(filepath, client_key)
163 | else:
164 | return self.client.get(filepath)
165 |
166 | def get_text(self, filepath):
167 | return self.client.get_text(filepath)
168 |
--------------------------------------------------------------------------------
/basicsr/utils/flow_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
2 | import cv2
3 | import numpy as np
4 | import os
5 |
6 |
7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
8 | """Read an optical flow map.
9 |
10 | Args:
11 | flow_path (ndarray or str): Flow path.
12 | quantize (bool): whether to read quantized pair, if set to True,
13 | remaining args will be passed to :func:`dequantize_flow`.
14 | concat_axis (int): The axis that dx and dy are concatenated,
15 | can be either 0 or 1. Ignored if quantize is False.
16 |
17 | Returns:
18 | ndarray: Optical flow represented as a (h, w, 2) numpy array
19 | """
20 | if quantize:
21 | assert concat_axis in [0, 1]
22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
23 | if cat_flow.ndim != 2:
24 | raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
25 | assert cat_flow.shape[concat_axis] % 2 == 0
26 | dx, dy = np.split(cat_flow, 2, axis=concat_axis)
27 | flow = dequantize_flow(dx, dy, *args, **kwargs)
28 | else:
29 | with open(flow_path, 'rb') as f:
30 | try:
31 | header = f.read(4).decode('utf-8')
32 | except Exception:
33 | raise IOError(f'Invalid flow file: {flow_path}')
34 | else:
35 | if header != 'PIEH':
36 | raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
37 |
38 | w = np.fromfile(f, np.int32, 1).squeeze()
39 | h = np.fromfile(f, np.int32, 1).squeeze()
40 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
41 |
42 | return flow.astype(np.float32)
43 |
44 |
45 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
46 | """Write optical flow to file.
47 |
48 | If the flow is not quantized, it will be saved as a .flo file losslessly,
49 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
50 | will be concatenated horizontally into a single image if quantize is True.)
51 |
52 | Args:
53 | flow (ndarray): (h, w, 2) array of optical flow.
54 | filename (str): Output filepath.
55 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg
56 | images. If set to True, remaining args will be passed to
57 | :func:`quantize_flow`.
58 | concat_axis (int): The axis that dx and dy are concatenated,
59 | can be either 0 or 1. Ignored if quantize is False.
60 | """
61 | if not quantize:
62 | with open(filename, 'wb') as f:
63 | f.write('PIEH'.encode('utf-8'))
64 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
65 | flow = flow.astype(np.float32)
66 | flow.tofile(f)
67 | f.flush()
68 | else:
69 | assert concat_axis in [0, 1]
70 | dx, dy = quantize_flow(flow, *args, **kwargs)
71 | dxdy = np.concatenate((dx, dy), axis=concat_axis)
72 | os.makedirs(os.path.dirname(filename), exist_ok=True)
73 | cv2.imwrite(filename, dxdy)
74 |
75 |
76 | def quantize_flow(flow, max_val=0.02, norm=True):
77 | """Quantize flow to [0, 255].
78 |
79 | After this step, the size of flow will be much smaller, and can be
80 | dumped as jpeg images.
81 |
82 | Args:
83 | flow (ndarray): (h, w, 2) array of optical flow.
84 | max_val (float): Maximum value of flow, values beyond
85 | [-max_val, max_val] will be truncated.
86 | norm (bool): Whether to divide flow values by image width/height.
87 |
88 | Returns:
89 | tuple[ndarray]: Quantized dx and dy.
90 | """
91 | h, w, _ = flow.shape
92 | dx = flow[..., 0]
93 | dy = flow[..., 1]
94 | if norm:
95 | dx = dx / w # avoid inplace operations
96 | dy = dy / h
97 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
98 | flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
99 | return tuple(flow_comps)
100 |
101 |
102 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
103 | """Recover from quantized flow.
104 |
105 | Args:
106 | dx (ndarray): Quantized dx.
107 | dy (ndarray): Quantized dy.
108 | max_val (float): Maximum value used when quantizing.
109 | denorm (bool): Whether to multiply flow values with width/height.
110 |
111 | Returns:
112 | ndarray: Dequantized flow.
113 | """
114 | assert dx.shape == dy.shape
115 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
116 |
117 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
118 |
119 | if denorm:
120 | dx *= dx.shape[1]
121 | dy *= dx.shape[0]
122 | flow = np.dstack((dx, dy))
123 | return flow
124 |
125 |
126 | def quantize(arr, min_val, max_val, levels, dtype=np.int64):
127 | """Quantize an array of (-inf, inf) to [0, levels-1].
128 |
129 | Args:
130 | arr (ndarray): Input array.
131 | min_val (scalar): Minimum value to be clipped.
132 | max_val (scalar): Maximum value to be clipped.
133 | levels (int): Quantization levels.
134 | dtype (np.type): The type of the quantized array.
135 |
136 | Returns:
137 | tuple: Quantized array.
138 | """
139 | if not (isinstance(levels, int) and levels > 1):
140 | raise ValueError(f'levels must be a positive integer, but got {levels}')
141 | if min_val >= max_val:
142 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
143 |
144 | arr = np.clip(arr, min_val, max_val) - min_val
145 | quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
146 |
147 | return quantized_arr
148 |
149 |
150 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
151 | """Dequantize an array.
152 |
153 | Args:
154 | arr (ndarray): Input array.
155 | min_val (scalar): Minimum value to be clipped.
156 | max_val (scalar): Maximum value to be clipped.
157 | levels (int): Quantization levels.
158 | dtype (np.type): The type of the dequantized array.
159 |
160 | Returns:
161 | tuple: Dequantized array.
162 | """
163 | if not (isinstance(levels, int) and levels > 1):
164 | raise ValueError(f'levels must be a positive integer, but got {levels}')
165 | if min_val >= max_val:
166 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
167 |
168 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
169 |
170 | return dequantized_arr
171 |
--------------------------------------------------------------------------------
/basicsr/utils/img_process_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | def filter2D(img, kernel):
8 | """PyTorch version of cv2.filter2D
9 |
10 | Args:
11 | img (Tensor): (b, c, h, w)
12 | kernel (Tensor): (b, k, k)
13 | """
14 | k = kernel.size(-1)
15 | b, c, h, w = img.size()
16 | if k % 2 == 1:
17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
18 | else:
19 | raise ValueError('Wrong kernel size')
20 |
21 | ph, pw = img.size()[-2:]
22 |
23 | if kernel.size(0) == 1:
24 | # apply the same kernel to all batch images
25 | img = img.view(b * c, 1, ph, pw)
26 | kernel = kernel.view(1, 1, k, k)
27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
28 | else:
29 | img = img.view(1, b * c, ph, pw)
30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
32 |
33 |
34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10):
35 | """USM sharpening.
36 |
37 | Input image: I; Blurry image: B.
38 | 1. sharp = I + weight * (I - B)
39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0
40 | 3. Blur mask:
41 | 4. Out = Mask * sharp + (1 - Mask) * I
42 |
43 |
44 | Args:
45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
46 | weight (float): Sharp weight. Default: 1.
47 | radius (float): Kernel size of Gaussian blur. Default: 50.
48 | threshold (int):
49 | """
50 | if radius % 2 == 0:
51 | radius += 1
52 | blur = cv2.GaussianBlur(img, (radius, radius), 0)
53 | residual = img - blur
54 | mask = np.abs(residual) * 255 > threshold
55 | mask = mask.astype('float32')
56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
57 |
58 | sharp = img + weight * residual
59 | sharp = np.clip(sharp, 0, 1)
60 | return soft_mask * sharp + (1 - soft_mask) * img
61 |
62 |
63 | class USMSharp(torch.nn.Module):
64 |
65 | def __init__(self, radius=50, sigma=0):
66 | super(USMSharp, self).__init__()
67 | if radius % 2 == 0:
68 | radius += 1
69 | self.radius = radius
70 | kernel = cv2.getGaussianKernel(radius, sigma)
71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
72 | self.register_buffer('kernel', kernel)
73 |
74 | def forward(self, img, weight=0.5, threshold=10):
75 | blur = filter2D(img, self.kernel)
76 | residual = img - blur
77 |
78 | mask = torch.abs(residual) * 255 > threshold
79 | mask = mask.float()
80 | soft_mask = filter2D(mask, self.kernel)
81 | sharp = img + weight * residual
82 | sharp = torch.clip(sharp, 0, 1)
83 | return soft_mask * sharp + (1 - soft_mask) * img
84 |
--------------------------------------------------------------------------------
/basicsr/utils/img_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | import os
5 | import torch
6 | from torchvision.utils import make_grid
7 |
8 |
9 | def img2tensor(imgs, bgr2rgb=True, float32=True):
10 | """Numpy array to tensor.
11 |
12 | Args:
13 | imgs (list[ndarray] | ndarray): Input images.
14 | bgr2rgb (bool): Whether to change bgr to rgb.
15 | float32 (bool): Whether to change to float32.
16 |
17 | Returns:
18 | list[tensor] | tensor: Tensor images. If returned results only have
19 | one element, just return tensor.
20 | """
21 |
22 | def _totensor(img, bgr2rgb, float32):
23 | if img.shape[2] == 3 and bgr2rgb:
24 | if img.dtype == 'float64':
25 | img = img.astype('float32')
26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27 | img = torch.from_numpy(img.transpose(2, 0, 1))
28 | if float32:
29 | img = img.float()
30 | return img
31 |
32 | if isinstance(imgs, list):
33 | return [_totensor(img, bgr2rgb, float32) for img in imgs]
34 | else:
35 | return _totensor(imgs, bgr2rgb, float32)
36 |
37 |
38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
39 | """Convert torch Tensors into image numpy arrays.
40 |
41 | After clamping to [min, max], values will be normalized to [0, 1].
42 |
43 | Args:
44 | tensor (Tensor or list[Tensor]): Accept shapes:
45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
46 | 2) 3D Tensor of shape (3/1 x H x W);
47 | 3) 2D Tensor of shape (H x W).
48 | Tensor channel should be in RGB order.
49 | rgb2bgr (bool): Whether to change rgb to bgr.
50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs
51 | to uint8 type with range [0, 255]; otherwise, float type with
52 | range [0, 1]. Default: ``np.uint8``.
53 | min_max (tuple[int]): min and max values for clamp.
54 |
55 | Returns:
56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
57 | shape (H x W). The channel order is BGR.
58 | """
59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
61 |
62 | if torch.is_tensor(tensor):
63 | tensor = [tensor]
64 | result = []
65 | for _tensor in tensor:
66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
68 |
69 | n_dim = _tensor.dim()
70 | if n_dim == 4:
71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
72 | img_np = img_np.transpose(1, 2, 0)
73 | if rgb2bgr:
74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
75 | elif n_dim == 3:
76 | img_np = _tensor.numpy()
77 | img_np = img_np.transpose(1, 2, 0)
78 | if img_np.shape[2] == 1: # gray image
79 | img_np = np.squeeze(img_np, axis=2)
80 | else:
81 | if rgb2bgr:
82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
83 | elif n_dim == 2:
84 | img_np = _tensor.numpy()
85 | else:
86 | raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
87 | if out_type == np.uint8:
88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
89 | img_np = (img_np * 255.0).round()
90 | img_np = img_np.astype(out_type)
91 | result.append(img_np)
92 | if len(result) == 1:
93 | result = result[0]
94 | return result
95 |
96 |
97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
98 | """This implementation is slightly faster than tensor2img.
99 | It now only supports torch tensor with shape (1, c, h, w).
100 |
101 | Args:
102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w).
103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
104 | min_max (tuple[int]): min and max values for clamp.
105 | """
106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
108 | output = output.type(torch.uint8).cpu().numpy()
109 | if rgb2bgr:
110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
111 | return output
112 |
113 |
114 | def imfrombytes(content, flag='color', float32=False):
115 | """Read an image from bytes.
116 |
117 | Args:
118 | content (bytes): Image bytes got from files or other streams.
119 | flag (str): Flags specifying the color type of a loaded image,
120 | candidates are `color`, `grayscale` and `unchanged`.
121 | float32 (bool): Whether to change to float32., If True, will also norm
122 | to [0, 1]. Default: False.
123 |
124 | Returns:
125 | ndarray: Loaded image array.
126 | """
127 | img_np = np.frombuffer(content, np.uint8)
128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
129 | img = cv2.imdecode(img_np, imread_flags[flag])
130 | if float32:
131 | img = img.astype(np.float32) / 255.
132 | return img
133 |
134 |
135 | def imwrite(img, file_path, params=None, auto_mkdir=True):
136 | """Write image to file.
137 |
138 | Args:
139 | img (ndarray): Image array to be written.
140 | file_path (str): Image file path.
141 | params (None or list): Same as opencv's :func:`imwrite` interface.
142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist,
143 | whether to create it automatically.
144 |
145 | Returns:
146 | bool: Successful or not.
147 | """
148 | if auto_mkdir:
149 | dir_name = os.path.abspath(os.path.dirname(file_path))
150 | os.makedirs(dir_name, exist_ok=True)
151 | ok = cv2.imwrite(file_path, img, params)
152 | if not ok:
153 | raise IOError('Failed in writing images.')
154 |
155 |
156 | def crop_border(imgs, crop_border):
157 | """Crop borders of images.
158 |
159 | Args:
160 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
161 | crop_border (int): Crop border for each end of height and weight.
162 |
163 | Returns:
164 | list[ndarray]: Cropped images.
165 | """
166 | if crop_border == 0:
167 | return imgs
168 | else:
169 | if isinstance(imgs, list):
170 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
171 | else:
172 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
173 |
--------------------------------------------------------------------------------
/basicsr/utils/lmdb_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import lmdb
3 | import sys
4 | from multiprocessing import Pool
5 | from os import path as osp
6 | from tqdm import tqdm
7 |
8 |
9 | def make_lmdb_from_imgs(data_path,
10 | lmdb_path,
11 | img_path_list,
12 | keys,
13 | batch=5000,
14 | compress_level=1,
15 | multiprocessing_read=False,
16 | n_thread=40,
17 | map_size=None):
18 | """Make lmdb from images.
19 |
20 | Contents of lmdb. The file structure is:
21 |
22 | ::
23 |
24 | example.lmdb
25 | ├── data.mdb
26 | ├── lock.mdb
27 | ├── meta_info.txt
28 |
29 | The data.mdb and lock.mdb are standard lmdb files and you can refer to
30 | https://lmdb.readthedocs.io/en/release/ for more details.
31 |
32 | The meta_info.txt is a specified txt file to record the meta information
33 | of our datasets. It will be automatically created when preparing
34 | datasets by our provided dataset tools.
35 | Each line in the txt file records 1)image name (with extension),
36 | 2)image shape, and 3)compression level, separated by a white space.
37 |
38 | For example, the meta information could be:
39 | `000_00000000.png (720,1280,3) 1`, which means:
40 | 1) image name (with extension): 000_00000000.png;
41 | 2) image shape: (720,1280,3);
42 | 3) compression level: 1
43 |
44 | We use the image name without extension as the lmdb key.
45 |
46 | If `multiprocessing_read` is True, it will read all the images to memory
47 | using multiprocessing. Thus, your server needs to have enough memory.
48 |
49 | Args:
50 | data_path (str): Data path for reading images.
51 | lmdb_path (str): Lmdb save path.
52 | img_path_list (str): Image path list.
53 | keys (str): Used for lmdb keys.
54 | batch (int): After processing batch images, lmdb commits.
55 | Default: 5000.
56 | compress_level (int): Compress level when encoding images. Default: 1.
57 | multiprocessing_read (bool): Whether use multiprocessing to read all
58 | the images to memory. Default: False.
59 | n_thread (int): For multiprocessing.
60 | map_size (int | None): Map size for lmdb env. If None, use the
61 | estimated size from images. Default: None
62 | """
63 |
64 | assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
65 | f'but got {len(img_path_list)} and {len(keys)}')
66 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
67 | print(f'Totoal images: {len(img_path_list)}')
68 | if not lmdb_path.endswith('.lmdb'):
69 | raise ValueError("lmdb_path must end with '.lmdb'.")
70 | if osp.exists(lmdb_path):
71 | print(f'Folder {lmdb_path} already exists. Exit.')
72 | sys.exit(1)
73 |
74 | if multiprocessing_read:
75 | # read all the images to memory (multiprocessing)
76 | dataset = {} # use dict to keep the order for multiprocessing
77 | shapes = {}
78 | print(f'Read images with multiprocessing, #thread: {n_thread} ...')
79 | pbar = tqdm(total=len(img_path_list), unit='image')
80 |
81 | def callback(arg):
82 | """get the image data and update pbar."""
83 | key, dataset[key], shapes[key] = arg
84 | pbar.update(1)
85 | pbar.set_description(f'Read {key}')
86 |
87 | pool = Pool(n_thread)
88 | for path, key in zip(img_path_list, keys):
89 | pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
90 | pool.close()
91 | pool.join()
92 | pbar.close()
93 | print(f'Finish reading {len(img_path_list)} images.')
94 |
95 | # create lmdb environment
96 | if map_size is None:
97 | # obtain data size for one image
98 | img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
99 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
100 | data_size_per_img = img_byte.nbytes
101 | print('Data size per image is: ', data_size_per_img)
102 | data_size = data_size_per_img * len(img_path_list)
103 | map_size = data_size * 10
104 |
105 | env = lmdb.open(lmdb_path, map_size=map_size)
106 |
107 | # write data to lmdb
108 | pbar = tqdm(total=len(img_path_list), unit='chunk')
109 | txn = env.begin(write=True)
110 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
111 | for idx, (path, key) in enumerate(zip(img_path_list, keys)):
112 | pbar.update(1)
113 | pbar.set_description(f'Write {key}')
114 | key_byte = key.encode('ascii')
115 | if multiprocessing_read:
116 | img_byte = dataset[key]
117 | h, w, c = shapes[key]
118 | else:
119 | _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
120 | h, w, c = img_shape
121 |
122 | txn.put(key_byte, img_byte)
123 | # write meta information
124 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
125 | if idx % batch == 0:
126 | txn.commit()
127 | txn = env.begin(write=True)
128 | pbar.close()
129 | txn.commit()
130 | env.close()
131 | txt_file.close()
132 | print('\nFinish writing lmdb.')
133 |
134 |
135 | def read_img_worker(path, key, compress_level):
136 | """Read image worker.
137 |
138 | Args:
139 | path (str): Image path.
140 | key (str): Image key.
141 | compress_level (int): Compress level when encoding images.
142 |
143 | Returns:
144 | str: Image key.
145 | byte: Image byte.
146 | tuple[int]: Image shape.
147 | """
148 |
149 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
150 | if img.ndim == 2:
151 | h, w = img.shape
152 | c = 1
153 | else:
154 | h, w, c = img.shape
155 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
156 | return (key, img_byte, (h, w, c))
157 |
158 |
159 | class LmdbMaker():
160 | """LMDB Maker.
161 |
162 | Args:
163 | lmdb_path (str): Lmdb save path.
164 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
165 | batch (int): After processing batch images, lmdb commits.
166 | Default: 5000.
167 | compress_level (int): Compress level when encoding images. Default: 1.
168 | """
169 |
170 | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
171 | if not lmdb_path.endswith('.lmdb'):
172 | raise ValueError("lmdb_path must end with '.lmdb'.")
173 | if osp.exists(lmdb_path):
174 | print(f'Folder {lmdb_path} already exists. Exit.')
175 | sys.exit(1)
176 |
177 | self.lmdb_path = lmdb_path
178 | self.batch = batch
179 | self.compress_level = compress_level
180 | self.env = lmdb.open(lmdb_path, map_size=map_size)
181 | self.txn = self.env.begin(write=True)
182 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
183 | self.counter = 0
184 |
185 | def put(self, img_byte, key, img_shape):
186 | self.counter += 1
187 | key_byte = key.encode('ascii')
188 | self.txn.put(key_byte, img_byte)
189 | # write meta information
190 | h, w, c = img_shape
191 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
192 | if self.counter % self.batch == 0:
193 | self.txn.commit()
194 | self.txn = self.env.begin(write=True)
195 |
196 | def close(self):
197 | self.txn.commit()
198 | self.env.close()
199 | self.txt_file.close()
200 |
--------------------------------------------------------------------------------
/basicsr/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 |
5 | from .dist_util import get_dist_info, master_only
6 |
7 | initialized_logger = {}
8 |
9 |
10 | class AvgTimer():
11 |
12 | def __init__(self, window=200):
13 | self.window = window # average window
14 | self.current_time = 0
15 | self.total_time = 0
16 | self.count = 0
17 | self.avg_time = 0
18 | self.start()
19 |
20 | def start(self):
21 | self.start_time = self.tic = time.time()
22 |
23 | def record(self):
24 | self.count += 1
25 | self.toc = time.time()
26 | self.current_time = self.toc - self.tic
27 | self.total_time += self.current_time
28 | # calculate average time
29 | self.avg_time = self.total_time / self.count
30 |
31 | # reset
32 | if self.count > self.window:
33 | self.count = 0
34 | self.total_time = 0
35 |
36 | self.tic = time.time()
37 |
38 | def get_current_time(self):
39 | return self.current_time
40 |
41 | def get_avg_time(self):
42 | return self.avg_time
43 |
44 |
45 | class MessageLogger():
46 | """Message logger for printing.
47 |
48 | Args:
49 | opt (dict): Config. It contains the following keys:
50 | name (str): Exp name.
51 | logger (dict): Contains 'print_freq' (str) for logger interval.
52 | train (dict): Contains 'total_iter' (int) for total iters.
53 | use_tb_logger (bool): Use tensorboard logger.
54 | start_iter (int): Start iter. Default: 1.
55 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
56 | """
57 |
58 | def __init__(self, opt, start_iter=1, tb_logger=None):
59 | self.exp_name = opt['name']
60 | self.interval = opt['logger']['print_freq']
61 | self.start_iter = start_iter
62 | self.max_iters = opt['train']['total_iter']
63 | self.use_tb_logger = opt['logger']['use_tb_logger']
64 | self.tb_logger = tb_logger
65 | self.start_time = time.time()
66 | self.logger = get_root_logger()
67 |
68 | def reset_start_time(self):
69 | self.start_time = time.time()
70 |
71 | @master_only
72 | def __call__(self, log_vars):
73 | """Format logging message.
74 |
75 | Args:
76 | log_vars (dict): It contains the following keys:
77 | epoch (int): Epoch number.
78 | iter (int): Current iter.
79 | lrs (list): List for learning rates.
80 |
81 | time (float): Iter time.
82 | data_time (float): Data time for each iter.
83 | """
84 | # epoch, iter, learning rates
85 | epoch = log_vars.pop('epoch')
86 | current_iter = log_vars.pop('iter')
87 | lrs = log_vars.pop('lrs')
88 |
89 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
90 | for v in lrs:
91 | message += f'{v:.3e},'
92 | message += ')] '
93 |
94 | # time and estimated time
95 | if 'time' in log_vars.keys():
96 | iter_time = log_vars.pop('time')
97 | data_time = log_vars.pop('data_time')
98 |
99 | total_time = time.time() - self.start_time
100 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
101 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
102 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
103 | message += f'[eta: {eta_str}, '
104 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
105 |
106 | # other items, especially losses
107 | for k, v in log_vars.items():
108 | message += f'{k}: {v:.4e} '
109 | # tensorboard logger
110 | if self.use_tb_logger and 'debug' not in self.exp_name:
111 | if k.startswith('l_'):
112 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
113 | else:
114 | self.tb_logger.add_scalar(k, v, current_iter)
115 | self.logger.info(message)
116 |
117 |
118 | @master_only
119 | def init_tb_logger(log_dir):
120 | from torch.utils.tensorboard import SummaryWriter
121 | tb_logger = SummaryWriter(log_dir=log_dir)
122 | return tb_logger
123 |
124 |
125 | @master_only
126 | def init_wandb_logger(opt):
127 | """We now only use wandb to sync tensorboard log."""
128 | import wandb
129 | logger = get_root_logger()
130 |
131 | project = opt['logger']['wandb']['project']
132 | resume_id = opt['logger']['wandb'].get('resume_id')
133 | if resume_id:
134 | wandb_id = resume_id
135 | resume = 'allow'
136 | logger.warning(f'Resume wandb logger with id={wandb_id}.')
137 | else:
138 | wandb_id = wandb.util.generate_id()
139 | resume = 'never'
140 |
141 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
142 |
143 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
144 |
145 |
146 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
147 | """Get the root logger.
148 |
149 | The logger will be initialized if it has not been initialized. By default a
150 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
151 | also be added.
152 |
153 | Args:
154 | logger_name (str): root logger name. Default: 'basicsr'.
155 | log_file (str | None): The log filename. If specified, a FileHandler
156 | will be added to the root logger.
157 | log_level (int): The root logger level. Note that only the process of
158 | rank 0 is affected, while other processes will set the level to
159 | "Error" and be silent most of the time.
160 |
161 | Returns:
162 | logging.Logger: The root logger.
163 | """
164 | logger = logging.getLogger(logger_name)
165 | # if the logger has been initialized, just return it
166 | if logger_name in initialized_logger:
167 | return logger
168 |
169 | format_str = '%(asctime)s %(levelname)s: %(message)s'
170 | stream_handler = logging.StreamHandler()
171 | stream_handler.setFormatter(logging.Formatter(format_str))
172 | logger.addHandler(stream_handler)
173 | logger.propagate = False
174 | rank, _ = get_dist_info()
175 | if rank != 0:
176 | logger.setLevel('ERROR')
177 | elif log_file is not None:
178 | logger.setLevel(log_level)
179 | # add file handler
180 | file_handler = logging.FileHandler(log_file, 'w')
181 | file_handler.setFormatter(logging.Formatter(format_str))
182 | file_handler.setLevel(log_level)
183 | logger.addHandler(file_handler)
184 | initialized_logger[logger_name] = True
185 | return logger
186 |
187 |
188 | def get_env_info():
189 | """Get environment information.
190 |
191 | Currently, only log the software version.
192 | """
193 | import torch
194 | import torchvision
195 |
196 | from basicsr.version import __version__
197 | msg = r"""
198 | ____ _ _____ ____
199 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \
200 | / __ |/ __ `// ___// // ___/\__ \ / /_/ /
201 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
202 | /_____/ \__,_//____//_/ \___//____//_/ |_|
203 | ______ __ __ __ __
204 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
205 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
206 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
207 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
208 | """
209 | msg += ('\nVersion Information: '
210 | f'\n\tBasicSR: {__version__}'
211 | f'\n\tPyTorch: {torch.__version__}'
212 | f'\n\tTorchVision: {torchvision.__version__}')
213 | return msg
214 |
--------------------------------------------------------------------------------
/basicsr/utils/matlab_functions.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def cubic(x):
7 | """cubic function used for calculate_weights_indices."""
8 | absx = torch.abs(x)
9 | absx2 = absx**2
10 | absx3 = absx**3
11 | return (1.5 * absx3 - 2.5 * absx2 + 1) * (
12 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
13 | (absx <= 2)).type_as(absx))
14 |
15 |
16 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
17 | """Calculate weights and indices, used for imresize function.
18 |
19 | Args:
20 | in_length (int): Input length.
21 | out_length (int): Output length.
22 | scale (float): Scale factor.
23 | kernel_width (int): Kernel width.
24 | antialisaing (bool): Whether to apply anti-aliasing when downsampling.
25 | """
26 |
27 | if (scale < 1) and antialiasing:
28 | # Use a modified kernel (larger kernel width) to simultaneously
29 | # interpolate and antialias
30 | kernel_width = kernel_width / scale
31 |
32 | # Output-space coordinates
33 | x = torch.linspace(1, out_length, out_length)
34 |
35 | # Input-space coordinates. Calculate the inverse mapping such that 0.5
36 | # in output space maps to 0.5 in input space, and 0.5 + scale in output
37 | # space maps to 1.5 in input space.
38 | u = x / scale + 0.5 * (1 - 1 / scale)
39 |
40 | # What is the left-most pixel that can be involved in the computation?
41 | left = torch.floor(u - kernel_width / 2)
42 |
43 | # What is the maximum number of pixels that can be involved in the
44 | # computation? Note: it's OK to use an extra pixel here; if the
45 | # corresponding weights are all zero, it will be eliminated at the end
46 | # of this function.
47 | p = math.ceil(kernel_width) + 2
48 |
49 | # The indices of the input pixels involved in computing the k-th output
50 | # pixel are in row k of the indices matrix.
51 | indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
52 | out_length, p)
53 |
54 | # The weights used to compute the k-th output pixel are in row k of the
55 | # weights matrix.
56 | distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
57 |
58 | # apply cubic kernel
59 | if (scale < 1) and antialiasing:
60 | weights = scale * cubic(distance_to_center * scale)
61 | else:
62 | weights = cubic(distance_to_center)
63 |
64 | # Normalize the weights matrix so that each row sums to 1.
65 | weights_sum = torch.sum(weights, 1).view(out_length, 1)
66 | weights = weights / weights_sum.expand(out_length, p)
67 |
68 | # If a column in weights is all zero, get rid of it. only consider the
69 | # first and last column.
70 | weights_zero_tmp = torch.sum((weights == 0), 0)
71 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
72 | indices = indices.narrow(1, 1, p - 2)
73 | weights = weights.narrow(1, 1, p - 2)
74 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
75 | indices = indices.narrow(1, 0, p - 2)
76 | weights = weights.narrow(1, 0, p - 2)
77 | weights = weights.contiguous()
78 | indices = indices.contiguous()
79 | sym_len_s = -indices.min() + 1
80 | sym_len_e = indices.max() - in_length
81 | indices = indices + sym_len_s - 1
82 | return weights, indices, int(sym_len_s), int(sym_len_e)
83 |
84 |
85 | @torch.no_grad()
86 | def imresize(img, scale, antialiasing=True):
87 | """imresize function same as MATLAB.
88 |
89 | It now only supports bicubic.
90 | The same scale applies for both height and width.
91 |
92 | Args:
93 | img (Tensor | Numpy array):
94 | Tensor: Input image with shape (c, h, w), [0, 1] range.
95 | Numpy: Input image with shape (h, w, c), [0, 1] range.
96 | scale (float): Scale factor. The same scale applies for both height
97 | and width.
98 | antialisaing (bool): Whether to apply anti-aliasing when downsampling.
99 | Default: True.
100 |
101 | Returns:
102 | Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
103 | """
104 | squeeze_flag = False
105 | if type(img).__module__ == np.__name__: # numpy type
106 | numpy_type = True
107 | if img.ndim == 2:
108 | img = img[:, :, None]
109 | squeeze_flag = True
110 | img = torch.from_numpy(img.transpose(2, 0, 1)).float()
111 | else:
112 | numpy_type = False
113 | if img.ndim == 2:
114 | img = img.unsqueeze(0)
115 | squeeze_flag = True
116 |
117 | in_c, in_h, in_w = img.size()
118 | out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
119 | kernel_width = 4
120 | kernel = 'cubic'
121 |
122 | # get weights and indices
123 | weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
124 | antialiasing)
125 | weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
126 | antialiasing)
127 | # process H dimension
128 | # symmetric copying
129 | img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
130 | img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
131 |
132 | sym_patch = img[:, :sym_len_hs, :]
133 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
134 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
135 | img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
136 |
137 | sym_patch = img[:, -sym_len_he:, :]
138 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
139 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
140 | img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
141 |
142 | out_1 = torch.FloatTensor(in_c, out_h, in_w)
143 | kernel_width = weights_h.size(1)
144 | for i in range(out_h):
145 | idx = int(indices_h[i][0])
146 | for j in range(in_c):
147 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
148 |
149 | # process W dimension
150 | # symmetric copying
151 | out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
152 | out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
153 |
154 | sym_patch = out_1[:, :, :sym_len_ws]
155 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
156 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
157 | out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
158 |
159 | sym_patch = out_1[:, :, -sym_len_we:]
160 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
161 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
162 | out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
163 |
164 | out_2 = torch.FloatTensor(in_c, out_h, out_w)
165 | kernel_width = weights_w.size(1)
166 | for i in range(out_w):
167 | idx = int(indices_w[i][0])
168 | for j in range(in_c):
169 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
170 |
171 | if squeeze_flag:
172 | out_2 = out_2.squeeze(0)
173 | if numpy_type:
174 | out_2 = out_2.numpy()
175 | if not squeeze_flag:
176 | out_2 = out_2.transpose(1, 2, 0)
177 |
178 | return out_2
179 |
--------------------------------------------------------------------------------
/basicsr/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import time
5 | import torch
6 | from os import path as osp
7 |
8 | from .dist_util import master_only
9 |
10 |
11 | def set_random_seed(seed):
12 | """Set random seeds."""
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 |
19 |
20 | def get_time_str():
21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
22 |
23 |
24 | def mkdir_and_rename(path):
25 | """mkdirs. If path exists, rename it with timestamp and create a new one.
26 |
27 | Args:
28 | path (str): Folder path.
29 | """
30 | if osp.exists(path):
31 | new_name = path + '_archived_' + get_time_str()
32 | print(f'Path already exists. Rename it to {new_name}', flush=True)
33 | os.rename(path, new_name)
34 | os.makedirs(path, exist_ok=True)
35 |
36 |
37 | @master_only
38 | def make_exp_dirs(opt):
39 | """Make dirs for experiments."""
40 | path_opt = opt['path'].copy()
41 | if opt['is_train']:
42 | mkdir_and_rename(path_opt.pop('experiments_root'))
43 | else:
44 | mkdir_and_rename(path_opt.pop('results_root'))
45 | for key, path in path_opt.items():
46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
47 | continue
48 | else:
49 | os.makedirs(path, exist_ok=True)
50 |
51 |
52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
53 | """Scan a directory to find the interested files.
54 |
55 | Args:
56 | dir_path (str): Path of the directory.
57 | suffix (str | tuple(str), optional): File suffix that we are
58 | interested in. Default: None.
59 | recursive (bool, optional): If set to True, recursively scan the
60 | directory. Default: False.
61 | full_path (bool, optional): If set to True, include the dir_path.
62 | Default: False.
63 |
64 | Returns:
65 | A generator for all the interested files with relative paths.
66 | """
67 |
68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
69 | raise TypeError('"suffix" must be a string or tuple of strings')
70 |
71 | root = dir_path
72 |
73 | def _scandir(dir_path, suffix, recursive):
74 | for entry in os.scandir(dir_path):
75 | if not entry.name.startswith('.') and entry.is_file():
76 | if full_path:
77 | return_path = entry.path
78 | else:
79 | return_path = osp.relpath(entry.path, root)
80 |
81 | if suffix is None:
82 | yield return_path
83 | elif return_path.endswith(suffix):
84 | yield return_path
85 | else:
86 | if recursive:
87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
88 | else:
89 | continue
90 |
91 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
92 |
93 |
94 | def check_resume(opt, resume_iter):
95 | """Check resume states and pretrain_network paths.
96 |
97 | Args:
98 | opt (dict): Options.
99 | resume_iter (int): Resume iteration.
100 | """
101 | if opt['path']['resume_state']:
102 | # get all the networks
103 | networks = [key for key in opt.keys() if key.startswith('network_')]
104 | flag_pretrain = False
105 | for network in networks:
106 | if opt['path'].get(f'pretrain_{network}') is not None:
107 | flag_pretrain = True
108 | if flag_pretrain:
109 | print('pretrain_network path will be ignored during resuming.')
110 | # set pretrained model paths
111 | for network in networks:
112 | name = f'pretrain_{network}'
113 | basename = network.replace('network_', '')
114 | if opt['path'].get('ignore_resume_networks') is None or (network
115 | not in opt['path']['ignore_resume_networks']):
116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
117 | print(f"Set {name} to {opt['path'][name]}")
118 |
119 | # change param_key to params in resume
120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
121 | for param_key in param_keys:
122 | if opt['path'][param_key] == 'params_ema':
123 | opt['path'][param_key] = 'params'
124 | print(f'Set {param_key} to params')
125 |
126 |
127 | def sizeof_fmt(size, suffix='B'):
128 | """Get human readable file size.
129 |
130 | Args:
131 | size (int): File size.
132 | suffix (str): Suffix. Default: 'B'.
133 |
134 | Return:
135 | str: Formatted file size.
136 | """
137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
138 | if abs(size) < 1024.0:
139 | return f'{size:3.1f} {unit}{suffix}'
140 | size /= 1024.0
141 | return f'{size:3.1f} Y{suffix}'
142 |
--------------------------------------------------------------------------------
/basicsr/utils/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import torch
5 | import yaml
6 | from collections import OrderedDict
7 | from os import path as osp
8 |
9 | from basicsr.utils import set_random_seed
10 | from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
11 |
12 |
13 | def ordered_yaml():
14 | """Support OrderedDict for yaml.
15 |
16 | Returns:
17 | tuple: yaml Loader and Dumper.
18 | """
19 | try:
20 | from yaml import CDumper as Dumper
21 | from yaml import CLoader as Loader
22 | except ImportError:
23 | from yaml import Dumper, Loader
24 |
25 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
26 |
27 | def dict_representer(dumper, data):
28 | return dumper.represent_dict(data.items())
29 |
30 | def dict_constructor(loader, node):
31 | return OrderedDict(loader.construct_pairs(node))
32 |
33 | Dumper.add_representer(OrderedDict, dict_representer)
34 | Loader.add_constructor(_mapping_tag, dict_constructor)
35 | return Loader, Dumper
36 |
37 |
38 | def yaml_load(f):
39 | """Load yaml file or string.
40 |
41 | Args:
42 | f (str): File path or a python string.
43 |
44 | Returns:
45 | dict: Loaded dict.
46 | """
47 | if os.path.isfile(f):
48 | with open(f, 'r') as f:
49 | return yaml.load(f, Loader=ordered_yaml()[0])
50 | else:
51 | return yaml.load(f, Loader=ordered_yaml()[0])
52 |
53 |
54 | def dict2str(opt, indent_level=1):
55 | """dict to string for printing options.
56 |
57 | Args:
58 | opt (dict): Option dict.
59 | indent_level (int): Indent level. Default: 1.
60 |
61 | Return:
62 | (str): Option string for printing.
63 | """
64 | msg = '\n'
65 | for k, v in opt.items():
66 | if isinstance(v, dict):
67 | msg += ' ' * (indent_level * 2) + k + ':['
68 | msg += dict2str(v, indent_level + 1)
69 | msg += ' ' * (indent_level * 2) + ']\n'
70 | else:
71 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
72 | return msg
73 |
74 |
75 | def _postprocess_yml_value(value):
76 | # None
77 | if value == '~' or value.lower() == 'none':
78 | return None
79 | # bool
80 | if value.lower() == 'true':
81 | return True
82 | elif value.lower() == 'false':
83 | return False
84 | # !!float number
85 | if value.startswith('!!float'):
86 | return float(value.replace('!!float', ''))
87 | # number
88 | if value.isdigit():
89 | return int(value)
90 | elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
91 | return float(value)
92 | # list
93 | if value.startswith('['):
94 | return eval(value)
95 | # str
96 | return value
97 |
98 |
99 | def parse_options(root_path, is_train=True):
100 | parser = argparse.ArgumentParser()
101 | parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
102 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
103 | parser.add_argument('--auto_resume', action='store_true')
104 | parser.add_argument('--debug', action='store_true')
105 | parser.add_argument('--local_rank', type=int, default=0)
106 | parser.add_argument(
107 | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
108 | args = parser.parse_args()
109 |
110 | # parse yml to dict
111 | opt = yaml_load(args.opt)
112 |
113 | # distributed settings
114 | if args.launcher == 'none':
115 | opt['dist'] = False
116 | print('Disable distributed.', flush=True)
117 | else:
118 | opt['dist'] = True
119 | if args.launcher == 'slurm' and 'dist_params' in opt:
120 | init_dist(args.launcher, **opt['dist_params'])
121 | else:
122 | init_dist(args.launcher)
123 | opt['rank'], opt['world_size'] = get_dist_info()
124 |
125 | # random seed
126 | seed = opt.get('manual_seed')
127 | if seed is None:
128 | seed = random.randint(1, 10000)
129 | opt['manual_seed'] = seed
130 | set_random_seed(seed + opt['rank'])
131 |
132 | # force to update yml options
133 | if args.force_yml is not None:
134 | for entry in args.force_yml:
135 | # now do not support creating new keys
136 | keys, value = entry.split('=')
137 | keys, value = keys.strip(), value.strip()
138 | value = _postprocess_yml_value(value)
139 | eval_str = 'opt'
140 | for key in keys.split(':'):
141 | eval_str += f'["{key}"]'
142 | eval_str += '=value'
143 | # using exec function
144 | exec(eval_str)
145 |
146 | opt['auto_resume'] = args.auto_resume
147 | opt['is_train'] = is_train
148 |
149 | # debug setting
150 | if args.debug and not opt['name'].startswith('debug'):
151 | opt['name'] = 'debug_' + opt['name']
152 |
153 | if opt['num_gpu'] == 'auto':
154 | opt['num_gpu'] = torch.cuda.device_count()
155 |
156 | # datasets
157 | for phase, dataset in opt['datasets'].items():
158 | # for multiple datasets, e.g., val_1, val_2; test_1, test_2
159 | phase = phase.split('_')[0]
160 | dataset['phase'] = phase
161 | if 'scale' in opt:
162 | dataset['scale'] = opt['scale']
163 | if dataset.get('dataroot_gt') is not None:
164 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
165 | if dataset.get('dataroot_lq') is not None:
166 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
167 |
168 | # paths
169 | for key, val in opt['path'].items():
170 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
171 | opt['path'][key] = osp.expanduser(val)
172 |
173 | if is_train:
174 | experiments_root = opt['path'].get('experiments_root')
175 | if experiments_root is None:
176 | experiments_root = osp.join(root_path, 'experiments')
177 | experiments_root = osp.join(experiments_root, opt['name'])
178 |
179 | opt['path']['experiments_root'] = experiments_root
180 | opt['path']['models'] = osp.join(experiments_root, 'models')
181 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
182 | opt['path']['log'] = experiments_root
183 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
184 |
185 | # change some options for debug mode
186 | if 'debug' in opt['name']:
187 | if 'val' in opt:
188 | opt['val']['val_freq'] = 8
189 | opt['logger']['print_freq'] = 1
190 | opt['logger']['save_checkpoint_freq'] = 8
191 | else: # test
192 | results_root = opt['path'].get('results_root')
193 | if results_root is None:
194 | results_root = osp.join(root_path, 'results')
195 | results_root = osp.join(results_root, opt['name'])
196 |
197 | opt['path']['results_root'] = results_root
198 | opt['path']['log'] = results_root
199 | opt['path']['visualization'] = osp.join(results_root, 'visualization')
200 |
201 | return opt, args
202 |
203 |
204 | @master_only
205 | def copy_opt_file(opt_file, experiments_root):
206 | # copy the yml file to the experiment root
207 | import sys
208 | import time
209 | from shutil import copyfile
210 | cmd = ' '.join(sys.argv)
211 | filename = osp.join(experiments_root, osp.basename(opt_file))
212 | copyfile(opt_file, filename)
213 |
214 | with open(filename, 'r+') as f:
215 | lines = f.readlines()
216 | lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
217 | f.seek(0)
218 | f.writelines(lines)
219 |
--------------------------------------------------------------------------------
/basicsr/utils/plot_util.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def read_data_from_tensorboard(log_path, tag):
5 | """Get raw data (steps and values) from tensorboard events.
6 |
7 | Args:
8 | log_path (str): Path to the tensorboard log.
9 | tag (str): tag to be read.
10 | """
11 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
12 |
13 | # tensorboard event
14 | event_acc = EventAccumulator(log_path)
15 | event_acc.Reload()
16 | scalar_list = event_acc.Tags()['scalars']
17 | print('tag list: ', scalar_list)
18 | steps = [int(s.step) for s in event_acc.Scalars(tag)]
19 | values = [s.value for s in event_acc.Scalars(tag)]
20 | return steps, values
21 |
22 |
23 | def read_data_from_txt_2v(path, pattern, step_one=False):
24 | """Read data from txt with 2 returned values (usually [step, value]).
25 |
26 | Args:
27 | path (str): path to the txt file.
28 | pattern (str): re (regular expression) pattern.
29 | step_one (bool): add 1 to steps. Default: False.
30 | """
31 | with open(path) as f:
32 | lines = f.readlines()
33 | lines = [line.strip() for line in lines]
34 | steps = []
35 | values = []
36 |
37 | pattern = re.compile(pattern)
38 | for line in lines:
39 | match = pattern.match(line)
40 | if match:
41 | steps.append(int(match.group(1)))
42 | values.append(float(match.group(2)))
43 | if step_one:
44 | steps = [v + 1 for v in steps]
45 | return steps, values
46 |
47 |
48 | def read_data_from_txt_1v(path, pattern):
49 | """Read data from txt with 1 returned values.
50 |
51 | Args:
52 | path (str): path to the txt file.
53 | pattern (str): re (regular expression) pattern.
54 | """
55 | with open(path) as f:
56 | lines = f.readlines()
57 | lines = [line.strip() for line in lines]
58 | data = []
59 |
60 | pattern = re.compile(pattern)
61 | for line in lines:
62 | match = pattern.match(line)
63 | if match:
64 | data.append(float(match.group(1)))
65 | return data
66 |
67 |
68 | def smooth_data(values, smooth_weight):
69 | """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does).
70 |
71 | Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501
72 |
73 | Args:
74 | values (list): A list of values to be smoothed.
75 | smooth_weight (float): Smooth weight.
76 | """
77 | values_sm = []
78 | last_sm_value = values[0]
79 | for value in values:
80 | value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value
81 | values_sm.append(value_sm)
82 | last_sm_value = value_sm
83 | return values_sm
84 |
--------------------------------------------------------------------------------
/basicsr/utils/registry.py:
--------------------------------------------------------------------------------
1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
2 |
3 |
4 | class Registry():
5 | """
6 | The registry that provides name -> object mapping, to support third-party
7 | users' custom modules.
8 |
9 | To create a registry (e.g. a backbone registry):
10 |
11 | .. code-block:: python
12 |
13 | BACKBONE_REGISTRY = Registry('BACKBONE')
14 |
15 | To register an object:
16 |
17 | .. code-block:: python
18 |
19 | @BACKBONE_REGISTRY.register()
20 | class MyBackbone():
21 | ...
22 |
23 | Or:
24 |
25 | .. code-block:: python
26 |
27 | BACKBONE_REGISTRY.register(MyBackbone)
28 | """
29 |
30 | def __init__(self, name):
31 | """
32 | Args:
33 | name (str): the name of this registry
34 | """
35 | self._name = name
36 | self._obj_map = {}
37 |
38 | def _do_register(self, name, obj, suffix=None):
39 | if isinstance(suffix, str):
40 | name = name + '_' + suffix
41 |
42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
43 | f"in '{self._name}' registry!")
44 | self._obj_map[name] = obj
45 |
46 | def register(self, obj=None, suffix=None):
47 | """
48 | Register the given object under the the name `obj.__name__`.
49 | Can be used as either a decorator or not.
50 | See docstring of this class for usage.
51 | """
52 | if obj is None:
53 | # used as a decorator
54 | def deco(func_or_class):
55 | name = func_or_class.__name__
56 | self._do_register(name, func_or_class, suffix)
57 | return func_or_class
58 |
59 | return deco
60 |
61 | # used as a function call
62 | name = obj.__name__
63 | self._do_register(name, obj, suffix)
64 |
65 | def get(self, name, suffix='basicsr'):
66 | ret = self._obj_map.get(name)
67 | if ret is None:
68 | ret = self._obj_map.get(name + '_' + suffix)
69 | print(f'Name {name} is not found, use name: {name}_{suffix}!')
70 | if ret is None:
71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
72 | return ret
73 |
74 | def __contains__(self, name):
75 | return name in self._obj_map
76 |
77 | def __iter__(self):
78 | return iter(self._obj_map.items())
79 |
80 | def keys(self):
81 | return self._obj_map.keys()
82 |
83 |
84 | DATASET_REGISTRY = Registry('dataset')
85 | ARCH_REGISTRY = Registry('arch')
86 | MODEL_REGISTRY = Registry('model')
87 | LOSS_REGISTRY = Registry('loss')
88 | METRIC_REGISTRY = Registry('metric')
89 |
--------------------------------------------------------------------------------
/options/test/SCNet/SCNet-T-x4-PS.yml:
--------------------------------------------------------------------------------
1 | # Modified SRResNet w/o BN from:
2 | # Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
3 |
4 | # general settings
5 | name: SCNet-T-x4_D64B16_PS
6 | model_type: SRModel
7 | scale: 4
8 | num_gpu: 1 # set num_gpu: 0 for cpu mode
9 | manual_seed: 0
10 |
11 | # dataset and data loader settings
12 | datasets:
13 | test_1: # the 1st test dataset
14 | name: Set5
15 | type: PairedImageDataset
16 | dataroot_gt: datasets/benchmark/Set5/HR
17 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
18 | filename_tmpl: '{}x4'
19 | io_backend:
20 | type: disk
21 | test_2: # the 2nd test dataset
22 | name: Set14
23 | type: PairedImageDataset
24 | dataroot_gt: datasets/benchmark/Set14/HR
25 | dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
26 | filename_tmpl: '{}x4'
27 | io_backend:
28 | type: disk
29 |
30 | test_3:
31 | name: B100
32 | type: PairedImageDataset
33 | dataroot_gt: datasets/benchmark/B100/HR
34 | dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
35 | filename_tmpl: '{}x4'
36 | io_backend:
37 | type: disk
38 | test_4:
39 | name: Urban100
40 | type: PairedImageDataset
41 | dataroot_gt: datasets/benchmark/Urban100/HR
42 | dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
43 | filename_tmpl: '{}x4'
44 | io_backend:
45 | type: disk
46 | test_5:
47 | name: Manga109
48 | type: PairedImageDataset
49 | dataroot_gt: datasets/benchmark/Manga109/HR
50 | dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
51 | filename_tmpl: '{}x4'
52 | io_backend:
53 | type: disk
54 |
55 | # network structures
56 | network_g:
57 | type: SCNet
58 | num_in_ch: 3
59 | num_out_ch: 3
60 | num_feat: 64
61 | num_block: 16
62 | upscale: 4
63 | use_pixelshuffle: true
64 |
65 | # Upsampling with pixelshuffle operation brings better performance
66 | # Ablations can be found at Section 4.3 and Table 7
67 |
68 | # path
69 | path:
70 | pretrain_network_g: model_zoo/SCNet/SCNet-T_x4_D64B16_PS.pth
71 | strict_load_g: true
72 | resume_state: ~
73 |
74 | # validation settings
75 | val:
76 | save_img: true
77 | suffix: ~ # add suffix to saved images, if None, use exp name
78 |
79 | metrics:
80 | psnr: # metric name, can be arbitrary
81 | type: calculate_psnr
82 | crop_border: 4
83 | test_y_channel: true
84 | ssim:
85 | type: calculate_ssim
86 | crop_border: 4
87 | test_y_channel: true
88 |
--------------------------------------------------------------------------------
/options/test/SCNet/SCNet-T-x4.yml:
--------------------------------------------------------------------------------
1 | # Modified SRResNet w/o BN from:
2 | # Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
3 |
4 | # general settings
5 | name: SCNet-T-x4_D64B16
6 | model_type: SRModel
7 | scale: 4
8 | num_gpu: 1 # set num_gpu: 0 for cpu mode
9 | manual_seed: 0
10 |
11 | # dataset and data loader settings
12 | datasets:
13 | test_1: # the 1st test dataset
14 | name: Set5
15 | type: PairedImageDataset
16 | dataroot_gt: datasets/benchmark/Set5/HR
17 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
18 | filename_tmpl: '{}x4'
19 | io_backend:
20 | type: disk
21 | test_2: # the 2nd test dataset
22 | name: Set14
23 | type: PairedImageDataset
24 | dataroot_gt: datasets/benchmark/Set14/HR
25 | dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
26 | filename_tmpl: '{}x4'
27 | io_backend:
28 | type: disk
29 |
30 | test_3:
31 | name: B100
32 | type: PairedImageDataset
33 | dataroot_gt: datasets/benchmark/B100/HR
34 | dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
35 | filename_tmpl: '{}x4'
36 | io_backend:
37 | type: disk
38 | test_4:
39 | name: Urban100
40 | type: PairedImageDataset
41 | dataroot_gt: datasets/benchmark/Urban100/HR
42 | dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
43 | filename_tmpl: '{}x4'
44 | io_backend:
45 | type: disk
46 | test_5:
47 | name: Manga109
48 | type: PairedImageDataset
49 | dataroot_gt: datasets/benchmark/Manga109/HR
50 | dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
51 | filename_tmpl: '{}x4'
52 | io_backend:
53 | type: disk
54 |
55 | # network structures
56 | network_g:
57 | type: SCNet
58 | num_in_ch: 3
59 | num_out_ch: 3
60 | num_feat: 64
61 | num_block: 16
62 | upscale: 4
63 | # path
64 | path:
65 | pretrain_network_g: model_zoo/SCNet/SCNet-T-x4.pth
66 | strict_load_g: true
67 | resume_state: ~
68 |
69 | # validation settings
70 | val:
71 | save_img: true
72 | suffix: ~ # add suffix to saved images, if None, use exp name
73 |
74 | metrics:
75 | psnr: # metric name, can be arbitrary
76 | type: calculate_psnr
77 | crop_border: 4
78 | test_y_channel: true
79 | ssim:
80 | type: calculate_ssim
81 | crop_border: 4
82 | test_y_channel: true
83 |
--------------------------------------------------------------------------------
/options/train/SCNet/SCNet-T-x4.yml:
--------------------------------------------------------------------------------
1 | # Modified from SRResNet w/o BN config in BasicSR:
2 |
3 | # general settings
4 | name: SCNet-T-x4
5 | model_type: SRModel
6 | scale: 4
7 | num_gpu: 2 # set num_gpu: 0 for cpu mode
8 | manual_seed: 0
9 |
10 | # dataset and data loader settings
11 | datasets:
12 | train:
13 | name: DF2K
14 | type: PairedImageDataset
15 | dataroot_gt: Path to your data
16 | dataroot_lq: Path to your data
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: lmdb
20 |
21 | gt_size: 256
22 | use_hflip: true
23 | use_rot: true
24 |
25 | # data loader
26 | use_shuffle: true
27 | num_worker_per_gpu: 4
28 | batch_size_per_gpu: 8
29 | dataset_enlarge_ratio: 1
30 | pin_memory: true
31 |
32 | val:
33 | name: Set14
34 | type: PairedImageDataset
35 | dataroot_gt: datasets/benchmark/Set14/HR
36 | dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
37 | filename_tmpl: '{}x4'
38 | io_backend:
39 | type: disk
40 |
41 | # network structures
42 | network_g:
43 | type: SCNet
44 | num_in_ch: 3
45 | num_out_ch: 3
46 | num_feat: 64
47 | num_block: 16
48 | upscale: 4
49 |
50 | # path
51 | path:
52 | pretrain_network_g: ~
53 | strict_load_g: false
54 | resume_state: ~
55 | # training settings
56 | train:
57 | ema_decay: 0.999
58 | optim_g:
59 | type: Adam
60 | lr: !!float 2e-4
61 | weight_decay: 0
62 | betas: [0.9, 0.99]
63 |
64 | scheduler:
65 | type: MultiStepLR
66 | milestones: [200000]
67 | gamma: 0.5
68 |
69 | total_iter: 300000
70 | warmup_iter: -1 # no warm up
71 |
72 | # losses
73 | pixel_opt:
74 | type: L1Loss
75 | loss_weight: 1.0
76 | reduction: mean
77 |
78 | # validation settings
79 | val:
80 | val_freq: !!float 5e3
81 | save_img: false
82 |
83 | metrics:
84 | psnr: # metric name, can be arbitrary
85 | type: calculate_psnr
86 | crop_border: 4
87 | test_y_channel: true
88 | ssim: # metric name, can be arbitrary
89 | type: calculate_ssim
90 | crop_border: 4
91 | test_y_channel: true
92 | # logging settings
93 | logger:
94 | print_freq: 200
95 | save_checkpoint_freq: !!float 5e3
96 | use_tb_logger: true
97 | wandb:
98 | project: ~
99 | resume_id: ~
100 |
101 | # dist training settings
102 | dist_params:
103 | backend: nccl
104 | port: 29500
105 |
--------------------------------------------------------------------------------