├── .idea
    ├── .gitignore
    ├── DAN-Basd-on-Openmmlab.iml
    ├── inspectionProfiles
    │   ├── Project_Default.xml
    │   └── profiles_settings.xml
    ├── misc.xml
    ├── modules.xml
    └── vcs.xml
├── README.md
├── configs
    └── restorers
    │   └── dan
    │       └── dan_setting2.py
├── mmedit
    └── models
    │   ├── backbones
    │       └── sr_backbones
    │       │   └── dan_net.py
    │   ├── common
    │       └── dan_preprocess.py
    │   └── restorers
    │       └── dan.py
└── tools
    └── data
        └── super-resolution
            └── dan_datasets
                ├── pca_matrix
                    ├── pca_aniso_matrix_x2.pth
                    └── pca_aniso_matrix_x4.pth
                ├── preprocess_dan_datasets.py
                └── preprocess_div2k_dataset.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | 
--------------------------------------------------------------------------------
/.idea/DAN-Basd-on-Openmmlab.iml:
--------------------------------------------------------------------------------
 1 | 
 2 | 
 3 |   
 4 |     
 5 |     
 6 |     
 7 |   
 8 |   
 9 |     
10 |     
11 |   
12 | 
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
 1 | 
 2 |   
 3 |     
 4 |     
 5 |       
26 |     
27 |     
28 |       
33 |     
34 |   
35 | 
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 | 
2 |   
3 |     
4 |     
5 |   
6 | 
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 | 
2 | 
3 |   
4 | 
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 | 
2 | 
3 |   
4 |     
5 |       
6 |     
7 |   
8 | 
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 | 
2 | 
3 |   
4 |     
5 |   
6 | 
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
 1 | # DAN-Basd-on-Openmmlab
 2 | DAN: [Unfolding the Alternating Optimization for Blind Super Resolution](https://arxiv.org/abs/2010.02631)
 3 | 
 4 | We reproduce DAN via [mmediting](https://github.com/open-mmlab/mmediting) based on [open-sourced code](https://github.com/greatlog/DAN).
 5 | 
 6 | ## Requirements
 7 | 
 8 | - PyTorch >= 1.3
 9 | - mmediting >= 0.9
10 | 
11 | ## DataSets
12 | We use [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](http://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) as our training datasets.
13 | For evaluation of Setting 2, we use [DIV2KRK](http://www.wisdom.weizmann.ac.il/~vision/kernelgan/DIV2KRK_public.zip) datasets,
14 | 
15 | ## Usages
16 | How to run this repo: copy the file to the mmediting workspace and run the program directly based on the commands in mmediting
17 | 
18 | 1. Copy files to MMEditing workspace.
19 | ```shell
20 | cd DAN-Basd-on-Openmmlab/
21 | mv ./mmedit/models/restorers/dan.py ${mmediting_workspace}/mmedit/models/restorers/
22 | mv ./mmedit/models/backbones/sr_backbones/dan_net.py ${mmediting_workspace}/mmedit/models/backbones/sr_backbones/
23 | mv ./mmedit/models/common/DANpreprocess.py ${mmediting_workspace}/mmedit/models/common
24 | mv ./configs/restorers/dan ${mmediting_workspace}/configs/restorers/
25 | mv ./tools/data/super-resolution/dan_datasets ${mmediting_workspace}/tools/data/super-resolution/
26 | ```
27 | 2. Modify the configuration file as follows:
28 | 
29 | ```python
30 | pca_matrix_path='${mmediting_workspace}/tools/data/super-resolution/div2k/pca_matrix/pca_aniso_matrix_x4.pth' # your pca_matrix path
31 | # Training
32 | gt_folder='${dataset_workspace}/dataset/DF2K_train_HR_sub' # your train data path
33 | # Testing
34 | lq_folder='${dataset_workspace}/dataset/DIV2KRK/lr_x4' # your test data LR path
35 | gt_folder='${dataset_workspace}/dataset/DIV2KRK/gt' # your test data HR path
36 | ```
37 | 3. Add script to init file, as follows:
38 | 
39 | - modify the `mmedit/models/backbones/sr_backbones/__init__.py`:
40 | ```python
41 | from .dan_net import DAN
42 | # add DAN into __all__ list.
43 | ```
44 | - modify the `mmedit/models/commons/__init__.py`:
45 | ```python
46 | from .dan_preprocess import SRMDPreprocessing
47 | # add SRMDreprocessing into __all__ list.
48 | ```
49 | - modify the `mmedit/models/restorers/__init__.py`:
50 | ```python
51 | from .dan import DAN
52 | # add DAN into __all__ list.
53 | ```
54 | 
55 | 4. Training/Test
56 | 
57 | Before using it, please download and process the dataset and set the path in the configuration file.
58 | 
59 | - Train
60 | 
61 | ```shell
62 | # Single GPU
63 | python tools/train.py configs/restorers/dan/dan_setting2.py --work_dir ${YOUR_WORK_DIR}
64 | 
65 | # Multiple GPUs
66 | ./tools/dist_train.sh configs/restorers/dan/dan_setting2.py ${GPU_NUM} --work_dir ${YOUR_WORK_DIR}
67 | ```
68 | 
69 | - Test
70 | ```shell
71 | # Single GPU
72 | python tools/test.py configs/restorers/dan/dan_setting2.py ${CHECKPOINT_FILE} [--metrics ${METRICS}] [--out ${RESULT_FILE}]
73 | 
74 | # Multiple GPUs
75 | ./tools/dist_test.sh configs/restorers/dan/dan_setting2.py ${CHECKPOINT_FILE} ${GPU_NUM} [--metrics ${METRICS}] [--out ${RESULT_FILE}]
76 | ```
77 | 
78 | ## Result
79 | 
80 | ### DIV2KRK
81 | The passwds of download links are all 'ta2o'
82 | 
83 | | Method | scale | Datasets | PSNR | SSIM | Download |
84 | | :-----: | :----: | :----: | :----: | :----: | :----:| 
85 | | DAN-RGB (paper) | x4 | DIV2KRK | 26.09 | 0.7312 | - |
86 | | DAN-Y (paper) | x4 | DIV2KRK | 27.55 | 0.7582 | - |
87 | | DAN-RGB (Ours) | x4 | DIV2KRK | 27.41 | 0.7666 | [model](https://pan.baidu.com/s/1T_BOVR7Ui-NLUIKr6R20-w) / [test_pkl](https://pan.baidu.com/s/1T_BOVR7Ui-NLUIKr6R20-w) |
88 | | DAN-Y (Ours) | x4 | DIV2KRK | 28.88 | 0.7915 | [model](https://pan.baidu.com/s/1T_BOVR7Ui-NLUIKr6R20-w) / [test_pkl](https://pan.baidu.com/s/1T_BOVR7Ui-NLUIKr6R20-w) |
89 | 
90 | ------------
91 | 
--------------------------------------------------------------------------------
/configs/restorers/dan/dan_setting2.py:
--------------------------------------------------------------------------------
  1 | exp_name = 'dan_setting2_lr6e25-6_320'
  2 | 
  3 | scale = 4
  4 | # model settings
  5 | model = dict(
  6 |     type='DAN',
  7 |     generator=dict(
  8 |         type='DAN',
  9 |         nf=64,
 10 |         nb=40,
 11 |         input_para=10,
 12 |         loop=4,
 13 |         kernel_size=21,
 14 |         # Your pca matrix path
 15 |         pca_matrix_path='pca_aniso_matrix_x4.pth'),
 16 |     pixel_loss=dict(type='MSELoss', loss_weight=1.0, reduction='mean'))
 17 | # model training and testing settings
 18 | 
 19 | train_cfg = dict(
 20 |     # Your pca matrix path
 21 |     pca_matrix_path='pca_aniso_matrix_x4.pth',
 22 |     scale=scale,
 23 |     degradation=dict(
 24 |         random_kernel=True,
 25 |         ksize=21,
 26 |         code_length=10,
 27 |         sig_min=0.6,
 28 |         sig_max=5.0,
 29 |         rate_iso=0,
 30 |         random_disturb=True))
 31 | test_cfg = dict(metrics=['PSNR', 'SSIM'], crop_border=scale)
 32 | 
 33 | # dataset settings
 34 | train_dataset_type = 'SRFolderGTDataset'
 35 | val_dataset_type = 'SRFolderDataset'
 36 | train_pipeline = [
 37 |     dict(
 38 |         type='LoadImageFromFile',
 39 |         io_backend='disk',
 40 |         key='gt',
 41 |         flag='unchanged'),
 42 |     dict(type='RescaleToZeroOne', keys=['gt']),
 43 |     dict(
 44 |         type='Normalize',
 45 |         keys=['gt'],
 46 |         mean=[0, 0, 0],
 47 |         std=[1, 1, 1],
 48 |         to_rgb=True),
 49 |     dict(type='Crop', keys=['gt'], crop_size=(256, 256)),
 50 |     dict(type='Flip', keys=['gt'], flip_ratio=0.5, direction='horizontal'),
 51 |     dict(type='Flip', keys=['gt'], flip_ratio=0.5, direction='vertical'),
 52 |     dict(type='RandomTransposeHW', keys=['gt'], transpose_ratio=0.5),
 53 |     dict(type='Collect', keys=['gt'], meta_keys=['gt_path']),
 54 |     dict(type='ImageToTensor', keys=['gt'])
 55 | ]
 56 | test_pipeline = [
 57 |     dict(type='LoadImageFromFile', io_backend='disk', key='lq', flag='color'),
 58 |     dict(type='LoadImageFromFile', io_backend='disk', key='gt', flag='color'),
 59 |     dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
 60 |     dict(
 61 |         type='Normalize',
 62 |         keys=['lq', 'gt'],
 63 |         mean=[0, 0, 0],
 64 |         std=[1, 1, 1],
 65 |         to_rgb=True),
 66 |     dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'gt_path']),
 67 |     dict(type='ImageToTensor', keys=['lq', 'gt'])
 68 | ]
 69 | 
 70 | data = dict(
 71 |     workers_per_gpu=4,
 72 |     train_dataloader=dict(samples_per_gpu=8, drop_last=True),
 73 |     val_dataloader=dict(samples_per_gpu=1),
 74 |     test_dataloader=dict(samples_per_gpu=1),
 75 |     train=dict(
 76 |         type='RepeatDataset',
 77 |         times=1000,
 78 |         dataset=dict(
 79 |             type=train_dataset_type,
 80 |             gt_folder='dataset/DF2K_train_HR_sub',  # Your training data path
 81 |             pipeline=train_pipeline,
 82 |             scale=scale)),
 83 |     val=dict(
 84 |         type=val_dataset_type,
 85 |         lq_folder='dataset/DIV2KRK/lr_x4',  # Your validation data LR path
 86 |         gt_folder='dataset/DIV2KRK/gt',  # Your validation data HR path
 87 |         pipeline=test_pipeline,
 88 |         scale=scale,
 89 |         filename_tmpl='{}'),
 90 |     test=dict(
 91 |         type=val_dataset_type,
 92 |         lq_folder='dataset/DIV2KRK/lr_x4',  # Your validation data LR path
 93 |         gt_folder='dataset/DIV2KRK/gt',  # Your validation data HR path
 94 |         pipeline=test_pipeline,
 95 |         scale=scale,
 96 |         filename_tmpl='{}'))
 97 | 
 98 | # optimizer
 99 | optimizers = dict(generator=dict(type='Adam', lr=6.25e-6, betas=(0.9, 0.999)))
100 | 
101 | # learning policy
102 | total_iters = 400000
103 | lr_config = dict(
104 |     policy='Step', by_epoch=False, step=[100000, 200000, 300000], gamma=0.5)
105 | 
106 | checkpoint_config = dict(interval=1000, save_optimizer=True, by_epoch=False)
107 | evaluation = dict(interval=1000, save_image=False, gpu_collect=True)
108 | log_config = dict(
109 |     interval=100,
110 |     hooks=[
111 |         dict(type='TextLoggerHook', by_epoch=False),
112 |         dict(type='TensorboardLoggerHook'),
113 |         # dict(type='PaviLoggerHook', init_kwargs=dict(project='mmedit-sr'))
114 |     ])
115 | visual_config = None
116 | 
117 | # runtime settings
118 | dist_params = dict(backend='nccl')
119 | log_level = 'INFO'
120 | work_dir = f'./work_dirs/{exp_name}'
121 | load_from = 'mmediting/danx4_l1_256/iter_40000.pth'
122 | resume_from = None
123 | workflow = [('train', 1)]
124 | 
--------------------------------------------------------------------------------
/mmedit/models/backbones/sr_backbones/dan_net.py:
--------------------------------------------------------------------------------
  1 | import numpy as np
  2 | import torch
  3 | import torch.nn as nn
  4 | import torch.nn.functional as F
  5 | 
  6 | from mmcv.runner import load_checkpoint
  7 | from mmedit.models.registry import BACKBONES
  8 | from mmedit.utils import get_root_logger
  9 | 
 10 | 
 11 | class CALayer(nn.Module):
 12 |     def __init__(self, nf, reduction=16):
 13 |         super(CALayer, self).__init__()
 14 |         self.body = nn.Sequential(
 15 |             nn.Conv2d(nf, nf // reduction, 1, 1, 0),
 16 |             nn.LeakyReLU(0.2),
 17 |             nn.Conv2d(nf // reduction, nf, 1, 1, 0),
 18 |             nn.Sigmoid(),
 19 |         )
 20 |         self.avg = nn.AdaptiveAvgPool2d(1)
 21 | 
 22 |     def forward(self, x):
 23 |         y = self.avg(x)
 24 |         y = self.body(y)
 25 |         return torch.mul(x, y)
 26 | 
 27 | 
 28 | class CRB_Layer(nn.Module):
 29 |     def __init__(self, nf1, nf2):
 30 |         super(CRB_Layer, self).__init__()
 31 | 
 32 |         body = [
 33 |             nn.Conv2d(nf1 + nf2, nf1 + nf2, 3, 1, 1),
 34 |             nn.LeakyReLU(0.2, True),
 35 |             nn.Conv2d(nf1 + nf2, nf1, 3, 1, 1),
 36 |             CALayer(nf1),
 37 |         ]
 38 | 
 39 |         self.body = nn.Sequential(*body)
 40 | 
 41 |     def forward(self, x):
 42 |         f1, f2 = x
 43 |         f1 = self.body(torch.cat(x, 1)) + f1
 44 |         return [f1, f2]
 45 | 
 46 | 
 47 | class Estimator(nn.Module):
 48 |     def __init__(self, in_nc=3, nf=64, num_blocks=5, scale=4, kernel_size=4):
 49 |         super(Estimator, self).__init__()
 50 | 
 51 |         self.ksize = kernel_size
 52 | 
 53 |         self.head_LR = nn.Conv2d(in_nc, nf // 2, 1, 1, 0)
 54 |         self.head_HR = nn.Conv2d(in_nc, nf // 2, 9, scale, 4)
 55 | 
 56 |         body = [CRB_Layer(nf // 2, nf // 2) for _ in range(num_blocks)]
 57 |         self.body = nn.Sequential(*body)
 58 | 
 59 |         self.out = nn.Conv2d(nf // 2, 10, 3, 1, 1)
 60 |         self.globalPooling = nn.AdaptiveAvgPool2d((1, 1))
 61 | 
 62 |     def forward(self, GT, LR):
 63 | 
 64 |         lrf = self.head_LR(LR)
 65 |         hrf = self.head_HR(GT)
 66 | 
 67 |         f = [lrf, hrf]
 68 |         f, _ = self.body(f)
 69 |         f = self.out(f)
 70 |         f = self.globalPooling(f)
 71 |         f = f.view(f.size()[:2])
 72 | 
 73 |         return f
 74 | 
 75 | 
 76 | class Restorer(nn.Module):
 77 |     def __init__(
 78 |         self, in_nc=3, out_nc=3, nf=64, nb=8, scale=4, input_para=10, min=0.0, max=1.0
 79 |     ):
 80 |         super(Restorer, self).__init__()
 81 |         self.min = min
 82 |         self.max = max
 83 |         self.para = input_para
 84 |         self.num_blocks = nb
 85 | 
 86 |         self.head = nn.Conv2d(in_nc, nf, 3, stride=1, padding=1)
 87 | 
 88 |         body = [CRB_Layer(nf, input_para) for _ in range(nb)]
 89 |         self.body = nn.Sequential(*body)
 90 | 
 91 |         self.fusion = nn.Conv2d(nf, nf, 3, 1, 1)
 92 | 
 93 |         if scale == 4:  # x4
 94 |             self.upscale = nn.Sequential(
 95 |                 nn.Conv2d(
 96 |                     in_channels=nf,
 97 |                     out_channels=nf * scale,
 98 |                     kernel_size=3,
 99 |                     stride=1,
100 |                     padding=1,
101 |                     bias=True,
102 |                 ),
103 |                 nn.PixelShuffle(scale // 2),
104 |                 nn.Conv2d(
105 |                     in_channels=nf,
106 |                     out_channels=nf * scale,
107 |                     kernel_size=3,
108 |                     stride=1,
109 |                     padding=1,
110 |                     bias=True,
111 |                 ),
112 |                 nn.PixelShuffle(scale // 2),
113 |                 nn.Conv2d(nf, 3, 3, 1, 1),
114 |             )
115 |         else:  # x2, x3
116 |             self.upscale = nn.Sequential(
117 |                 nn.Conv2d(
118 |                     in_channels=nf,
119 |                     out_channels=nf * scale ** 2,
120 |                     kernel_size=3,
121 |                     stride=1,
122 |                     padding=1,
123 |                     bias=True,
124 |                 ),
125 |                 nn.PixelShuffle(scale),
126 |                 nn.Conv2d(nf, 3, 3, 1, 1),
127 |             )
128 | 
129 |     def forward(self, input, ker_code):
130 |         B, C, H, W = input.size()  # I_LR batch
131 |         B_h, C_h = ker_code.size()  # Batch, Len=10
132 |         ker_code_exp = ker_code.view((B_h, C_h, 1, 1)).expand(
133 |             (B_h, C_h, H, W)
134 |         )  # kernel_map stretch
135 | 
136 |         f = self.head(input)
137 |         inputs = [f, ker_code_exp]
138 |         f, _ = self.body(inputs)
139 |         f = self.fusion(f)
140 |         out = self.upscale(f)
141 | 
142 |         return out  # torch.clamp(out, min=self.min, max=self.max)
143 | 
144 | 
145 | @BACKBONES.register_module()
146 | class DAN(nn.Module):
147 |     def __init__(
148 |         self,
149 |         nf=64,
150 |         nb=16,
151 |         upscale=4,
152 |         input_para=10,
153 |         kernel_size=21,
154 |         loop=8,
155 |         pca_matrix_path=None,
156 |     ):
157 |         super(DAN, self).__init__()
158 | 
159 |         self.ksize = kernel_size
160 |         self.loop = loop
161 |         self.scale = upscale
162 | 
163 |         self.Restorer = Restorer(nf=nf, nb=nb, scale=self.scale, input_para=input_para)
164 |         self.Estimator = Estimator(kernel_size=kernel_size, scale=self.scale)
165 | 
166 |         self.register_buffer("encoder", torch.load(pca_matrix_path)[None])
167 | 
168 |         kernel = torch.zeros(1, self.ksize, self.ksize)
169 |         kernel[:, self.ksize // 2, self.ksize // 2] = 1
170 | 
171 |         self.register_buffer("init_kernel", kernel)
172 |         init_ker_map = self.init_kernel.view(1, 1, self.ksize ** 2).matmul(
173 |             self.encoder
174 |         )[:, 0]
175 |         self.register_buffer("init_ker_map", init_ker_map)
176 | 
177 |     def forward(self, lr):
178 | 
179 |         srs = []
180 |         ker_maps = []
181 | 
182 |         B, C, H, W = lr.shape
183 |         ker_map = self.init_ker_map.repeat([B, 1])
184 | 
185 |         for i in range(self.loop):
186 | 
187 |             sr = self.Restorer(lr, ker_map.detach())
188 |             ker_map = self.Estimator(sr.detach(), lr)
189 | 
190 |             srs.append(sr)
191 |             ker_maps.append(ker_map)
192 |         return [srs, ker_maps]
193 | 
194 |     def init_weights(self, pretrained=None, strict=True):
195 |         """Init weights for models.
196 |         Args:
197 |             pretrained (str, optional): Path for pretrained weights. If given
198 |                 None, pretrained weights will not be loaded. Defaults to None.
199 |             strict (boo, optional): Whether strictly load the pretrained model.
200 |                 Defaults to True.
201 |         """
202 |         if isinstance(pretrained, str):
203 |             logger = get_root_logger()
204 |             load_checkpoint(self, pretrained, strict=strict, logger=logger)
205 |         elif pretrained is None:
206 |             pass  # use default initialization
207 |         else:
208 |             raise TypeError('"pretrained" must be a str or None. '
209 |                             f'But received {type(pretrained)}.')
210 | 
--------------------------------------------------------------------------------
/mmedit/models/common/dan_preprocess.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | 
  3 | import numpy as np
  4 | import torch
  5 | import torch.nn as nn
  6 | import torch.nn.functional as F
  7 | from scipy.io import loadmat
  8 | from torch.autograd import Variable
  9 | from torchvision.utils import make_grid
 10 | 
 11 | 
 12 | def img2tensor(img):
 13 |     """
 14 |     # BGR to RGB, HWC to CHW, numpy to tensor
 15 |     Input: img(H, W, C), [0,255], np.uint8 (default)
 16 |     Output: 3D(C,H,W), RGB order, float tensor
 17 |     """
 18 |     img = img.astype(np.float32) / 255.0
 19 |     img = img[:, :, [2, 1, 0]]
 20 |     img = torch.from_numpy(np.ascontiguousarray(np.transpose(img,
 21 |                                                              (2, 0,
 22 |                                                               1)))).float()
 23 |     return img
 24 | 
 25 | 
 26 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
 27 |     """
 28 |     Converts a torch Tensor into an image Numpy array
 29 |     Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
 30 |     Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
 31 |     """
 32 |     tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # clamp
 33 |     tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]
 34 |                                       )  # to range [0,1]
 35 |     n_dim = tensor.dim()
 36 |     if n_dim == 4:
 37 |         n_img = len(tensor)
 38 |         img_np = make_grid(
 39 |             tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
 40 |         img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
 41 |     elif n_dim == 3:
 42 |         img_np = tensor.numpy()
 43 |         img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
 44 |     elif n_dim == 2:
 45 |         img_np = tensor.numpy()
 46 |     else:
 47 |         raise TypeError('Only support 4D, 3D and 2D tensor.')
 48 |     if out_type == np.uint8:
 49 |         img_np = (img_np * 255.0).round()
 50 |         # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
 51 |     return img_np.astype(out_type)
 52 | 
 53 | 
 54 | def cubic(x):
 55 |     absx = torch.abs(x)
 56 |     absx2 = absx**2
 57 |     absx3 = absx**3
 58 | 
 59 |     weight = (1.5 * absx3 - 2.5 * absx2 + 1) * (
 60 |         (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx +
 61 |                                       2) * (((absx > 1) *
 62 |                                              (absx <= 2)).type_as(absx))
 63 |     return weight
 64 | 
 65 | 
 66 | def calculate_weights_indices(in_length, out_length, scale, kernel,
 67 |                               kernel_width, antialiasing):
 68 |     if (scale < 1) and (antialiasing):
 69 |         # Use a modified kernel to simultaneously interpolate
 70 |         kernel_width = kernel_width / scale
 71 | 
 72 |     # Output-space coordinates
 73 |     x = torch.linspace(1, out_length, out_length)
 74 | 
 75 |     # Input-space coordinates. Calculate the inverse mapping such that 0.5
 76 |     # in output space maps to 0.5 in input space, and 0.5+scale in output
 77 |     # space maps to 1.5 in input space.
 78 |     u = x / scale + 0.5 * (1 - 1 / scale)
 79 | 
 80 |     # What is the left-most pixel that can be involved in the computation?
 81 |     left = torch.floor(u - kernel_width / 2)
 82 | 
 83 |     # What is the maximum number of pixels that can be involved in the
 84 |     # computation?  Note: it's OK to use an extra pixel here; if the
 85 |     # corresponding weights are all zero, it will be eliminated at the end
 86 |     # of this function.
 87 |     P = math.ceil(kernel_width) + 2
 88 | 
 89 |     # The indices of the input pixels involved in computing the k-th output
 90 |     # pixel are in row k of the indices matrix.
 91 |     indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
 92 |         0, P - 1, P).view(1, P).expand(out_length, P)
 93 | 
 94 |     # The weights used to compute the k-th output pixel are in row k of the
 95 |     # weights matrix.
 96 |     distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
 97 |     # apply cubic kernel
 98 |     if (scale < 1) and (antialiasing):
 99 |         weights = scale * cubic(distance_to_center * scale)
100 |     else:
101 |         weights = cubic(distance_to_center)
102 |     # Normalize the weights matrix so that each row sums to 1.
103 |     weights_sum = torch.sum(weights, 1).view(out_length, 1)
104 |     weights = weights / weights_sum.expand(out_length, P)
105 | 
106 |     # If a column in weights is all zero, get rid of it.
107 |     weights_zero_tmp = torch.sum((weights == 0), 0)
108 |     if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
109 |         indices = indices.narrow(1, 1, P - 2)
110 |         weights = weights.narrow(1, 1, P - 2)
111 |     if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
112 |         indices = indices.narrow(1, 0, P - 2)
113 |         weights = weights.narrow(1, 0, P - 2)
114 |     weights = weights.contiguous()
115 |     indices = indices.contiguous()
116 |     sym_len_s = -indices.min() + 1
117 |     sym_len_e = indices.max() - in_length
118 |     indices = indices + sym_len_s - 1
119 |     return weights, indices, int(sym_len_s), int(sym_len_e)
120 | 
121 | 
122 | def imresize(img, scale, antialiasing=True):
123 |     # Now the scale should be the same for H and W
124 |     # input: img: CHW RGB [0,1]
125 |     # output: CHW RGB [0,1] w/o round
126 |     is_numpy = False
127 |     if isinstance(img, np.ndarray):
128 |         img = torch.from_numpy(img.transpose(2, 0, 1))
129 |         is_numpy = True
130 |     device = img.device
131 | 
132 |     is_batch = True
133 |     if len(img.shape) == 3:  # C, H, W
134 |         img = img[None]
135 |         is_batch = False
136 | 
137 |     B, in_C, in_H, in_W = img.size()
138 |     img = img.view(-1, in_H, in_W)
139 |     _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
140 |     kernel_width = 4
141 |     kernel = 'cubic'
142 | 
143 |     # Return the desired dimension order for performing the resize.  The
144 |     # strategy is to perform the resize first along the dimension with the
145 |     # smallest scale factor.
146 |     # Now we do not support this.
147 | 
148 |     # get weights and indices
149 |     weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
150 |         in_H, out_H, scale, kernel, kernel_width, antialiasing)
151 |     weights_H, indices_H = weights_H.to(device), indices_H.to(device)
152 |     weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
153 |         in_W, out_W, scale, kernel, kernel_width, antialiasing)
154 |     weights_W, indices_W = weights_W.to(device), indices_W.to(device)
155 |     # process H dimension
156 |     # symmetric copying
157 |     img_aug = torch.FloatTensor(B * in_C, in_H + sym_len_Hs + sym_len_He,
158 |                                 in_W).to(device)
159 |     img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
160 | 
161 |     sym_patch = img[:, :sym_len_Hs, :]
162 |     inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long().to(device)
163 |     sym_patch_inv = sym_patch.index_select(1, inv_idx)
164 |     img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
165 | 
166 |     sym_patch = img[:, -sym_len_He:, :]
167 |     inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long().to(device)
168 |     sym_patch_inv = sym_patch.index_select(1, inv_idx)
169 |     img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
170 | 
171 |     out_1 = torch.FloatTensor(B * in_C, out_H, in_W).to(device)
172 |     kernel_width = weights_H.size(1)
173 |     for i in range(out_H):
174 |         idx = int(indices_H[i][0])
175 |         out_1[:, i, :] = (img_aug[:, idx:idx + kernel_width, :].transpose(
176 |             1, 2).matmul(weights_H[i][None, :, None].repeat(B * in_C, 1,
177 |                                                             1))).squeeze()
178 | 
179 |     # process W dimension
180 |     # symmetric copying
181 |     out_1_aug = torch.FloatTensor(B * in_C, out_H,
182 |                                   in_W + sym_len_Ws + sym_len_We).to(device)
183 |     out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
184 | 
185 |     sym_patch = out_1[:, :, :sym_len_Ws]
186 |     inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long().to(device)
187 |     sym_patch_inv = sym_patch.index_select(2, inv_idx)
188 |     out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
189 | 
190 |     sym_patch = out_1[:, :, -sym_len_We:]
191 |     inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long().to(device)
192 |     sym_patch_inv = sym_patch.index_select(2, inv_idx)
193 |     out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
194 | 
195 |     out_2 = torch.FloatTensor(B * in_C, out_H, out_W).to(device)
196 |     kernel_width = weights_W.size(1)
197 |     for i in range(out_W):
198 |         idx = int(indices_W[i][0])
199 |         out_2[:, :, i] = (out_1_aug[:, :, idx:idx + kernel_width].matmul(
200 |             weights_W[i][None, :, None].repeat(B * in_C, 1, 1))).squeeze()
201 | 
202 |     out_2 = out_2.contiguous().view(B, in_C, out_H, out_W)
203 |     if not is_batch:
204 |         out_2 = out_2[0]
205 |     return out_2.cpu().numpy().transpose(1, 2, 0) if is_numpy else out_2
206 | 
207 | 
208 | def DUF_downsample(x, scale=4):
209 |     """Downsamping with Gaussian kernel used in the DUF official code
210 |     Args:
211 |         x (Tensor, [B, T, C, H, W]): frames to be downsampled.
212 |         scale (int): downsampling factor: 2 | 3 | 4.
213 |     """
214 | 
215 |     assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale)
216 | 
217 |     def gkern(kernlen=13, nsig=1.6):
218 |         import scipy.ndimage.filters as fi
219 | 
220 |         inp = np.zeros((kernlen, kernlen))
221 |         # set element at the middle to one, a dirac delta
222 |         inp[kernlen // 2, kernlen // 2] = 1
223 |         # gaussian-smooth the dirac, resulting in a gaussian filter mask
224 |         return fi.gaussian_filter(inp, nsig)
225 | 
226 |     B, T, C, H, W = x.size()
227 |     x = x.view(-1, 1, H, W)
228 |     # 6 is the pad of the gaussian filter
229 |     pad_w, pad_h = 6 + scale * 2, 6 + scale * 2
230 |     r_h, r_w = 0, 0
231 |     if scale == 3:
232 |         r_h = 3 - (H % 3)
233 |         r_w = 3 - (W % 3)
234 |     x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect')
235 | 
236 |     gaussian_filter = (
237 |         torch.from_numpy(gkern(13, 0.4 *
238 |                                scale)).type_as(x).unsqueeze(0).unsqueeze(0))
239 |     x = F.conv2d(x, gaussian_filter, stride=scale)
240 |     x = x[:, :, 2:-2, 2:-2]
241 |     x = x.view(B, T, C, x.size(2), x.size(3))
242 |     return x
243 | 
244 | 
245 | def PCA(data, k=2):
246 |     X = torch.from_numpy(data)
247 |     X_mean = torch.mean(X, 0)
248 |     X = X - X_mean.expand_as(X)
249 |     U, S, V = torch.svd(torch.t(X))
250 |     return U[:, :k]  # PCA matrix
251 | 
252 | 
253 | def random_batch_kernel(
254 |     batch,
255 |     kernel_size=21,
256 |     sig_min=0.2,
257 |     sig_max=4.0,
258 |     rate_iso=1.0,
259 |     tensor=True,
260 |     random_disturb=False,
261 | ):
262 |     if rate_iso == 1:
263 | 
264 |         sigma = np.random.uniform(sig_min, sig_max, (batch, 1, 1))
265 |         ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
266 |         xx, yy = np.meshgrid(ax, ax)
267 |         xx = xx[None].repeat(batch, 0)
268 |         yy = yy[None].repeat(batch, 0)
269 |         kernel = np.exp(-(xx**2 + yy**2) / (2.0 * sigma**2))
270 |         kernel = kernel / np.sum(kernel, (1, 2), keepdims=True)
271 |         return torch.FloatTensor(kernel) if tensor else kernel
272 | 
273 |     else:
274 | 
275 |         sigma_x = np.random.uniform(sig_min, sig_max, (batch, 1, 1))
276 |         sigma_y = np.random.uniform(sig_min, sig_max, (batch, 1, 1))
277 | 
278 |         D = np.zeros((batch, 2, 2))
279 |         D[:, 0, 0] = sigma_x.squeeze()**2
280 |         D[:, 1, 1] = sigma_y.squeeze()**2
281 | 
282 |         radians = np.random.uniform(-np.pi, np.pi, (batch))
283 |         mask_iso = np.random.uniform(0, 1, (batch)) < rate_iso
284 |         radians[mask_iso] = 0
285 |         sigma_y[mask_iso] = sigma_x[mask_iso]
286 | 
287 |         U = np.zeros((batch, 2, 2))
288 |         U[:, 0, 0] = np.cos(radians)
289 |         U[:, 0, 1] = -np.sin(radians)
290 |         U[:, 1, 0] = np.sin(radians)
291 |         U[:, 1, 1] = np.cos(radians)
292 |         sigma = np.matmul(U, np.matmul(D, U.transpose(0, 2, 1)))
293 |         ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
294 |         xx, yy = np.meshgrid(ax, ax)
295 |         xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
296 |                         yy.reshape(kernel_size * kernel_size,
297 |                                    1))).reshape(kernel_size, kernel_size, 2)
298 |         xy = xy[None].repeat(batch, 0)
299 |         inverse_sigma = np.linalg.inv(sigma)[:, None, None]
300 |         kernel = np.exp(-0.5 * np.matmul(
301 |             np.matmul(xy[:, :, :, None], inverse_sigma), xy[:, :, :, :, None]))
302 |         kernel = kernel.reshape(batch, kernel_size, kernel_size)
303 |         if random_disturb:
304 |             kernel = kernel + np.random.uniform(
305 |                 0, 0.25, (batch, kernel_size, kernel_size)) * kernel
306 |         kernel = kernel / np.sum(kernel, (1, 2), keepdims=True)
307 | 
308 |         return torch.FloatTensor(kernel) if tensor else kernel
309 | 
310 | 
311 | def stable_batch_kernel(batch, kernel_size=21, sig=2.6, tensor=True):
312 |     sigma = sig
313 |     ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
314 |     xx, yy = np.meshgrid(ax, ax)
315 |     xx = xx[None].repeat(batch, 0)
316 |     yy = yy[None].repeat(batch, 0)
317 |     kernel = np.exp(-(xx**2 + yy**2) / (2.0 * sigma**2))
318 |     kernel = kernel / np.sum(kernel, (1, 2), keepdims=True)
319 |     return torch.FloatTensor(kernel) if tensor else kernel
320 | 
321 | 
322 | def b_Bicubic(variable, scale):
323 |     B, C, H, W = variable.size()
324 |     # H_new = int(H / scale)
325 |     # W_new = int(W / scale)
326 |     tensor_v = variable.view((B, C, H, W))
327 |     re_tensor = imresize(tensor_v, 1 / scale)
328 |     return re_tensor
329 | 
330 | 
331 | def random_batch_noise(batch, high, rate_cln=1.0):
332 |     noise_level = np.random.uniform(size=(batch, 1)) * high
333 |     noise_mask = np.random.uniform(size=(batch, 1))
334 |     noise_mask[noise_mask < rate_cln] = 0
335 |     noise_mask[noise_mask >= rate_cln] = 1
336 |     return noise_level * noise_mask
337 | 
338 | 
339 | def b_GaussianNoising(tensor,
340 |                       sigma,
341 |                       mean=0.0,
342 |                       noise_size=None,
343 |                       min=0.0,
344 |                       max=1.0):
345 |     if noise_size is None:
346 |         size = tensor.size()
347 |     else:
348 |         size = noise_size
349 |     noise = torch.mul(
350 |         torch.FloatTensor(np.random.normal(loc=mean, scale=1.0, size=size)),
351 |         sigma.view(sigma.size() + (1, 1)),
352 |     ).to(tensor.device)
353 |     return torch.clamp(noise + tensor, min=min, max=max)
354 | 
355 | 
356 | class BatchSRKernel(object):
357 | 
358 |     def __init__(
359 |         self,
360 |         kernel_size=21,
361 |         sig=2.6,
362 |         sig_min=0.2,
363 |         sig_max=4.0,
364 |         rate_iso=1.0,
365 |         random_disturb=False,
366 |     ):
367 |         self.kernel_size = kernel_size
368 |         self.sig = sig
369 |         self.sig_min = sig_min
370 |         self.sig_max = sig_max
371 |         self.rate = rate_iso
372 |         self.random_disturb = random_disturb
373 | 
374 |     def __call__(self, random, batch, tensor=False):
375 |         if random:  # random kernel
376 |             return random_batch_kernel(
377 |                 batch,
378 |                 kernel_size=self.kernel_size,
379 |                 sig_min=self.sig_min,
380 |                 sig_max=self.sig_max,
381 |                 rate_iso=self.rate,
382 |                 tensor=tensor,
383 |                 random_disturb=self.random_disturb,
384 |             )
385 |         else:  # stable kernel
386 |             return stable_batch_kernel(
387 |                 batch,
388 |                 kernel_size=self.kernel_size,
389 |                 sig=self.sig,
390 |                 tensor=tensor)
391 | 
392 | 
393 | class BatchBlurKernel(object):
394 | 
395 |     def __init__(self, kernels_path):
396 |         kernels = loadmat(kernels_path)['kernels']
397 |         self.num_kernels = kernels.shape[0]
398 |         self.kernels = kernels
399 | 
400 |     def __call__(self, random, batch, tensor=False):
401 |         index = np.random.randint(0, self.num_kernels, batch)
402 |         kernels = self.kernels[index]
403 |         return torch.FloatTensor(kernels).contiguous() if tensor else kernels
404 | 
405 | 
406 | class PCAEncoder(nn.Module):
407 | 
408 |     def __init__(self, weight):
409 |         super().__init__()
410 |         self.register_buffer('weight', weight)
411 |         self.size = self.weight.size()
412 | 
413 |     def forward(self, batch_kernel):
414 |         B, H, W = batch_kernel.size()  # [B, l, l]
415 |         return torch.bmm(
416 |             batch_kernel.view((B, 1, H * W)),
417 |             self.weight.expand((B, ) + self.size)).view((B, -1))
418 | 
419 | 
420 | class BatchBlur(object):
421 | 
422 |     def __init__(self, kernel_size=15):
423 |         self.kernel_size = kernel_size
424 |         if kernel_size % 2 == 1:
425 |             self.pad = (kernel_size // 2, kernel_size // 2, kernel_size // 2,
426 |                         kernel_size // 2)
427 |         else:
428 |             self.pad = (kernel_size // 2, kernel_size // 2 - 1,
429 |                         kernel_size // 2, kernel_size // 2 - 1)
430 |         # self.pad = nn.ZeroPad2d(l // 2)
431 | 
432 |     def __call__(self, input, kernel):
433 |         B, C, H, W = input.size()
434 |         pad = F.pad(input, self.pad, mode='reflect')
435 |         H_p, W_p = pad.size()[-2:]
436 | 
437 |         if len(kernel.size()) == 2:
438 |             input_CBHW = pad.view((C * B, 1, H_p, W_p))
439 |             kernel_var = kernel.contiguous().view(
440 |                 (1, 1, self.kernel_size, self.kernel_size))
441 |             return F.conv2d(
442 |                 input_CBHW, kernel_var, padding=0).view((B, C, H, W))
443 |         else:
444 |             input_CBHW = pad.view((1, C * B, H_p, W_p))
445 |             kernel_var = (
446 |                 kernel.contiguous().view(
447 |                     (B, 1, self.kernel_size,
448 |                      self.kernel_size)).repeat(1, C, 1, 1).view(
449 |                          (B * C, 1, self.kernel_size, self.kernel_size)))
450 |             return F.conv2d(
451 |                 input_CBHW, kernel_var, groups=B * C).view((B, C, H, W))
452 | 
453 | 
454 | class SRMDPreprocessing(object):
455 | 
456 |     def __init__(self,
457 |                  scale,
458 |                  pca_matrix,
459 |                  ksize=21,
460 |                  code_length=10,
461 |                  random_kernel=True,
462 |                  noise=False,
463 |                  cuda=False,
464 |                  random_disturb=False,
465 |                  sig=0,
466 |                  sig_min=0,
467 |                  sig_max=0,
468 |                  rate_iso=1.0,
469 |                  rate_cln=1,
470 |                  noise_high=0,
471 |                  stored_kernel=False,
472 |                  pre_kernel_path=None):
473 |         self.encoder = PCAEncoder(pca_matrix).cuda() if cuda else PCAEncoder(
474 |             pca_matrix)
475 | 
476 |         self.kernel_gen = (
477 |             BatchSRKernel(
478 |                 kernel_size=ksize,
479 |                 sig=sig,
480 |                 sig_min=sig_min,
481 |                 sig_max=sig_max,
482 |                 rate_iso=rate_iso,
483 |                 random_disturb=random_disturb,
484 |             ) if not stored_kernel else BatchBlurKernel(pre_kernel_path))
485 | 
486 |         self.blur = BatchBlur(kernel_size=ksize)
487 |         self.para_in = code_length
488 |         self.kernel_size = ksize
489 |         self.noise = noise
490 |         self.scale = scale
491 |         self.cuda = cuda
492 |         self.rate_cln = rate_cln
493 |         self.noise_high = noise_high
494 |         self.random = random_kernel
495 | 
496 |     def __call__(self, hr_tensor, kernel=False):
497 |         # hr_tensor is tensor, not cuda tensor
498 | 
499 |         hr_var = Variable(hr_tensor).cuda() if self.cuda else Variable(
500 |             hr_tensor)
501 |         device = hr_var.device
502 |         B, C, H, W = hr_var.size()
503 | 
504 |         b_kernels = Variable(self.kernel_gen(self.random, B,
505 |                                              tensor=True)).to(device)
506 |         hr_blured_var = self.blur(hr_var, b_kernels)
507 | 
508 |         # B x self.para_input
509 |         kernel_code = self.encoder(b_kernels)
510 | 
511 |         # Down sample
512 |         if self.scale != 1:
513 |             lr_blured_t = b_Bicubic(hr_blured_var, self.scale)
514 |         else:
515 |             lr_blured_t = hr_blured_var
516 | 
517 |         # Noisy
518 |         if self.noise:
519 |             Noise_level = torch.FloatTensor(
520 |                 random_batch_noise(B, self.noise_high, self.rate_cln))
521 |             lr_noised_t = b_GaussianNoising(lr_blured_t, self.noise_high)
522 |         else:
523 |             Noise_level = torch.zeros((B, 1))
524 |             lr_noised_t = lr_blured_t
525 | 
526 |         Noise_level = Variable(Noise_level).cuda()
527 |         re_code = (
528 |             torch.cat([kernel_code, Noise_level *
529 |                        10], dim=1) if self.noise else kernel_code)
530 |         lr_re = Variable(lr_noised_t).to(device)
531 | 
532 |         return (lr_re, re_code, b_kernels) if kernel else (lr_re, re_code)
533 | 
--------------------------------------------------------------------------------
/mmedit/models/restorers/dan.py:
--------------------------------------------------------------------------------
  1 | import numbers
  2 | import os.path as osp
  3 | 
  4 | import mmcv
  5 | import torch
  6 | from mmcv.runner import auto_fp16
  7 | 
  8 | from mmedit.core import psnr, ssim, tensor2img
  9 | from .basic_restorer import BasicRestorer
 10 | from ..builder import build_backbone, build_loss
 11 | from ..registry import MODELS
 12 | from ..common import SRMDPreprocessing
 13 | 
 14 | 
 15 | @MODELS.register_module()
 16 | class DAN(BasicRestorer):
 17 |     """Basic model for image restoration.
 18 | 
 19 |     It must contain a generator that takes an image as inputs and outputs a
 20 |     restored image. It also has a pixel-wise loss for training.
 21 | 
 22 |     The subclasses should overwrite the function `forward_train`,
 23 |     `forward_test` and `train_step`.
 24 | 
 25 |     Args:
 26 |         generator (dict): Config for the generator structure.
 27 |         pixel_loss (dict): Config for pixel-wise loss.
 28 |         train_cfg (dict): Config for training. Default: None.
 29 |         test_cfg (dict): Config for testing. Default: None.
 30 |         pretrained (str): Path for pretrained model. Default: None.
 31 |     """
 32 |     allowed_metrics = {'PSNR': psnr, 'SSIM': ssim}
 33 | 
 34 |     def __init__(self,
 35 |                  generator,
 36 |                  pixel_loss,
 37 |                  train_cfg=None,
 38 |                  test_cfg=None,
 39 |                  pretrained=None):
 40 |         super(BasicRestorer, self).__init__()
 41 | 
 42 |         self.train_cfg = train_cfg
 43 |         self.test_cfg = test_cfg
 44 | 
 45 |         # support fp16
 46 |         self.fp16_enabled = False
 47 | 
 48 |         # load PCA matrix of enough kernel
 49 |         if train_cfg:
 50 |             pca_matrix = torch.load(train_cfg['pca_matrix_path'], map_location=lambda storage, loc: storage)
 51 |             print('PCA matrix shape:{}'.format(pca_matrix.shape))
 52 |             self.prepro = SRMDPreprocessing(train_cfg['scale'], pca_matrix=pca_matrix, cuda=True, **train_cfg['degradation'])
 53 | 
 54 |         # generator
 55 |         self.generator = build_backbone(generator)
 56 |         # for parameters in self.generator.parameters():
 57 |         #     pass
 58 |         # print(parameters)
 59 |         self.init_weights(pretrained)
 60 | 
 61 |         # loss
 62 |         self.pixel_loss = build_loss(pixel_loss)
 63 | 
 64 |     def init_weights(self, pretrained=None, strict=True):
 65 |         """Init weights for models.
 66 | 
 67 |         Args:
 68 |             pretrained (str, optional): Path for pretrained weights. If given
 69 |                 None, pretrained weights will not be loaded. Defaults to None.
 70 |         """
 71 |         self.generator.init_weights(pretrained, strict)
 72 | 
 73 |     @auto_fp16(apply_to=('lq', ))
 74 |     def forward(self, lq, gt=None, ker_map=None, test_mode=False, **kwargs):
 75 |         """Forward function.
 76 | 
 77 |         Args:
 78 |             lq (Tensor): Input lq images.
 79 |             gt (Tensor): Ground-truth image. Default: None.
 80 |             test_mode (bool): Whether in test mode or not. Default: False.
 81 |             kwargs (dict): Other arguments.
 82 |         """
 83 | 
 84 |         if test_mode:
 85 |             return self.forward_test(lq, gt, **kwargs)
 86 | 
 87 |         return self.forward_train(lq, gt, ker_map)
 88 | 
 89 |     def forward_train(self, lq, gt, ker_map):
 90 |         """Training forward function.
 91 | 
 92 |         Args:
 93 |             lq (Tensor): LQ Tensor with shape (n, c, h, w).
 94 |             gt (Tensor): GT Tensor with shape (n, c, h, w).
 95 |             ker_map (Tensor) : ker map Tensor
 96 |         Returns:
 97 |             Tensor: Output tensor.
 98 |         """
 99 |         losses = dict()
100 |         srs, fake_kers = self.generator(lq)
101 |         output = srs[-1]
102 |         out_ker = fake_kers[-1]
103 |         loss_pix = 0
104 |         loss_pix += self.pixel_loss(out_ker, ker_map)
105 |         loss_pix += self.pixel_loss(output, gt)
106 | 
107 |         losses['loss_pix'] = loss_pix
108 |         outputs = dict(
109 |             losses=losses,
110 |             num_samples=len(gt.data),
111 |             results=dict(lq=lq.cpu(), gt=gt.cpu(), output=output.cpu()))
112 |         return outputs
113 | 
114 |     def evaluate(self, output, gt):
115 |         """Evaluation function.
116 | 
117 |         Args:
118 |             output (Tensor): Model output with shape (n, c, h, w).
119 |             gt (Tensor): GT Tensor with shape (n, c, h, w).
120 | 
121 |         Returns:
122 |             dict: Evaluation results.
123 |         """
124 |         crop_border = self.test_cfg.crop_border
125 | 
126 |         output = tensor2img(output)
127 |         gt = tensor2img(gt)
128 | 
129 |         eval_result = dict()
130 |         for metric in self.test_cfg.metrics:
131 |             eval_result[metric] = self.allowed_metrics[metric](output, gt,
132 |                                                                crop_border)
133 |         return eval_result
134 | 
135 |     def forward_test(self,
136 |                      lq,
137 |                      gt=None,
138 |                      meta=None,
139 |                      save_image=False,
140 |                      save_path=None,
141 |                      iteration=None):
142 |         """Testing forward function.
143 | 
144 |         Args:
145 |             lq (Tensor): LQ Tensor with shape (n, c, h, w).
146 |             gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None.
147 |             save_image (bool): Whether to save image. Default: False.
148 |             save_path (str): Path to save image. Default: None.
149 |             iteration (int): Iteration for the saving image name.
150 |                 Default: None.
151 | 
152 |         Returns:
153 |             dict: Output results.
154 |         """
155 |         srs, _ = self.generator(lq)
156 |         output = srs[-1]
157 |         if self.test_cfg is not None and self.test_cfg.get('metrics', None):
158 |             assert gt is not None, (
159 |                 'evaluation with metrics must have gt images.')
160 |             results = dict(eval_result=self.evaluate(output, gt))
161 |         else:
162 |             results = dict(lq=lq.cpu(), output=output.cpu())
163 |             if gt is not None:
164 |                 results['gt'] = gt.cpu()
165 | 
166 |         # save image
167 |         if save_image:
168 |             lq_path = meta[0]['lq_path']
169 |             folder_name = osp.splitext(osp.basename(lq_path))[0]
170 |             if isinstance(iteration, numbers.Number):
171 |                 save_path = osp.join(save_path, folder_name,
172 |                                      f'{folder_name}-{iteration + 1:06d}.png')
173 |             elif iteration is None:
174 |                 save_path = osp.join(save_path, f'{folder_name}.png')
175 |             else:
176 |                 raise ValueError('iteration should be number or None, '
177 |                                  f'but got {type(iteration)}')
178 |             mmcv.imwrite(tensor2img(output), save_path)
179 | 
180 |         return results
181 | 
182 |     def forward_dummy(self, GT_img):
183 |         """Forward of networks.
184 | 
185 |         Args:
186 |             gt (Tensor): GT image.
187 | 
188 |         Returns:
189 |             out (Tensor): Predicted super-resolution results (n, 3, 4h, 4w).
190 |         """
191 |         LR_img, ker_map = self.prepro(GT_img)
192 |         LR_img = (LR_img * 255).round() / 255
193 |         GT_img = GT_img.cuda()
194 |         LR_img = LR_img.cuda()
195 |         ker_map = ker_map.cuda()
196 |         return GT_img, LR_img, ker_map
197 | 
198 |     def train_step(self, data_batch, optimizer):
199 |         """Train step.
200 | 
201 |         Args:
202 |             data_batch (dict): A batch of data.
203 |             optimizer (obj): Optimizer.
204 | 
205 |         Returns:
206 |             dict: Returned output.
207 |         """
208 |         gt = data_batch['gt']
209 |         GT_img, LR_img, ker_map = self.forward_dummy(gt)
210 |         outputs = self(LR_img, GT_img, ker_map, test_mode=False)
211 |         loss, log_vars = self.parse_losses(outputs.pop('losses'))
212 | 
213 |         # optimize
214 |         optimizer['generator'].zero_grad()
215 |         loss.backward()
216 |         optimizer['generator'].step()
217 | 
218 |         outputs.update({'log_vars': log_vars})
219 |         return outputs
220 | 
221 |     def val_step(self, data_batch, **kwargs):
222 |         """Validation step.
223 | 
224 |         Args:
225 |             data_batch (dict): A batch of data.
226 |             kwargs (dict): Other arguments for ``val_step``.
227 | 
228 |         Returns:
229 |             dict: Returned output.
230 |         """
231 |         output = self.forward_test(**data_batch, **kwargs)
232 |         return output
233 | 
--------------------------------------------------------------------------------
/tools/data/super-resolution/dan_datasets/pca_matrix/pca_aniso_matrix_x2.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexZou14/DAN-Basd-on-Openmmlab/fc0d273edf7d874a641e47cedde8d5ca54abcad7/tools/data/super-resolution/dan_datasets/pca_matrix/pca_aniso_matrix_x2.pth
--------------------------------------------------------------------------------
/tools/data/super-resolution/dan_datasets/pca_matrix/pca_aniso_matrix_x4.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexZou14/DAN-Basd-on-Openmmlab/fc0d273edf7d874a641e47cedde8d5ca54abcad7/tools/data/super-resolution/dan_datasets/pca_matrix/pca_aniso_matrix_x4.pth
--------------------------------------------------------------------------------
/tools/data/super-resolution/dan_datasets/preprocess_dan_datasets.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import sys
  3 | import cv2
  4 | import numpy as np
  5 | import torch
  6 | from mmedit.models.common.DANpreprocess import tensor2img, imresize, SRMDPreprocessing, img2tensor
  7 | import random
  8 | 
  9 | def set_random_seed(seed):
 10 |     random.seed(seed)
 11 |     np.random.seed(seed)
 12 |     torch.manual_seed(seed)
 13 |     torch.cuda.manual_seed_all(seed)
 14 | 
 15 | def generate_mod_LR_bic():
 16 |     # set parameters
 17 |     up_scale = 2
 18 |     mod_scale = 4
 19 |     # set data dir
 20 |     sourcedir = "/media/lab216/dbf4a469-6c52-4371-95c4-c24c882bc23b/data/benchmark/Set5/HR"
 21 |     savedir = "/media/lab216/dbf4a469-6c52-4371-95c4-c24c882bc23b/data/benchmark/Set5/generate/"
 22 | 
 23 |     # load PCA matrix of enough kernel
 24 |     print("load PCA matrix")
 25 |     pca_matrix = torch.load(
 26 |         "/home/lab216/Desktop/mmediting/tools/data/super-resolution/div2k/pca_matrix/pca_matrix_x2.pth", map_location=lambda storage, loc: storage
 27 |     )
 28 |     print("PCA matrix shape: {}".format(pca_matrix.shape))
 29 | 
 30 |     degradation_setting = {
 31 |                               "random_kernel": False,
 32 |                               "code_length": 10,
 33 |                               "ksize": 21,
 34 |                               "pca_matrix": pca_matrix,
 35 |                               "scale": up_scale,
 36 |                               "cuda": True,
 37 |                               "rate_iso": 1.0
 38 |     }
 39 | 
 40 |     # set random seed
 41 |     set_random_seed(0)
 42 | 
 43 |     saveHRpath = os.path.join(savedir, "HR", "x" + str(mod_scale))
 44 |     saveLRpath = os.path.join(savedir, "LR", "x" + str(up_scale))
 45 |     saveBicpath = os.path.join(savedir, "Bic", "x" + str(up_scale))
 46 |     saveLRblurpath = os.path.join(savedir, "LRblur", "x" + str(up_scale))
 47 | 
 48 |     if not os.path.isdir(sourcedir):
 49 |         print("Error: No source data found")
 50 |         exit(0)
 51 |     if not os.path.isdir(savedir):
 52 |         os.mkdir(savedir)
 53 | 
 54 |     if not os.path.isdir(os.path.join(savedir, "HR")):
 55 |         os.mkdir(os.path.join(savedir, "HR"))
 56 |     if not os.path.isdir(os.path.join(savedir, "LR")):
 57 |         os.mkdir(os.path.join(savedir, "LR"))
 58 |     if not os.path.isdir(os.path.join(savedir, "Bic")):
 59 |         os.mkdir(os.path.join(savedir, "Bic"))
 60 |     if not os.path.isdir(os.path.join(savedir, "LRblur")):
 61 |         os.mkdir(os.path.join(savedir, "LRblur"))
 62 | 
 63 |     if not os.path.isdir(saveHRpath):
 64 |         os.mkdir(saveHRpath)
 65 |     else:
 66 |         print("It will cover " + str(saveHRpath))
 67 | 
 68 |     if not os.path.isdir(saveLRpath):
 69 |         os.mkdir(saveLRpath)
 70 |     else:
 71 |         print("It will cover " + str(saveLRpath))
 72 | 
 73 |     if not os.path.isdir(saveBicpath):
 74 |         os.mkdir(saveBicpath)
 75 |     else:
 76 |         print("It will cover " + str(saveBicpath))
 77 | 
 78 |     if not os.path.isdir(saveLRblurpath):
 79 |         os.mkdir(saveLRblurpath)
 80 |     else:
 81 |         print("It will cover " + str(saveLRblurpath))
 82 | 
 83 |     filepaths = sorted([f for f in os.listdir(sourcedir) if f.endswith(".png")])
 84 |     print(filepaths)
 85 |     num_files = len(filepaths)
 86 | 
 87 |     # kernel_map_tensor = torch.zeros((num_files, 1, 10)) # each kernel map: 1*10
 88 | 
 89 |     # prepare data with augementation
 90 | 
 91 |     for i in range(num_files):
 92 |         filename = filepaths[i]
 93 |         print("No.{} -- Processing {}".format(i, filename))
 94 |         # read image
 95 |         image = cv2.imread(os.path.join(sourcedir, filename))
 96 | 
 97 |         width = int(np.floor(image.shape[1] / mod_scale))
 98 |         height = int(np.floor(image.shape[0] / mod_scale))
 99 |         # modcrop
100 |         if len(image.shape) == 3:
101 |             image_HR = image[0: mod_scale * height, 0: mod_scale * width, :]
102 |         else:
103 |             image_HR = image[0: mod_scale * height, 0: mod_scale * width]
104 |         # LR_blur, by random gaussian kernel
105 |         img_HR = img2tensor(image_HR)
106 |         C, H, W = img_HR.size()
107 | 
108 |         for sig in np.linspace(1.8, 3.2, 8):
109 |             prepro = SRMDPreprocessing(sig=sig, **degradation_setting)
110 | 
111 |             LR_img, ker_map = prepro(img_HR.view(1, C, H, W))
112 |             image_LR_blur = tensor2img(LR_img)
113 |             cv2.imwrite(os.path.join(saveLRblurpath, 'sig{}_{}'.format(sig, filename)), image_LR_blur)
114 |             cv2.imwrite(os.path.join(saveHRpath, 'sig{}_{}'.format(sig, filename)), image_HR)
115 |         # LR
116 |         image_LR = imresize(image_HR, 1 / up_scale, True)
117 |         # bic
118 |         image_Bic = imresize(image_LR, up_scale, True)
119 | 
120 |         # cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
121 |         cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
122 |         cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
123 | 
124 |         # kernel_map_tensor[i] = ker_map
125 |     # save dataset corresponding kernel maps
126 |     # torch.save(kernel_map_tensor, './Set5_sig2.6_kermap.pth')
127 |     print("Image Blurring & Down smaple Done: X" + str(up_scale))
128 | 
129 | 
130 | if __name__ == "__main__":
131 |     generate_mod_LR_bic()
132 | 
--------------------------------------------------------------------------------
/tools/data/super-resolution/dan_datasets/preprocess_div2k_dataset.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import os
  3 | import os.path as osp
  4 | import re
  5 | import sys
  6 | from multiprocessing import Pool
  7 | 
  8 | import cv2
  9 | import lmdb
 10 | import mmcv
 11 | import numpy as np
 12 | 
 13 | 
 14 | def main_extract_subimages(args):
 15 |     """A multi-thread tool to crop large images to sub-images for faster IO.
 16 | 
 17 |     It is used for DIV2K dataset.
 18 | 
 19 |     opt (dict): Configuration dict. It contains:
 20 |         n_thread (int): Thread number.
 21 |         compression_level (int):  CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
 22 |             A higher value means a smaller size and longer compression time.
 23 |             Use 0 for faster CPU decompression. Default: 3, same in cv2.
 24 | 
 25 |         input_folder (str): Path to the input folder.
 26 |         save_folder (str): Path to save folder.
 27 |         crop_size (int): Crop size.
 28 |         step (int): Step for overlapped sliding window.
 29 |         thresh_size (int): Threshold size. Patches whose size is lower
 30 |             than thresh_size will be dropped.
 31 | 
 32 |     Usage:
 33 |         For each folder, run this script.
 34 |         Typically, there are four folders to be processed for DIV2K dataset.
 35 |             DIV2K_train_HR
 36 |             DIV2K_train_LR_bicubic/X2
 37 |             DIV2K_train_LR_bicubic/X3
 38 |             DIV2K_train_LR_bicubic/X4
 39 |         After process, each sub_folder should have the same number of
 40 |         subimages.
 41 |         Remember to modify opt configurations according to your settings.
 42 |     """
 43 | 
 44 |     opt = {}
 45 |     opt['n_thread'] = args.n_thread
 46 |     opt['compression_level'] = args.compression_level
 47 | 
 48 |     # HR images
 49 |     opt['input_folder'] = osp.join(args.data_root, 'DF2K_train_HR')
 50 |     opt['save_folder'] = osp.join(args.data_root, 'DF2K_train_HR_sub')
 51 |     opt['crop_size'] = args.crop_size
 52 |     opt['step'] = args.step
 53 |     opt['thresh_size'] = args.thresh_size
 54 |     extract_subimages(opt)
 55 | 
 56 | 
 57 | def extract_subimages(opt):
 58 |     """Crop images to subimages.
 59 | 
 60 |     Args:
 61 |         opt (dict): Configuration dict. It contains:
 62 |             input_folder (str): Path to the input folder.
 63 |             save_folder (str): Path to save folder.
 64 |             n_thread (int): Thread number.
 65 |     """
 66 |     input_folder = opt['input_folder']
 67 |     save_folder = opt['save_folder']
 68 |     if not osp.exists(save_folder):
 69 |         os.makedirs(save_folder)
 70 |         print(f'mkdir {save_folder} ...')
 71 |     else:
 72 |         print(f'Folder {save_folder} already exists. Exit.')
 73 |         sys.exit(1)
 74 | 
 75 |     img_list = list(mmcv.scandir(input_folder))
 76 |     img_list = [osp.join(input_folder, v) for v in img_list]
 77 | 
 78 |     prog_bar = mmcv.ProgressBar(len(img_list))
 79 |     pool = Pool(opt['n_thread'])
 80 |     for path in img_list:
 81 |         pool.apply_async(
 82 |             worker, args=(path, opt), callback=lambda arg: prog_bar.update())
 83 |     pool.close()
 84 |     pool.join()
 85 |     print('All processes done.')
 86 | 
 87 | 
 88 | def worker(path, opt):
 89 |     """Worker for each process.
 90 | 
 91 |     Args:
 92 |         path (str): Image path.
 93 |         opt (dict): Configuration dict. It contains:
 94 |             crop_size (int): Crop size.
 95 |             step (int): Step for overlapped sliding window.
 96 |             thresh_size (int): Threshold size. Patches whose size is smaller
 97 |                 than thresh_size will be dropped.
 98 |             save_folder (str): Path to save folder.
 99 |             compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
100 | 
101 |     Returns:
102 |         process_info (str): Process information displayed in progress bar.
103 |     """
104 |     crop_size = opt['crop_size']
105 |     step = opt['step']
106 |     thresh_size = opt['thresh_size']
107 |     img_name, extension = osp.splitext(osp.basename(path))
108 | 
109 |     # remove the x2, x3, x4 and x8 in the filename for DIV2K
110 |     img_name = re.sub('x[2348]', '', img_name)
111 | 
112 |     img = mmcv.imread(path, flag='unchanged')
113 | 
114 |     if img.ndim == 2 or img.ndim == 3:
115 |         h, w = img.shape[:2]
116 |     else:
117 |         raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}')
118 | 
119 |     h_space = np.arange(0, h - crop_size + 1, step)
120 |     if h - (h_space[-1] + crop_size) > thresh_size:
121 |         h_space = np.append(h_space, h - crop_size)
122 |     w_space = np.arange(0, w - crop_size + 1, step)
123 |     if w - (w_space[-1] + crop_size) > thresh_size:
124 |         w_space = np.append(w_space, w - crop_size)
125 | 
126 |     index = 0
127 |     for x in h_space:
128 |         for y in w_space:
129 |             index += 1
130 |             cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
131 |             cv2.imwrite(
132 |                 osp.join(opt['save_folder'],
133 |                          f'{img_name}_s{index:03d}{extension}'), cropped_img,
134 |                 [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
135 |     process_info = f'Processing {img_name} ...'
136 |     return process_info
137 | 
138 | 
139 | def make_lmdb_for_div2k(data_root):
140 |     """Create lmdb files for DIV2K dataset.
141 | 
142 |     Args:
143 |         data_root (str): Data root path.
144 | 
145 |     Usage:
146 |         Typically, there are four folders to be processed for DIV2K dataset.
147 |             DIV2K_train_HR_sub
148 |             DIV2K_train_LR_bicubic/X2_sub
149 |             DIV2K_train_LR_bicubic/X3_sub
150 |             DIV2K_train_LR_bicubic/X4_sub
151 |         Remember to modify opt configurations according to your settings.
152 |     """
153 | 
154 |     folder_paths = [
155 |         osp.join(data_root, 'DF2K_train_HR_sub')
156 |     ]
157 |     lmdb_paths = [
158 |         osp.join(data_root, 'DF2K_train_HR_sub.lmdb')
159 |     ]
160 | 
161 |     for folder_path, lmdb_path in zip(folder_paths, lmdb_paths):
162 |         img_path_list, keys = prepare_keys_div2k(folder_path)
163 |         make_lmdb(folder_path, lmdb_path, img_path_list, keys)
164 | 
165 | 
166 | def prepare_keys_div2k(folder_path):
167 |     """Prepare image path list and keys for DIV2K dataset.
168 | 
169 |     Args:
170 |         folder_path (str): Folder path.
171 | 
172 |     Returns:
173 |         list[str]: Image path list.
174 |         list[str]: Key list.
175 |     """
176 |     print('Reading image path list ...')
177 |     img_path_list = sorted(
178 |         list(mmcv.scandir(folder_path, suffix='png', recursive=False)))
179 |     keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)]
180 | 
181 |     return img_path_list, keys
182 | 
183 | 
184 | def make_lmdb(data_path,
185 |               lmdb_path,
186 |               img_path_list,
187 |               keys,
188 |               batch=5000,
189 |               compress_level=1,
190 |               multiprocessing_read=False,
191 |               n_thread=40):
192 |     """Make lmdb.
193 | 
194 |     Contents of lmdb. The file structure is:
195 |     example.lmdb
196 |     ├── data.mdb
197 |     ├── lock.mdb
198 |     ├── meta_info.txt
199 | 
200 |     The data.mdb and lock.mdb are standard lmdb files and you can refer to
201 |     https://lmdb.readthedocs.io/en/release/ for more details.
202 | 
203 |     The meta_info.txt is a specified txt file to record the meta information
204 |     of our datasets. It will be automatically created when preparing
205 |     datasets by our provided dataset tools.
206 |     Each line in the txt file records 1)image name (with extension),
207 |     2)image shape, and 3)compression level, separated by a white space.
208 | 
209 |     For example, the meta information could be:
210 |     `000_00000000.png (720,1280,3) 1`, which means:
211 |     1) image name (with extension): 000_00000000.png;
212 |     2) image shape: (720,1280,3);
213 |     3) compression level: 1
214 | 
215 |     We use the image name without extension as the lmdb key.
216 | 
217 |     If `multiprocessing_read` is True, it will read all the images to memory
218 |     using multiprocessing. Thus, your server needs to have enough memory.
219 | 
220 |     Args:
221 |         data_path (str): Data path for reading images.
222 |         lmdb_path (str): Lmdb save path.
223 |         img_path_list (str): Image path list.
224 |         keys (str): Used for lmdb keys.
225 |         batch (int): After processing batch images, lmdb commits.
226 |             Default: 5000.
227 |         compress_level (int): Compress level when encoding images. Default: 1.
228 |         multiprocessing_read (bool): Whether use multiprocessing to read all
229 |             the images to memory. Default: False.
230 |         n_thread (int): For multiprocessing.
231 |     """
232 |     assert len(img_path_list) == len(keys), (
233 |         'img_path_list and keys should have the same length, '
234 |         f'but got {len(img_path_list)} and {len(keys)}')
235 |     print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
236 |     print(f'Total images: {len(img_path_list)}')
237 |     if not lmdb_path.endswith('.lmdb'):
238 |         raise ValueError("lmdb_path must end with '.lmdb'.")
239 |     if osp.exists(lmdb_path):
240 |         print(f'Folder {lmdb_path} already exists. Exit.')
241 |         sys.exit(1)
242 | 
243 |     if multiprocessing_read:
244 |         # read all the images to memory (multiprocessing)
245 |         dataset = {}  # use dict to keep the order for multiprocessing
246 |         shapes = {}
247 |         print(f'Read images with multiprocessing, #thread: {n_thread} ...')
248 |         prog_bar = mmcv.ProgressBar(len(img_path_list))
249 | 
250 |         def callback(arg):
251 |             """get the image data and update prog_bar."""
252 |             key, dataset[key], shapes[key] = arg
253 |             prog_bar.update()
254 | 
255 |         pool = Pool(n_thread)
256 |         for path, key in zip(img_path_list, keys):
257 |             pool.apply_async(
258 |                 read_img_worker,
259 |                 args=(osp.join(data_path, path), key, compress_level),
260 |                 callback=callback)
261 |         pool.close()
262 |         pool.join()
263 |         print(f'Finish reading {len(img_path_list)} images.')
264 | 
265 |     # create lmdb environment
266 |     # obtain data size for one image
267 |     img = mmcv.imread(osp.join(data_path, img_path_list[0]), flag='unchanged')
268 |     _, img_byte = cv2.imencode('.png', img,
269 |                                [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
270 |     data_size_per_img = img_byte.nbytes
271 |     print('Data size per image is: ', data_size_per_img)
272 |     data_size = data_size_per_img * len(img_path_list)
273 |     env = lmdb.open(lmdb_path, map_size=data_size * 10)
274 | 
275 |     # write data to lmdb
276 |     prog_bar = mmcv.ProgressBar(len(img_path_list))
277 |     txn = env.begin(write=True)
278 |     txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
279 |     for idx, (path, key) in enumerate(zip(img_path_list, keys)):
280 |         prog_bar.update()
281 |         key_byte = key.encode('ascii')
282 |         if multiprocessing_read:
283 |             img_byte = dataset[key]
284 |             h, w, c = shapes[key]
285 |         else:
286 |             _, img_byte, img_shape = read_img_worker(
287 |                 osp.join(data_path, path), key, compress_level)
288 |             h, w, c = img_shape
289 | 
290 |         txn.put(key_byte, img_byte)
291 |         # write meta information
292 |         txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
293 |         if idx % batch == 0:
294 |             txn.commit()
295 |             txn = env.begin(write=True)
296 |     txn.commit()
297 |     env.close()
298 |     txt_file.close()
299 |     print('\nFinish writing lmdb.')
300 | 
301 | 
302 | def read_img_worker(path, key, compress_level):
303 |     """Read image worker
304 | 
305 |     Args:
306 |         path (str): Image path.
307 |         key (str): Image key.
308 |         compress_level (int): Compress level when encoding images.
309 | 
310 |     Returns:
311 |         str: Image key.
312 |         byte: Image byte.
313 |         tuple[int]: Image shape.
314 |     """
315 |     img = mmcv.imread(path, flag='unchanged')
316 |     if img.ndim == 2:
317 |         h, w = img.shape
318 |         c = 1
319 |     else:
320 |         h, w, c = img.shape
321 |     _, img_byte = cv2.imencode('.png', img,
322 |                                [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
323 |     return (key, img_byte, (h, w, c))
324 | 
325 | 
326 | def parse_args():
327 |     parser = argparse.ArgumentParser(
328 |         description='Prepare DIV2K dataset',
329 |         formatter_class=argparse.ArgumentDefaultsHelpFormatter)
330 |     parser.add_argument('--data-root', help='dataset root')
331 |     parser.add_argument(
332 |         '--crop-size',
333 |         nargs='?',
334 |         default=256,
335 |         help='cropped size for HR images')
336 |     parser.add_argument(
337 |         '--step', nargs='?', default=128, help='step size for HR images')
338 |     parser.add_argument(
339 |         '--thresh-size',
340 |         nargs='?',
341 |         default=0,
342 |         help='threshold size for HR images')
343 |     parser.add_argument(
344 |         '--compression-level',
345 |         nargs='?',
346 |         default=3,
347 |         help='compression level when save png images')
348 |     parser.add_argument(
349 |         '--n-thread',
350 |         nargs='?',
351 |         default=4,
352 |         help='thread number when using multiprocessing')
353 |     parser.add_argument(
354 |         '--make-lmdb',
355 |         action='store_true',
356 |         help='whether to prepare lmdb files')
357 |     args = parser.parse_args()
358 |     return args
359 | 
360 | 
361 | if __name__ == '__main__':
362 |     args = parse_args()
363 |     # # extract subimages
364 |     main_extract_subimages(args)
365 |     # # prepare lmdb files if necessary
366 |     if args.make_lmdb:
367 |         make_lmdb_for_div2k(args.data_root)
368 | 
--------------------------------------------------------------------------------