├── README.md ├── ckpt_convert ├── BiSeNetV2.py ├── ckpt_convert.py ├── convert_onnx.py ├── ema.py ├── nafa_archv1.py ├── readme └── 说明文档.txt ├── compute_mask.py ├── data └── dataloader.py ├── gauss.py ├── loss ├── Loss.py └── PSNRLoss.py ├── losses.py ├── models ├── Model.py ├── discriminator.py ├── idr.py ├── networks.py ├── non_local.py ├── readme ├── sa_aidr.py └── sa_gan.py ├── predict.py ├── submit_dehw.zip ├── test.py ├── test.sh ├── train.py ├── train.sh ├── utils.py └── zip.sh /README.md: -------------------------------------------------------------------------------- 1 | # 手写文字擦除第1名方案,水印智能消除赛第1名方案 2 | 手写文字擦除第1名,水印智能消除赛第1名方案 3 | 比赛连接:[手写文字擦除](https://aistudio.baidu.com/aistudio/competition/detail/129/0/introduction) 4 | 比赛连接:[水印智能消除赛](https://aistudio.baidu.com/aistudio/competition/detail/209/0/introduction) 5 | 6 | 2024-04-23: 关于在比赛数据集之外效果差的问题: 7 | 比赛数据是合成的,不能模拟到真实场景,所以在真实场景中效果比较差,需要在私有数据集中进行finetune. 8 | 9 | 2024-04-12: 模型已上传,在https://github.com/zdyshine/Baidu-netdisk-AI-Image-processing-Challenge-handwriting/releases/tag/checkpoints 10 | ## 一、赛题背景 11 | 对比赛给定的带有手写痕迹的试卷图片进行处理,擦除相关的笔,还原图片原本的样子 12 | ![](https://ai-studio-static-online.cdn.bcebos.com/af2816877d054080987de1f47679fa656e5f498fd39744f5a9f94cc6c5a4fb9d) 13 | 14 | ## 二、数据分析 15 | **数据划分**:使用1000张做为训练集,81张作为验证集。 16 | 官方提供了训练集1081对,测试集A、B各200张。包含以下几个特征: 17 | 1.图像分辨率普遍较大 18 | 2.手写字包含红黑蓝多种颜色,印刷字基本为黑色 19 | 3.手写字除了正常文字外,还包含手画的线段、图案等内容 20 | 4.试卷上的污渍、脏点也属于需要去除的内容 21 | 5.手写字和印刷字存在重叠 22 | 23 | **mask**:根据原始图片和标签图像的差值来生成mask数据 24 | 计算RGB通道的平均差值 25 | 平均差值在20以上的设为 1 26 | 平均差值在20以下的设为 差值/20 27 | 28 | ![](https://ai-studio-static-online.cdn.bcebos.com/255b0b9dd6e8426fae2d9f01c6bd17229fd4dbb37a5741539ba8d8ea87fd10f3) 29 | 30 | ## 三、模型设计 31 | 网络模型,是基于开源的EraseNet,然后整体改成了Paddle版本。同时也尝试了最新的PERT:一种基于区域的迭代场景文字擦除网络。基于对比实验,发现ErastNet,在本批次数据集上效果更好。从网络结构图上可以直观的看出ErastNet是多分支以及多阶段网络其中包括mask生成分支和两阶段图像生成分支。此外整个网络也都是基于多尺度结构。在损失函数上,原版的ErastNet使用了感知损失以及GAN损失。两个损失函数,是为了生成更加逼真的背景。但是本赛题任务的背景都是纯白,这两个损失是不需要的,可以直接去除。此外,由于ErastNet网络是由多尺度网络组成,结合去摩尔纹比赛的经验,我把ErastNet网络的Refinement替换成了去摩尔纹比赛使用的多尺度网络 32 | 双模型融合: 33 | 模型一:erasenet去掉判别器部分,仅保留生成器 34 | ![](https://ai-studio-static-online.cdn.bcebos.com/7546d26870a44fce9b5f118b8fc8e8501b7f4ed1e807468ebece4c9d21209ac0) 35 | 模型二:erasenet二阶段网络使用基于Non-Local的深度编解码结构 36 | ![](https://ai-studio-static-online.cdn.bcebos.com/67f2b22dca8a491cad844354f2ba81601190f4bda4e44524a115b8c715bedbfb) 37 | 38 | ## 四、训练细节 39 | 40 | **训练数据:** 41 | 增强仅使用横向翻转和小角度旋转,保留文字的先验 42 | 随机crop成512x512的patch进行训练 43 | 44 | **训练分为两阶段:** 45 | 第一阶段损失函数为dice_loss + l1 loss 46 | 第二阶段损失函数只保留l1 loss 47 | 48 | ## 五、测试细节 49 | 50 | 测试trick: 51 | **分块测试**,把图像切分为512x512的小块进行预测,保持和训练一致 52 | **交错分块测试**,测试图像增加镜像padding,且分块时边缘包含重复部分,每次预测仅保留每块预测结果的中心部分,这么做的原因是图像边缘信息较少,预测效果要差于中心部分 53 | 测试时对**测试**数据使用了横向的镜像**增强** 54 | 测试时将两个**模型**的预测结果进行**融合** 55 | 56 | ## 六、上分策略 57 | 58 | ![](https://ai-studio-static-online.cdn.bcebos.com/88dd53709c1f47aca80f9ce63e344e8494c44c59b9534367b7aa4b5b0034caad) 59 | 60 | ## 七、其他 61 | data:定义数据加载 62 | loss:定义损失函数 63 | model:定义网络模型 64 | compute_mask.py:生成mask文件 65 | test.py: 测试脚本 66 | train.py: 训练脚本 67 | 68 | 代码运行: 69 | 1.指定数据文件夹 70 | 2.运行sh train.sh 生成mask并开始训练 71 | 3.指定测试文件夹和模型路径,执行sh test.sh开始测试 72 | ## 预训练模型 73 | https://aistudio.baidu.com/aistudio/projectdetail/3439691 74 | 运行项目,下载预训练模型,同时可以进行在线测试。 75 | -------------------------------------------------------------------------------- /ckpt_convert/BiSeNetV2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # import os 16 | 17 | import paddle 18 | import paddle.nn as nn 19 | import paddle.nn.functional as F 20 | 21 | # from paddleseg import utils 22 | # from paddleseg.cvlibs import manager, param_init 23 | # from .layer import layers 24 | 25 | 26 | # @manager.MODELS.add_component 27 | class BiSeNetV2(nn.Layer): 28 | """ 29 | The BiSeNet V2 implementation based on PaddlePaddle. 30 | 31 | The original article refers to 32 | Yu, Changqian, et al. "BiSeNet V2: Bilateral Network with Guided Aggregation for Real-time Semantic Segmentation" 33 | (https://arxiv.org/abs/2004.02147) 34 | 35 | Args: 36 | num_classes (int): The unique number of target classes. 37 | lambd (float, optional): A factor for controlling the size of semantic branch channels. Default: 0.25. 38 | pretrained (str, optional): The path or url of pretrained model. Default: None. 39 | """ 40 | 41 | def __init__(self, 42 | num_classes, 43 | lambd=0.25, 44 | align_corners=False, 45 | pretrained=None): 46 | super().__init__() 47 | 48 | C1, C2, C3 = 64, 64, 128 49 | db_channels = (C1, C2, C3) 50 | C1, C3, C4, C5 = int(C1 * lambd), int(C3 * lambd), 64, 128 51 | sb_channels = (C1, C3, C4, C5) 52 | mid_channels = 128 53 | 54 | self.db = DetailBranch(db_channels) 55 | self.sb = SemanticBranch(sb_channels) 56 | 57 | self.bga = BGA(mid_channels, align_corners) 58 | self.aux_head1 = SegHead(C1, C1, num_classes) 59 | self.aux_head2 = SegHead(C3, C3, num_classes) 60 | self.aux_head3 = SegHead(C4, C4, num_classes) 61 | self.aux_head4 = SegHead(C5, C5, num_classes) 62 | self.head = SegHead(mid_channels, mid_channels, num_classes) 63 | 64 | self.align_corners = align_corners 65 | self.pretrained = pretrained 66 | 67 | def forward(self, x): 68 | dfm = self.db(x) 69 | feat1, feat2, feat3, feat4, sfm = self.sb(x) 70 | logit = self.head(self.bga(dfm, sfm)) 71 | 72 | if not self.training: 73 | logit_list = [logit] 74 | else: 75 | logit1 = self.aux_head1(feat1) 76 | logit2 = self.aux_head2(feat2) 77 | logit3 = self.aux_head3(feat3) 78 | logit4 = self.aux_head4(feat4) 79 | logit_list = [logit, logit1, logit2, logit3, logit4] 80 | 81 | logit_list = [ 82 | F.interpolate( 83 | logit, 84 | paddle.shape(x)[2:], 85 | mode='bilinear', 86 | align_corners=self.align_corners) for logit in logit_list 87 | ] 88 | 89 | return logit_list 90 | 91 | # def init_weight(self): 92 | # if self.pretrained is not None: 93 | # utils.load_entire_model(self, self.pretrained) 94 | # else: 95 | # for sublayer in self.sublayers(): 96 | # if isinstance(sublayer, nn.Conv2D): 97 | # param_init.kaiming_normal_init(sublayer.weight) 98 | # elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)): 99 | # param_init.constant_init(sublayer.weight, value=1.0) 100 | # param_init.constant_init(sublayer.bias, value=0.0) 101 | 102 | 103 | class Activation(nn.Layer): 104 | 105 | def __init__(self, act=None): 106 | super(Activation, self).__init__() 107 | 108 | self._act = act 109 | upper_act_names = nn.layer.activation.__dict__.keys() 110 | lower_act_names = [act.lower() for act in upper_act_names] 111 | act_dict = dict(zip(lower_act_names, upper_act_names)) 112 | 113 | if act is not None: 114 | if act in act_dict.keys(): 115 | act_name = act_dict[act] 116 | self.act_func = eval("nn.layer.activation.{}()".format( 117 | act_name)) 118 | else: 119 | raise KeyError("{} does not exist in the current {}".format( 120 | act, act_dict.keys())) 121 | 122 | def forward(self, x): 123 | if self._act is not None: 124 | return self.act_func(x) 125 | else: 126 | return x 127 | import os 128 | 129 | def SyncBatchNorm(*args, **kwargs): 130 | """In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm2D instead""" 131 | if paddle.get_device() == 'cpu' or os.environ.get('PADDLESEG_EXPORT_STAGE'): 132 | return nn.BatchNorm2D(*args, **kwargs) 133 | elif paddle.distributed.ParallelEnv().nranks == 1: 134 | return nn.BatchNorm2D(*args, **kwargs) 135 | else: 136 | return nn.SyncBatchNorm(*args, **kwargs) 137 | 138 | class ConvBNReLU(nn.Layer): 139 | def __init__(self, 140 | in_channels, 141 | out_channels, 142 | kernel_size, 143 | padding='same', 144 | **kwargs): 145 | super().__init__() 146 | 147 | self._conv = nn.Conv2D( 148 | in_channels, out_channels, kernel_size, padding=padding, **kwargs) 149 | 150 | if 'data_format' in kwargs: 151 | data_format = kwargs['data_format'] 152 | else: 153 | data_format = 'NCHW' 154 | self._batch_norm = SyncBatchNorm(out_channels, data_format=data_format) 155 | self._relu = Activation("relu") 156 | 157 | def forward(self, x): 158 | x = self._conv(x) 159 | x = self._batch_norm(x) 160 | x = self._relu(x) 161 | return x 162 | 163 | class StemBlock(nn.Layer): 164 | def __init__(self, in_dim, out_dim): 165 | super(StemBlock, self).__init__() 166 | 167 | self.conv = ConvBNReLU(in_dim, out_dim, 3, stride=2) 168 | 169 | self.left = nn.Sequential( 170 | ConvBNReLU(out_dim, out_dim // 2, 1), 171 | ConvBNReLU( 172 | out_dim // 2, out_dim, 3, stride=2)) 173 | 174 | self.right = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) 175 | 176 | self.fuse = ConvBNReLU(out_dim * 2, out_dim, 3) 177 | 178 | def forward(self, x): 179 | x = self.conv(x) 180 | left = self.left(x) 181 | right = self.right(x) 182 | concat = paddle.concat([left, right], axis=1) 183 | return self.fuse(concat) 184 | 185 | class ConvBN(nn.Layer): 186 | def __init__(self, 187 | in_channels, 188 | out_channels, 189 | kernel_size, 190 | padding='same', 191 | **kwargs): 192 | super().__init__() 193 | self._conv = nn.Conv2D( 194 | in_channels, out_channels, kernel_size, padding=padding, **kwargs) 195 | if 'data_format' in kwargs: 196 | data_format = kwargs['data_format'] 197 | else: 198 | data_format = 'NCHW' 199 | self._batch_norm = SyncBatchNorm(out_channels, data_format=data_format) 200 | 201 | def forward(self, x): 202 | x = self._conv(x) 203 | x = self._batch_norm(x) 204 | return x 205 | 206 | 207 | class Add(nn.Layer): 208 | def __init__(self): 209 | super().__init__() 210 | 211 | def forward(self, x, y, name=None): 212 | return paddle.add(x, y, name) 213 | 214 | 215 | class ContextEmbeddingBlock(nn.Layer): 216 | def __init__(self, in_dim, out_dim): 217 | super(ContextEmbeddingBlock, self).__init__() 218 | 219 | self.gap = nn.AdaptiveAvgPool2D(1) 220 | self.bn = SyncBatchNorm(in_dim) 221 | 222 | self.conv_1x1 = ConvBNReLU(in_dim, out_dim, 1) 223 | self.add = Add() 224 | self.conv_3x3 = nn.Conv2D(out_dim, out_dim, 3, 1, 1) 225 | 226 | def forward(self, x): 227 | gap = self.gap(x) 228 | bn = self.bn(gap) 229 | conv1 = self.add(self.conv_1x1(bn), x) 230 | return self.conv_3x3(conv1) 231 | 232 | class DepthwiseConvBN(nn.Layer): 233 | def __init__(self, 234 | in_channels, 235 | out_channels, 236 | kernel_size, 237 | padding='same', 238 | **kwargs): 239 | super().__init__() 240 | self.depthwise_conv = ConvBN( 241 | in_channels, 242 | out_channels=out_channels, 243 | kernel_size=kernel_size, 244 | padding=padding, 245 | groups=in_channels, 246 | **kwargs) 247 | 248 | def forward(self, x): 249 | x = self.depthwise_conv(x) 250 | return x 251 | 252 | class GatherAndExpansionLayer1(nn.Layer): 253 | """Gather And Expansion Layer with stride 1""" 254 | 255 | def __init__(self, in_dim, out_dim, expand): 256 | super().__init__() 257 | 258 | expand_dim = expand * in_dim 259 | 260 | self.conv = nn.Sequential( 261 | ConvBNReLU(in_dim, in_dim, 3), 262 | DepthwiseConvBN(in_dim, expand_dim, 3), 263 | ConvBN(expand_dim, out_dim, 1)) 264 | self.relu = Activation("relu") 265 | 266 | def forward(self, x): 267 | return self.relu(self.conv(x) + x) 268 | 269 | 270 | class GatherAndExpansionLayer2(nn.Layer): 271 | """Gather And Expansion Layer with stride 2""" 272 | 273 | def __init__(self, in_dim, out_dim, expand): 274 | super().__init__() 275 | 276 | expand_dim = expand * in_dim 277 | 278 | self.branch_1 = nn.Sequential( 279 | ConvBNReLU(in_dim, in_dim, 3), 280 | DepthwiseConvBN( 281 | in_dim, expand_dim, 3, stride=2), 282 | DepthwiseConvBN(expand_dim, expand_dim, 3), 283 | ConvBN(expand_dim, out_dim, 1)) 284 | 285 | self.branch_2 = nn.Sequential( 286 | DepthwiseConvBN( 287 | in_dim, in_dim, 3, stride=2), 288 | ConvBN(in_dim, out_dim, 1)) 289 | 290 | self.relu = Activation("relu") 291 | 292 | def forward(self, x): 293 | return self.relu(self.branch_1(x) + self.branch_2(x)) 294 | 295 | 296 | class DetailBranch(nn.Layer): 297 | """The detail branch of BiSeNet, which has wide channels but shallow layers.""" 298 | 299 | def __init__(self, in_channels): 300 | super().__init__() 301 | 302 | C1, C2, C3 = in_channels 303 | 304 | self.convs = nn.Sequential( 305 | # stage 1 306 | ConvBNReLU( 307 | 3, C1, 3, stride=2), 308 | ConvBNReLU(C1, C1, 3), 309 | # stage 2 310 | ConvBNReLU( 311 | C1, C2, 3, stride=2), 312 | ConvBNReLU(C2, C2, 3), 313 | ConvBNReLU(C2, C2, 3), 314 | # stage 3 315 | ConvBNReLU( 316 | C2, C3, 3, stride=2), 317 | ConvBNReLU(C3, C3, 3), 318 | ConvBNReLU(C3, C3, 3), ) 319 | 320 | def forward(self, x): 321 | return self.convs(x) 322 | 323 | 324 | class SemanticBranch(nn.Layer): 325 | """The semantic branch of BiSeNet, which has narrow channels but deep layers.""" 326 | 327 | def __init__(self, in_channels): 328 | super().__init__() 329 | C1, C3, C4, C5 = in_channels 330 | 331 | self.stem = StemBlock(3, C1) 332 | 333 | self.stage3 = nn.Sequential( 334 | GatherAndExpansionLayer2(C1, C3, 6), 335 | GatherAndExpansionLayer1(C3, C3, 6)) 336 | 337 | self.stage4 = nn.Sequential( 338 | GatherAndExpansionLayer2(C3, C4, 6), 339 | GatherAndExpansionLayer1(C4, C4, 6)) 340 | 341 | self.stage5_4 = nn.Sequential( 342 | GatherAndExpansionLayer2(C4, C5, 6), 343 | GatherAndExpansionLayer1(C5, C5, 6), 344 | GatherAndExpansionLayer1(C5, C5, 6), 345 | GatherAndExpansionLayer1(C5, C5, 6)) 346 | 347 | self.ce = ContextEmbeddingBlock(C5, C5) 348 | 349 | def forward(self, x): 350 | stage2 = self.stem(x) 351 | stage3 = self.stage3(stage2) 352 | stage4 = self.stage4(stage3) 353 | stage5_4 = self.stage5_4(stage4) 354 | fm = self.ce(stage5_4) 355 | return stage2, stage3, stage4, stage5_4, fm 356 | 357 | 358 | class BGA(nn.Layer): 359 | """The Bilateral Guided Aggregation Layer, used to fuse the semantic features and spatial features.""" 360 | 361 | def __init__(self, out_dim, align_corners): 362 | super().__init__() 363 | 364 | self.align_corners = align_corners 365 | 366 | self.db_branch_keep = nn.Sequential( 367 | DepthwiseConvBN(out_dim, out_dim, 3), 368 | nn.Conv2D(out_dim, out_dim, 1)) 369 | 370 | self.db_branch_down = nn.Sequential( 371 | ConvBN( 372 | out_dim, out_dim, 3, stride=2), 373 | nn.AvgPool2D( 374 | kernel_size=3, stride=2, padding=1)) 375 | 376 | self.sb_branch_keep = nn.Sequential( 377 | DepthwiseConvBN(out_dim, out_dim, 3), 378 | nn.Conv2D(out_dim, out_dim, 1), 379 | Activation(act='sigmoid')) 380 | 381 | self.sb_branch_up = ConvBN(out_dim, out_dim, 3) 382 | 383 | self.conv = ConvBN(out_dim, out_dim, 3) 384 | 385 | def forward(self, dfm, sfm): 386 | db_feat_keep = self.db_branch_keep(dfm) 387 | db_feat_down = self.db_branch_down(dfm) 388 | sb_feat_keep = self.sb_branch_keep(sfm) 389 | 390 | sb_feat_up = self.sb_branch_up(sfm) 391 | sb_feat_up = F.interpolate( 392 | sb_feat_up, 393 | paddle.shape(db_feat_keep)[2:], 394 | mode='bilinear', 395 | align_corners=self.align_corners) 396 | 397 | sb_feat_up = F.sigmoid(sb_feat_up) 398 | db_feat = db_feat_keep * sb_feat_up 399 | 400 | sb_feat = db_feat_down * sb_feat_keep 401 | sb_feat = F.interpolate( 402 | sb_feat, 403 | paddle.shape(db_feat)[2:], 404 | mode='bilinear', 405 | align_corners=self.align_corners) 406 | 407 | return self.conv(db_feat + sb_feat) 408 | 409 | 410 | class SegHead(nn.Layer): 411 | def __init__(self, in_dim, mid_dim, num_classes): 412 | super().__init__() 413 | 414 | self.conv_3x3 = nn.Sequential( 415 | ConvBNReLU(in_dim, mid_dim, 3), nn.Dropout(0.1)) 416 | 417 | self.conv_1x1 = nn.Conv2D(mid_dim, num_classes, 1, 1) 418 | 419 | def forward(self, x): 420 | conv1 = self.conv_3x3(x) 421 | conv2 = self.conv_1x1(conv1) 422 | return conv2 423 | -------------------------------------------------------------------------------- /ckpt_convert/ckpt_convert.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 训练的动态图模型转静态图模型 3 | ''' 4 | 5 | import paddle 6 | from paddle.static import InputSpec 7 | from BiSeNetV2 import BiSeNetV2 8 | from nafa_archv1 import NAFNet 9 | ######################################################################## 10 | model = BiSeNetV2(num_classes=2) 11 | weights = paddle.load('../maskseg/output_enhance/best_model/model.pdparams') # 0923, 0.70293模型 12 | model.load_dict(weights) 13 | model.eval() 14 | 15 | # step 2: 定义 InputSpec 信息 16 | x_spec = InputSpec(shape=[1, 3, 1024, 1024], dtype='float32', name='x') 17 | 18 | # step 3: 调用 jit.save 接口 19 | net = paddle.jit.save(model, path='../submit/stac/seg', input_spec=[x_spec]) # 动静转换 20 | ######################################################################## 21 | model = NAFNet(img_channel=3, width=32, middle_blk_num=8, 22 | enc_blk_nums=[1, 1, 2, 2], dec_blk_nums=[2, 2, 1, 1], decmask_blk_nums=[1, 1, 1, 1]) 23 | weights = paddle.load("model_ema.pdparams") 24 | model.load_dict(weights) 25 | model.eval() 26 | 27 | # step 2: 定义 InputSpec 信息 28 | x_spec = InputSpec(shape=[1, 3, 480, 480], dtype='float32', name='x') 29 | 30 | # step 3: 调用 jit.save 接口 31 | net = paddle.jit.save(model, path='../submit/stac/ema', input_spec=[x_spec]) # 动静转换 32 | # ######################################################################## 33 | 34 | # # 按照要求安装环境 35 | # !pip install onnx==1.10.1 onnxruntime-gpu==1.10 paddle2onnx 36 | # 37 | # !paddle2onnx --model_dir ./stac_restormer --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 13 --save_file result_restormer.onnx -------------------------------------------------------------------------------- /ckpt_convert/convert_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from PIL import Image 3 | import torch 4 | import torchvision.transforms as T 5 | import onnx 6 | import os 7 | from torchviz import make_dot 8 | import onnxruntime as ort 9 | import numpy as np 10 | import copy 11 | 12 | parser = argparse.ArgumentParser(description='Test inpainting') 13 | parser.add_argument("--image", type=str, 14 | default="examples/inpaint/case1.png", help="path to the image file") 15 | parser.add_argument("--mask", type=str, 16 | default="examples/inpaint/case1_mask.png", help="path to the mask file") 17 | parser.add_argument("--out", type=str, 18 | default="examples/inpaint/case1_out_test.png", help="path for the output file") 19 | parser.add_argument("--checkpoint", type=str, 20 | default="pretrained/states_tf_places2.pth", help="path to the checkpoint file") 21 | 22 | 23 | def main(): 24 | 25 | args = parser.parse_args() 26 | 27 | generator_state_dict = torch.load(args.checkpoint)['G'] 28 | 29 | if 'stage1.conv1.conv.weight' in generator_state_dict.keys(): 30 | from model.networks_test import Generator 31 | else: 32 | from model.networks_tf import Generator 33 | 34 | use_cuda_if_available = False 35 | device = torch.device('cuda' if torch.cuda.is_available() 36 | and use_cuda_if_available else 'cpu') 37 | 38 | # set up network 39 | generator = Generator(cnum_in=5, cnum=48, return_flow=False).to(device) 40 | 41 | generator_state_dict = torch.load(args.checkpoint)['G'] 42 | generator.load_state_dict(generator_state_dict, strict=True) 43 | generator.eval() 44 | 45 | print('Converting checkpoint to onnx format') 46 | ckpt_name = os.path.basename(args.checkpoint).split('.')[0] 47 | ckpt_dir = os.path.dirname(args.checkpoint) 48 | 49 | # load image and mask 50 | image = Image.open(args.image) 51 | mask = Image.open(args.mask) 52 | 53 | # prepare input 54 | image = T.ToTensor()(image).float().to(device) 55 | mask = T.ToTensor()(mask).float().to(device) 56 | 57 | _, h, w = image.shape 58 | grid = 8 59 | test_size = 512 60 | image = image[:3, :h//grid*grid, :w//grid*grid].unsqueeze(0) 61 | mask = mask[0:1, :h//grid*grid, :w//grid*grid].unsqueeze(0) 62 | image = torch.nn.functional.interpolate(image, size=(test_size,test_size)) 63 | mask = torch.nn.functional.interpolate(mask, size=(test_size,test_size)) 64 | print(f"Shape of image: {image.shape}") 65 | 66 | image = (image*2 - 1.) # map image values to [-1, 1] range 67 | mask = (mask > 0.5).float() # 1.: masked 0.: unmasked 68 | 69 | image_masked = image * (1.-mask) # mask image 70 | 71 | ones_x = torch.ones_like(image_masked)[:, 0:1, :, :].to(device) 72 | x = torch.cat([image_masked, ones_x, ones_x*mask], 73 | dim=1) # concatenate channels 74 | # x = x.repeat(8, 1, 1, 1) 75 | # mask = mask.repeat(8, 1, 1, 1) 76 | # with torch.inference_mode(): 77 | # _, x_stage2 = generator(x, mask) 78 | 79 | 80 | # complete image 81 | 82 | # image_inpainted = image * (1.-mask) + x_stage2 * mask 83 | # save inpainted image 84 | 85 | # img_out = ((image_inpainted[0].permute(1, 2, 0) + 1)*127.5) 86 | # img_out = img_out.to(device='cpu', dtype=torch.uint8) 87 | # img_out = Image.fromarray(img_out.numpy()) 88 | # img_out.save(args.out) 89 | 90 | print(f"Saved output file at: {args.out}") 91 | 92 | torch.onnx.export( 93 | generator, 94 | (x, mask), 95 | os.path.join(ckpt_dir, ckpt_name+'.onnx'), 96 | verbose=True, 97 | export_params=True, 98 | input_names=['img', 'mask'], 99 | output_names=['output_stage1', 'output_stage2'], 100 | do_constant_folding=True, 101 | opset_version=11) 102 | # dynamic_axes={'img':[0], 103 | # 'mask':[0], 104 | # 'output_stage1':[0], 105 | # 'output_stage2':[0]}) 106 | 107 | print('testing onnx model...') 108 | model = onnx.load(os.path.join(ckpt_dir, ckpt_name+'.onnx')) 109 | onnx.checker.check_model(model) 110 | 111 | session = ort.InferenceSession(os.path.join(ckpt_dir, ckpt_name+'.onnx')) 112 | x = x.cpu().numpy().astype(np.float32) # 注意输入type一定要np.float32 113 | mask = mask.cpu().numpy().astype(np.float32) 114 | out_stage1, out_stage2 = session.run(None, { 'img': x, 'mask': mask }) 115 | print('onnx model test finished') 116 | print(out_stage1.shape) 117 | print(out_stage2.shape) 118 | 119 | with torch.inference_mode(): 120 | x_stage1, x_stage2 = generator(torch.from_numpy(x).to(device).float(), torch.from_numpy(mask).to(device).float()) 121 | # np.testing.assert_allclose(x_stage2.cpu().numpy()[0], out_stage2[0], rtol=1e-03, atol=1e-05) 122 | # print("Exported model has been tested with ONNXRuntime, and the result looks good!") 123 | 124 | ort_inpainted = image.cpu().numpy() * (1.-mask) + out_stage2 * mask 125 | # save inpainted image 126 | img_out = ((ort_inpainted[0].transpose(1, 2, 0) + 1)*127.5) 127 | img_out = img_out.astype(np.uint8) 128 | img_out = Image.fromarray(img_out) 129 | img_out.save(args.out+'.onnx.png') 130 | 131 | print(f"Saved ort output file at: {args.out}") 132 | 133 | 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /ckpt_convert/ema.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------- 2 | # 依赖导入 3 | # ------------------------------------------------------------------------------------------------------------------- 4 | 5 | import os 6 | import sys 7 | import time 8 | import math 9 | import glob 10 | import copy 11 | import random 12 | 13 | import cv2 14 | import numpy as np 15 | import pandas as pd 16 | 17 | import paddle 18 | import paddle.nn as nn 19 | import paddle.nn.functional as F 20 | 21 | from nafa_archv1 import NAFNet 22 | Net = NAFNet(img_channel=3, width=32, middle_blk_num=8, 23 | enc_blk_nums=[1, 1, 2, 2], dec_blk_nums=[2, 2, 1, 1], decmask_blk_nums=[1, 1, 1, 1]) 24 | 25 | SEED = 42 26 | 27 | paddle.seed(SEED) 28 | random.seed(SEED) 29 | np.random.seed(SEED) 30 | 31 | if paddle.is_compiled_with_cuda(): 32 | paddle.set_device('gpu:0') 33 | else: 34 | paddle.set_device('cpu') 35 | 36 | # ------------------------------------------------------------------------------------------------------------------- 37 | # EMA 38 | # ------------------------------------------------------------------------------------------------------------------- 39 | 40 | def EMA(model, ema_model_path, model_path_list): 41 | 42 | ema_model = copy.deepcopy(model) 43 | ema_n = 0 44 | 45 | with paddle.no_grad(): 46 | for _ckpt in model_path_list: 47 | model.load_dict(paddle.load(_ckpt)) # , map_location=torch.device('cpu') 48 | tmp_para_dict = dict(model.named_parameters()) 49 | alpha = 1. / (ema_n + 1.) 50 | for name, para in ema_model.named_parameters(): 51 | new_para = tmp_para_dict[name].clone() * alpha + para.clone() * (1. - alpha) 52 | para.set_value(new_para.clone()) 53 | ema_n += 1 54 | 55 | paddle.save(ema_model.state_dict(), ema_model_path) 56 | print('ema finished !!!') 57 | 58 | return ema_model 59 | 60 | # --------------------------------------------------------------------------------------------------------------------------------- 61 | # 主函数定义 62 | # --------------------------------------------------------------------------------------------------------------------------------- 63 | 64 | def process(): 65 | 66 | 67 | ema_model_path="./model_ema.pdparams" 68 | 69 | model_path_list = [f"../repaire/nafa_v1/99_0.7561.pdparams", # best phase1 70 | f"../repaire/nafa_v1/97_0.7560.pdparams", # best phase1 71 | f"../repaire/nafa_v1_psnr/50_0.7590.pdparams", # best phase2 72 | f"../repaire/nafa_v1_psnr/51_0.7588.pdparams", # best phase1 73 | ] 74 | 75 | model = EMA(Net, ema_model_path, model_path_list) 76 | 77 | # paddle.jit.save(model, path='./stac/ema', input_spec=[paddle.static.InputSpec(shape=[1, 3, patch_size, patch_size], dtype='float32')]) 78 | 79 | # --------------------------------------------------------------------------------------------------------------------------------- 80 | # 主函数调用 81 | # --------------------------------------------------------------------------------------------------------------------------------- 82 | 83 | if __name__ == "__main__": 84 | process() 85 | 86 | 87 | -------------------------------------------------------------------------------- /ckpt_convert/nafa_archv1.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import paddle.nn as nn 3 | import paddle.nn.functional as F 4 | import numpy as np 5 | from paddle.autograd import PyLayer 6 | import numbers 7 | 8 | 9 | class Identity(nn.Layer): 10 | def __init_(self): 11 | super().__init__() 12 | 13 | def forward(self, x): 14 | return x 15 | 16 | def to_3d(x): 17 | b, c, h, w = x.shape 18 | x = paddle.reshape(x, [b, c, h * w]) 19 | x = paddle.transpose(x, [0, 2, 1]) 20 | return x 21 | # return rearrange(x, 'b c h w -> b (h w) c') 22 | 23 | 24 | def to_4d(x, h, w): 25 | b, hw, c = x.shape 26 | x = paddle.reshape(x, [b, h, w, c]) 27 | x = paddle.transpose(x, [0, 3, 1, 2]) 28 | return x 29 | # return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 30 | 31 | class BiasFree_LayerNorm(nn.Layer): 32 | def __init__(self, normalized_shape): 33 | super(BiasFree_LayerNorm, self).__init__() 34 | if isinstance(normalized_shape, numbers.Integral): 35 | normalized_shape = (normalized_shape,) 36 | # normalized_shape = [normalized_shape] 37 | 38 | assert len(normalized_shape) == 1 39 | 40 | # self.weight = nn.Parameter(torch.ones(normalized_shape)) 41 | self.weight = paddle.create_parameter(shape=normalized_shape,dtype='float32', 42 | default_initializer=nn.initializer.Constant(1.0)) 43 | self.normalized_shape = normalized_shape 44 | 45 | def forward(self, x): 46 | sigma = x.var(-1, keepdim=True, unbiased=False) 47 | return x / paddle.sqrt(sigma+1e-5) * self.weight 48 | 49 | 50 | class WithBias_LayerNorm(nn.Layer): 51 | def __init__(self, normalized_shape): 52 | super(WithBias_LayerNorm, self).__init__() 53 | if isinstance(normalized_shape, numbers.Integral): 54 | normalized_shape = (normalized_shape,) 55 | # normalized_shape = normalized_shape.shape 56 | 57 | assert len(normalized_shape) == 1 58 | 59 | # self.weight = nn.Parameter(torch.ones(normalized_shape)) 60 | self.weight = paddle.create_parameter(shape=normalized_shape,dtype='float32', 61 | default_initializer=nn.initializer.Constant(1.0)) 62 | # self.bias = nn.Parameter(torch.zeros(normalized_shape)) 63 | self.bias = paddle.create_parameter(shape=normalized_shape,dtype='float32', 64 | default_initializer=nn.initializer.Constant(0.0)) 65 | self.normalized_shape = normalized_shape 66 | 67 | def forward(self, x): 68 | mu = x.mean(-1, keepdim=True) 69 | sigma = x.var(-1, keepdim=True, unbiased=False) 70 | return (x - mu) / paddle.sqrt(sigma+1e-6) * self.weight + self.bias 71 | 72 | 73 | class LayerNorm(nn.Layer): 74 | def __init__(self, dim, LayerNorm_type='WithBias'): 75 | super(LayerNorm, self).__init__() 76 | if LayerNorm_type =='BiasFree': 77 | self.body = BiasFree_LayerNorm(dim) 78 | else: 79 | self.body = WithBias_LayerNorm(dim) 80 | 81 | def forward(self, x): 82 | h, w = x.shape[-2:] 83 | return to_4d(self.body(to_3d(x)), h, w) 84 | 85 | 86 | class SimpleGate(nn.Layer): 87 | def forward(self, x): 88 | x1, x2 = x.chunk(2, axis=1) 89 | return x1 * x2 90 | 91 | 92 | class NAFBlock(nn.Layer): 93 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): 94 | super().__init__() 95 | dw_channel = c * DW_Expand 96 | self.conv1 = nn.Conv2D(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, 97 | bias_attr=True) 98 | self.conv2 = nn.Conv2D(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, 99 | groups=dw_channel, 100 | bias_attr=True) 101 | self.conv3 = nn.Conv2D(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 102 | groups=1, bias_attr=True) 103 | 104 | # Simplified Channel Attention 105 | self.sca = nn.Sequential( 106 | nn.AdaptiveAvgPool2D(1), 107 | nn.Conv2D(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, 108 | groups=1, bias_attr=True), 109 | ) 110 | 111 | # SimpleGate 112 | self.sg = SimpleGate() 113 | 114 | ffn_channel = FFN_Expand * c 115 | self.conv4 = nn.Conv2D(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, 116 | bias_attr=True) 117 | self.conv5 = nn.Conv2D(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 118 | groups=1, bias_attr=True) 119 | 120 | self.norm1 = LayerNorm(c) 121 | self.norm2 = LayerNorm(c) 122 | 123 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else Identity() 124 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else Identity() 125 | 126 | self.beta = paddle.create_parameter(shape=[1, c, 1, 1], 127 | dtype='float32', 128 | default_initializer=paddle.nn.initializer.Assign(paddle.zeros([1, c, 1, 1]))) 129 | self.gamma = paddle.create_parameter(shape=[1, c, 1, 1], 130 | dtype='float32', 131 | default_initializer=paddle.nn.initializer.Assign(paddle.zeros([1, c, 1, 1]))) 132 | 133 | def forward(self, inp): 134 | x = inp 135 | 136 | x = self.norm1(x) 137 | 138 | x = self.conv1(x) 139 | x = self.conv2(x) 140 | x = self.sg(x) 141 | x = x * self.sca(x) 142 | x = self.conv3(x) 143 | 144 | x = self.dropout1(x) 145 | 146 | y = inp + x * self.beta 147 | 148 | x = self.conv4(self.norm2(y)) 149 | x = self.sg(x) 150 | x = self.conv5(x) 151 | 152 | x = self.dropout2(x) 153 | 154 | return y + x * self.gamma 155 | 156 | 157 | def get_pad(in_, ksize, stride, atrous=1): 158 | out_ = np.ceil(float(in_) / stride) 159 | return int(((out_ - 1) * stride + atrous * (ksize - 1) + 1 - in_) / 2) 160 | 161 | 162 | class ConvWithActivation(nn.Layer): 163 | ''' 164 | SN convolution for spetral normalization conv 165 | ''' 166 | 167 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 168 | activation=nn.LeakyReLU(0.2)): 169 | super(ConvWithActivation, self).__init__() 170 | self.conv2d = nn.Conv2D(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, 171 | dilation=dilation, groups=groups, bias_attr=bias) 172 | self.conv2d = nn.utils.spectral_norm(self.conv2d) 173 | 174 | self.activation = activation 175 | for m in self.sublayers(): 176 | if isinstance(m, nn.Conv2D): 177 | n = m.weight.shape[0] * m.weight.shape[1] * m.weight.shape[2] 178 | v = np.random.normal(loc=0., scale=np.sqrt(2. / n), size=m.weight.shape).astype('float32') 179 | m.weight.set_value(v) 180 | 181 | def forward(self, input): 182 | x = self.conv2d(input) 183 | if self.activation is not None: 184 | return self.activation(x) 185 | else: 186 | return x 187 | 188 | 189 | class DeConvWithActivation(nn.Layer): 190 | ''' 191 | SN convolution for spetral normalization conv 192 | ''' 193 | 194 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, 195 | output_padding=1, bias=True, activation=nn.LeakyReLU(0.2)): 196 | super(DeConvWithActivation, self).__init__() 197 | self.conv2d = nn.Conv2DTranspose(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 198 | padding=padding, dilation=dilation, groups=groups, 199 | output_padding=output_padding, bias_attr=bias) 200 | self.conv2d = nn.utils.spectral_norm(self.conv2d) 201 | self.activation = activation 202 | 203 | def forward(self, input): 204 | 205 | x = self.conv2d(input) 206 | 207 | if self.activation is not None: 208 | return self.activation(x) 209 | else: 210 | return x 211 | 212 | class Residual(nn.Layer): 213 | def __init__(self, in_channels, out_channels, same_shape=True, **kwargs): 214 | super(Residual, self).__init__() 215 | self.same_shape = same_shape 216 | strides = 1 if same_shape else 2 217 | self.conv1 = nn.Conv2D(in_channels, in_channels, kernel_size=3, padding=1, stride=strides) 218 | self.conv2 = nn.Conv2D(in_channels, out_channels, kernel_size=3, padding=1) 219 | if not same_shape: 220 | self.conv3 = nn.Conv2D(in_channels, out_channels, kernel_size=1, 221 | # self.conv3 = nn.Conv2D(channels, kernel_size=3, padding=1, 222 | stride=strides) 223 | self.batch_norm2d = nn.BatchNorm2D(out_channels) 224 | 225 | def forward(self, x): 226 | out = F.relu(self.conv1(x)) 227 | out = self.conv2(out) 228 | if not self.same_shape: 229 | x = self.conv3(x) 230 | out = self.batch_norm2d(out + x) 231 | # out = out + x 232 | return F.relu(out) 233 | 234 | 235 | class NAFNet(nn.Layer): 236 | 237 | def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], decm_blk_nums=[], decmask_blk_nums=[]): 238 | super().__init__() 239 | 240 | self.intro = nn.Conv2D(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, 241 | groups=1, 242 | bias_attr=True) 243 | self.ending = nn.Conv2D(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, 244 | groups=1, 245 | bias_attr=True) 246 | 247 | self.encoders = nn.LayerList() 248 | self.decoders = nn.LayerList() 249 | self.decoders_mask = nn.LayerList() 250 | self.middle_blks = nn.LayerList() 251 | self.ups = nn.LayerList() 252 | self.ups_mask = nn.LayerList() 253 | self.downs = nn.LayerList() 254 | 255 | chan = width 256 | 257 | for num in enc_blk_nums: 258 | self.encoders.append( 259 | nn.Sequential( 260 | *[NAFBlock(chan) for _ in range(num)] 261 | ) 262 | ) 263 | self.downs.append( 264 | nn.Conv2D(chan, 2 * chan, 2, 2) 265 | ) 266 | chan = chan * 2 267 | 268 | chan_mask = chan 269 | 270 | self.middle_blks = \ 271 | nn.Sequential( 272 | *[NAFBlock(chan) for _ in range(middle_blk_num)] 273 | ) 274 | 275 | for num in dec_blk_nums: 276 | self.ups.append( 277 | nn.Sequential( 278 | nn.Conv2D(chan, chan * 2, 1, bias_attr=False), 279 | nn.PixelShuffle(2) 280 | ) 281 | ) 282 | chan = chan // 2 283 | self.decoders.append( 284 | nn.Sequential( 285 | *[NAFBlock(chan) for _ in range(num)] 286 | ) 287 | ) 288 | 289 | self.padder_size = 2 ** len(self.encoders) 290 | 291 | ### mask branch decoder ### 292 | self.res_mask = Residual(chan_mask, chan_mask) 293 | for num in decmask_blk_nums: 294 | self.ups_mask.append( 295 | nn.Sequential( 296 | nn.Conv2D(chan_mask, chan_mask * 2, 1, bias_attr=False), 297 | nn.PixelShuffle(2) 298 | ) 299 | ) 300 | chan_mask = chan_mask // 2 301 | self.decoders_mask.append( 302 | nn.Sequential( 303 | *[NAFBlock(chan_mask) for _ in range(num)] 304 | ) 305 | ) 306 | self.mask_conv_d = nn.Conv2D(chan_mask, 1, kernel_size=1) # 3->1 307 | self.sig = nn.Sigmoid() 308 | 309 | 310 | 311 | def forward(self, inp): 312 | B, C, H, W = inp.shape 313 | inp = self.check_image_size(inp) 314 | 315 | x = self.intro(inp) 316 | inp_res = x 317 | 318 | encs = [] 319 | 320 | for encoder, down in zip(self.encoders, self.downs): 321 | x = encoder(x) 322 | encs.append(x) 323 | x = down(x) 324 | 325 | x_mask = self.res_mask(x) 326 | 327 | x = self.middle_blks(x) 328 | # x4 = x 329 | # x_mask = self.res_mask(x) 330 | 331 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): 332 | x = up(x) 333 | x = x + enc_skip 334 | x = decoder(x) 335 | 336 | x = x + inp_res 337 | x = self.ending(x) 338 | 339 | # ### mask branch ### 340 | for decoderm, upm, enc_skip in zip(self.decoders_mask, self.ups_mask, encs[::-1]): 341 | x_mask = upm(x_mask) 342 | x_mask = x_mask + enc_skip 343 | x_mask = decoderm(x_mask) 344 | 345 | mm = self.mask_conv_d(x_mask) # 32 -> 3, h, w 346 | mm = self.sig(mm) 347 | 348 | return x[:, :, :H, :W], mm[:, :, :H, :W] 349 | 350 | def check_image_size(self, x): 351 | _, _, h, w = x.shape 352 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size 353 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size 354 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) 355 | return x 356 | 357 | class Local_Base(): 358 | def convert(self, *args, train_size, **kwargs): 359 | replace_layers(self, *args, train_size=train_size, **kwargs) 360 | imgs = paddle.rand(train_size) 361 | with paddle.no_grad(): 362 | self.forward(imgs) 363 | 364 | class NAFNetLocal(Local_Base, NAFNet): 365 | def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs): 366 | Local_Base.__init__(self) 367 | NAFNet.__init__(self, *args, **kwargs) 368 | 369 | N, C, H, W = train_size 370 | base_size = (int(H * 1.5), int(W * 1.5)) 371 | 372 | self.eval() 373 | with paddle.no_grad(): 374 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) 375 | 376 | 377 | class AvgPool2d(nn.Layer): 378 | def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None): 379 | super().__init__() 380 | self.kernel_size = kernel_size 381 | self.base_size = base_size 382 | self.auto_pad = auto_pad 383 | 384 | # only used for fast implementation 385 | self.fast_imp = fast_imp 386 | self.rs = [5, 4, 3, 2, 1] 387 | self.max_r1 = self.rs[0] 388 | self.max_r2 = self.rs[0] 389 | self.train_size = train_size 390 | 391 | def extra_repr(self) -> str: 392 | return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( 393 | self.kernel_size, self.base_size, self.kernel_size, self.fast_imp 394 | ) 395 | 396 | def forward(self, x): 397 | if self.kernel_size is None and self.base_size: 398 | train_size = self.train_size 399 | if isinstance(self.base_size, int): 400 | self.base_size = (self.base_size, self.base_size) 401 | self.kernel_size = list(self.base_size) 402 | self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] 403 | self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] 404 | 405 | # only used for fast implementation 406 | self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) 407 | self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) 408 | 409 | if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1): 410 | return F.adaptive_avg_pool2d(x, 1) 411 | 412 | if self.fast_imp: # Non-equivalent implementation but faster 413 | h, w = x.shape[2:] 414 | if self.kernel_size[0] >= h and self.kernel_size[1] >= w: 415 | out = F.adaptive_avg_pool2d(x, 1) 416 | else: 417 | r1 = [r for r in self.rs if h % r == 0][0] 418 | r2 = [r for r in self.rs if w % r == 0][0] 419 | # reduction_constraint 420 | r1 = min(self.max_r1, r1) 421 | r2 = min(self.max_r2, r2) 422 | s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) 423 | n, c, h, w = s.shape 424 | k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2) 425 | out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) 426 | out = paddle.nn.functional.interpolate(out, scale_factor=(r1, r2)) 427 | else: 428 | n, c, h, w = x.shape 429 | s = x.cumsum(dim=-1).cumsum_(dim=-2) 430 | s = paddle.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience 431 | k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) 432 | s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:] 433 | out = s4 + s1 - s2 - s3 434 | out = out / (k1 * k2) 435 | 436 | if self.auto_pad: 437 | n, c, h, w = x.shape 438 | _h, _w = out.shape[2:] 439 | # print(x.shape, self.kernel_size) 440 | pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2) 441 | out = paddle.nn.functional.pad(out, pad2d, mode='replicate') 442 | 443 | return out 444 | 445 | 446 | def replace_layers(model, base_size, train_size, fast_imp, **kwargs): 447 | for n, m in model.named_children(): 448 | if len(list(m.children())) > 0: 449 | ## compound Layer, go inside it 450 | replace_layers(m, base_size, train_size, fast_imp, **kwargs) 451 | 452 | if isinstance(m, nn.AdaptiveAvgPool2D): 453 | pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size) 454 | assert m.output_size == 1 455 | setattr(model, n, pool) 456 | 457 | 458 | ''' 459 | ref. 460 | @article{chu2021tlsc, 461 | title={Revisiting Global Statistics Aggregation for Improving Image Restoration}, 462 | author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin}, 463 | journal={arXiv preprint arXiv:2112.04491}, 464 | year={2021} 465 | } 466 | ''' 467 | 468 | 469 | if __name__ == '__main__': 470 | import os 471 | import numpy as np 472 | net = NAFNet(img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 18], dec_blk_nums=[1, 1, 1, 1], decm_blk_nums=[1, 1, 1, 1]) 473 | x = np.random.randn(*[1, 3, 128, 128]) 474 | x = x.astype('float32') 475 | x = paddle.to_tensor(x) 476 | out, mm = net(x) 477 | print(out.shape, mm.shape) -------------------------------------------------------------------------------- /ckpt_convert/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ckpt_convert/说明文档.txt: -------------------------------------------------------------------------------- 1 | 1. repaire模型融合 2 | 选择repair/ckpts中nafa_v1最优的两个模型和nafa_v1_psnr中最优的两个模型,修改ema.py中69-73行代码为对应路径。 3 | 执行python eam.py进行模型融合 4 | 5 | 2.模型转换 -------------------------------------------------------------------------------- /compute_mask.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os,glob, shutil 3 | import numpy as np 4 | path = './dataset/task2/dehw_train_dataset' 5 | save_path = './dataset/task2/dehw_train_dataset/mask_331_25/' 6 | gts = sorted(glob.glob(path + '/gts/*')) 7 | images = sorted(glob.glob(path + '/images/*')) 8 | os.makedirs(save_path, exist_ok=True) 9 | # print(gts,images) 10 | for gt_f,im_f in zip(gts,images): 11 | print(gt_f) 12 | gt = cv2.imread(gt_f) 13 | im = cv2.imread(im_f) 14 | # mask = np.where(abs(gt.astype(np.float32) - im.astype(np.float32)) > 40, 0, 1) 15 | kernel = np.ones((3,3),np.uint8) 16 | # mask = cv2.erode(np.uint8(mask), kernel, iterations=2) 17 | threshold = 25 18 | diff_image = np.abs(im.astype(np.float32) - gt.astype(np.float32)) 19 | mean_image = np.mean(diff_image, axis=-1) 20 | mask = np.greater(mean_image, threshold).astype(np.uint8) 21 | mask = (1 - mask) * 255 22 | mask = cv2.erode(np.uint8(mask), kernel, iterations=1) 23 | cv2.imwrite(save_path+os.path.basename(gt_f), np.uint8(mask)) 24 | # print(gt[622,513],im[622,513],mask[622,513]) 25 | # kernel = np.ones((2,2),np.uint8) 26 | # erosion = cv2.morphologyEx(np.uint8(mask), cv2.MORPH_OPEN, kernel) 27 | # # break 28 | # cv2.imshow('gt',gt) 29 | # cv2.imshow('im',im) 30 | # cv2.imshow('mask',np.uint8(mask)) 31 | # # # cv2.imshow('erosion',np.uint8(erosion*255)) 32 | # cv2.waitKey(0) 33 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import numpy as np 3 | import cv2 4 | from os import listdir, walk 5 | from os.path import join 6 | import random 7 | from PIL import Image 8 | 9 | from paddle.vision.transforms import Compose, RandomCrop, ToTensor, CenterCrop 10 | from paddle.vision.transforms import functional as F 11 | 12 | 13 | def random_horizontal_flip(imgs): 14 | if random.random() < 0.3: 15 | for i in range(len(imgs)): 16 | imgs[i] = imgs[i].transpose(Image.FLIP_LEFT_RIGHT) 17 | return imgs 18 | 19 | def random_rotate(imgs): 20 | if random.random() < 0.3: 21 | max_angle = 10 22 | angle = random.random() * 2 * max_angle - max_angle 23 | # print(angle) 24 | for i in range(len(imgs)): 25 | img = np.array(imgs[i]) 26 | w, h = img.shape[:2] 27 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 28 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w)) 29 | imgs[i] =Image.fromarray(img_rotation) 30 | return imgs 31 | 32 | def CheckImageFile(filename): 33 | return any(filename.endswith(extention) for extention in ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.bmp', '.BMP']) 34 | 35 | def ImageTransform(): 36 | return Compose([ 37 | # CenterCrop(size=loadSize), 38 | ToTensor(), 39 | ]) 40 | def ImageTransformTest(loadSize): 41 | return Compose([ 42 | CenterCrop(size=loadSize), 43 | ToTensor(), 44 | ]) 45 | 46 | class PairedRandomCrop(RandomCrop): 47 | def __init__(self, size, keys=None): 48 | super().__init__(size, keys=keys) 49 | 50 | if isinstance(size, int): 51 | self.size = (size, size) 52 | else: 53 | self.size = size 54 | 55 | def _get_params(self, inputs): 56 | image = inputs[self.keys.index('image')] 57 | params = {} 58 | params['crop_prams'] = self._get_param(image, self.size) 59 | return params 60 | 61 | def _apply_image(self, img): 62 | i, j, h, w = self.params['crop_prams'] 63 | return F.crop(img, i, j, h, w) 64 | 65 | class ErasingData(paddle.io.Dataset): 66 | def __init__(self, dataRoot, loadSize, training=True, mask_dir='mask'): 67 | super(ErasingData, self).__init__() 68 | self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \ 69 | for files in filenames if CheckImageFile(files)] 70 | self.loadSize = loadSize 71 | self.ImgTrans = ImageTransform() 72 | self.training = training 73 | self.mask_dir = mask_dir 74 | self.RandomCropparam = RandomCrop(self.loadSize) 75 | 76 | def __getitem__(self, index): 77 | img = Image.open(self.imageFiles[index]) 78 | # print(self.imageFiles[index].replace('images', self.mask_dir).replace('jpg','png')) 79 | mask = Image.open(self.imageFiles[index].replace('images', self.mask_dir).replace('jpg','png')) 80 | gt = Image.open(self.imageFiles[index].replace('images','gts').replace('jpg','png')) 81 | # import pdb;pdb.set_trace() 82 | if self.training: 83 | # ### for data augmentation 84 | all_input = [img, mask, gt] 85 | all_input = random_horizontal_flip(all_input) 86 | all_input = random_rotate(all_input) 87 | img = all_input[0] 88 | mask = all_input[1] 89 | gt = all_input[2] 90 | ### for data augmentation 91 | # param = RandomCrop.get_params(img.convert('RGB'), self.loadSize) 92 | param = self.RandomCropparam._get_param(img.convert('RGB'), self.loadSize) 93 | # print(param) 94 | inputImage = F.crop(img.convert('RGB'), *param) 95 | maskIn = F.crop(mask.convert('RGB'), *param) 96 | groundTruth = F.crop(gt.convert('RGB'), *param) 97 | del img 98 | del gt 99 | del mask 100 | 101 | inputImage = self.ImgTrans(inputImage) 102 | maskIn = self.ImgTrans(maskIn) 103 | groundTruth = self.ImgTrans(groundTruth) 104 | path = self.imageFiles[index].split('/')[-1] 105 | # import pdb;pdb.set_trace() 106 | 107 | return inputImage, groundTruth, maskIn, path 108 | 109 | def __len__(self): 110 | return len(self.imageFiles) 111 | 112 | class devdata(paddle.io.Dataset): 113 | def __init__(self, dataRoot, gtRoot, loadSize=512): 114 | super(devdata, self).__init__() 115 | self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \ 116 | for files in filenames if CheckImageFile(files)] 117 | self.gtFiles = [join (gtRootK, files) for gtRootK, dn, filenames in walk(gtRoot) \ 118 | for files in filenames if CheckImageFile(files)] 119 | self.loadSize = loadSize 120 | self.ImgTrans = ImageTransform() 121 | # self.ImgTrans = ImageTransformTest(loadSize) 122 | 123 | def __getitem__(self, index): 124 | img = Image.open(self.imageFiles[index]) 125 | gt = Image.open(self.gtFiles[index]) 126 | # print(self.imageFiles[index],self.gtFiles[index]) 127 | #import pdb;pdb.set_trace() 128 | inputImage = self.ImgTrans(img.convert('RGB')) 129 | 130 | groundTruth = self.ImgTrans(gt.convert('RGB')) 131 | path = self.imageFiles[index].split('/')[-1] 132 | 133 | return inputImage, groundTruth,path 134 | 135 | def __len__(self): 136 | return len(self.imageFiles) -------------------------------------------------------------------------------- /gauss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module providing functionality surrounding gaussian function. 3 | """ 4 | SVN_REVISION = '$LastChangedRevision: 16541 $' 5 | 6 | import sys 7 | import numpy 8 | 9 | def gaussian2(size, sigma): 10 | """Returns a normalized circularly symmetric 2D gauss kernel array 11 | 12 | f(x,y) = A.e^{-(x^2/2*sigma^2 + y^2/2*sigma^2)} where 13 | 14 | A = 1/(2*pi*sigma^2) 15 | 16 | as define by Wolfram Mathworld 17 | http://mathworld.wolfram.com/GaussianFunction.html 18 | """ 19 | A = 1/(2.0*numpy.pi*sigma**2) 20 | x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 21 | g = A*numpy.exp(-((x**2/(2.0*sigma**2))+(y**2/(2.0*sigma**2)))) 22 | return g 23 | 24 | def fspecial_gauss(size, sigma): 25 | """Function to mimic the 'fspecial' gaussian MATLAB function 26 | """ 27 | x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 28 | g = numpy.exp(-((x**2 + y**2)/(2.0*sigma**2))) 29 | return g/g.sum() 30 | 31 | def main(): 32 | """Show simple use cases for functionality provided by this module.""" 33 | from mpl_toolkits.mplot3d.axes3d import Axes3D 34 | import pylab 35 | argv = sys.argv 36 | if len(argv) != 3: 37 | print >>sys.stderr, 'usage: python -m pim.sp.gauss size sigma' 38 | sys.exit(2) 39 | size = int(argv[1]) 40 | sigma = float(argv[2]) 41 | x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 42 | 43 | fig = pylab.figure() 44 | fig.suptitle('Some 2-D Gauss Functions') 45 | ax = fig.add_subplot(2, 1, 1, projection='3d') 46 | ax.plot_surface(x, y, fspecial_gauss(size, sigma), rstride=1, cstride=1, 47 | linewidth=0, antialiased=False, cmap=pylab.jet()) 48 | ax = fig.add_subplot(2, 1, 2, projection='3d') 49 | ax.plot_surface(x, y, gaussian2(size, sigma), rstride=1, cstride=1, 50 | linewidth=0, antialiased=False, cmap=pylab.jet()) 51 | pylab.show() 52 | return 0 53 | 54 | if __name__ == '__main__': 55 | sys.exit(main()) -------------------------------------------------------------------------------- /loss/Loss.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | from paddle import nn 3 | import paddle.nn.functional as F 4 | # from tensorboardX import SummaryWriter 5 | 6 | from PIL import Image 7 | import numpy as np 8 | 9 | def gram_matrix(feat): 10 | # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py 11 | (b, ch, h, w) = feat.size() 12 | feat = feat.view(b, ch, h * w) 13 | feat_t = feat.transpose(1, 2) 14 | gram = paddle.bmm(feat, feat_t) / (ch * h * w) 15 | return gram 16 | 17 | def visual(image): 18 | im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy() 19 | Image.fromarray(im[0].astype(np.uint8)).show() 20 | 21 | def dice_loss(input, target): 22 | input = F.sigmoid(input) 23 | 24 | input = input.contiguous().view(input.size()[0], -1) 25 | target = target.contiguous().view(target.size()[0], -1) 26 | 27 | input = input 28 | target = target 29 | 30 | a = paddle.sum(input * target, 1) 31 | b = paddle.sum(input * input, 1) + 0.001 32 | c = paddle.sum(target * target, 1) + 0.001 33 | d = (2 * a) / (b + c) 34 | dice_loss = paddle.mean(d) 35 | return 1 - dice_loss 36 | 37 | def bce_loss(input, target): 38 | input = F.sigmoid(input) 39 | 40 | input = input.reshape([input.shape[0], -1]) 41 | target = target.reshape([target.shape[0], -1]) 42 | 43 | input = input 44 | target = target 45 | 46 | bce = paddle.nn.BCELoss() 47 | 48 | return bce(input, target) 49 | 50 | class LossWithGAN_STE(nn.Layer): 51 | # def __init__(self, logPath, extractor, Lamda, lr, betasInit=(0.5, 0.9)): 52 | def __init__(self, Lamda, lr, betasInit=(0.5, 0.9)): 53 | super(LossWithGAN_STE, self).__init__() 54 | self.l1 = nn.L1Loss() 55 | # self.extractor = extractor 56 | # self.discriminator = Discriminator_STE(3) ## local_global sn patch gan 57 | # self.D_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betasInit) 58 | # self.cudaAvailable = torch.cuda.is_available() 59 | # self.numOfGPUs = torch.cuda.device_count() 60 | # self.lamda = Lamda 61 | # self.writer = SummaryWriter(logPath) 62 | 63 | def forward(self, input, mask, x_o1,x_o2,x_o3,output,mm, gt, count, epoch): 64 | # self.discriminator.zero_grad() 65 | # D_real = self.discriminator(gt, mask) 66 | # D_real = D_real.mean().sum() * -1 67 | # D_fake = self.discriminator(output, mask) 68 | # D_fake = D_fake.mean().sum() * 1 69 | # D_loss = torch.mean(F.relu(1.+D_real)) + torch.mean(F.relu(1.+D_fake)) #SN-patch-GAN loss 70 | # D_fake = -torch.mean(D_fake) # SN-Patch-GAN loss 71 | 72 | # self.D_optimizer.zero_grad() 73 | # D_loss.backward(retain_graph=True) 74 | # self.D_optimizer.step() 75 | 76 | # self.writer.add_scalar('LossD/Discrinimator loss', D_loss.item(), count) 77 | 78 | # output_comp = mask * input + (1 - mask) * output 79 | # import pdb;pdb.set_trace() 80 | holeLoss = self.l1((1 - mask) * output, (1 - mask) * gt) 81 | validAreaLoss = self.l1(mask * output, mask * gt) 82 | mask_loss = bce_loss(mm, 1-mask) 83 | 84 | # GLoss = msrloss+ holeLoss + validAreaLoss+ prcLoss + styleLoss + 0.1 * D_fake + 1*mask_loss 85 | GLoss = mask_loss + holeLoss + validAreaLoss 86 | # self.writer.add_scalar('Generator/Joint loss', GLoss.item(), count) 87 | return GLoss.sum() 88 | 89 | -------------------------------------------------------------------------------- /loss/PSNRLoss.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import paddle 4 | import paddle.nn as nn 5 | 6 | class PSNRLoss(nn.Layer): 7 | 8 | def __init__(self, loss_weight=1.0, reduction='mean'): 9 | super(PSNRLoss, self).__init__() 10 | assert reduction == 'mean' 11 | self.loss_weight = loss_weight 12 | self.scale = 10 / np.log(10) 13 | 14 | def forward(self, pred, target): 15 | assert len(pred.shape) == 4 16 | 17 | return self.loss_weight * self.scale * paddle.log(((pred - target) ** 2).mean(axis=(1, 2, 3)) + 1e-8).mean() 18 | 19 | # import torch 20 | # import torch.nn as nn 21 | # 22 | # class PSNRLoss1(nn.Module): 23 | # 24 | # def __init__(self, loss_weight=1.0, reduction='mean'): 25 | # super(PSNRLoss1, self).__init__() 26 | # assert reduction == 'mean' 27 | # self.loss_weight = loss_weight 28 | # self.scale = 10 / np.log(10) 29 | # 30 | # def forward(self, pred, target): 31 | # assert len(pred.size()) == 4 32 | # 33 | # return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 34 | 35 | 36 | if __name__ == '__main__': 37 | import cv2 38 | import torch 39 | img1 = cv2.imread('../result/results/bg_image_00013_0001.jpg') 40 | img1 = cv2.resize(img1, (256, 256)) 41 | img1 = img1[np.newaxis, :, :, :] 42 | 43 | img1_torch = torch.from_numpy(np.transpose(img1, (0, 3, 1, 2))) / 255. 44 | img1_paddle = paddle.to_tensor(np.transpose(img1, (0, 3, 1, 2))) / 255. 45 | 46 | 47 | img2 = cv2.imread('../result/results/bg_image_00016_0016.jpg') 48 | img2 = cv2.resize(img2, (256, 256)) 49 | 50 | img2 = img2[np.newaxis, :, :, :] 51 | img2_torch = torch.from_numpy(np.transpose(img2, (0, 3, 1, 2))) / 255. 52 | img2_paddle = paddle.to_tensor(np.transpose(img2, (0, 3, 1, 2))) / 255. 53 | 54 | psnr1 = PSNRLoss1() 55 | print(psnr1(img2_torch, img1_torch)) 56 | 57 | psnr = PSNRLoss() 58 | print(psnr(img2_paddle, img1_paddle)) -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar 2 | import paddle 3 | import paddle.nn as nn 4 | 5 | from models.Model import vgg19 6 | 7 | class pre_network(nn.Layer): 8 | """Reference: 9 | https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/3 10 | """ 11 | 12 | def __init__(self, pretrained: str = None): 13 | super(pre_network, self).__init__() 14 | self.vgg_layers = vgg19(pretrained=pretrained).features 15 | self.layer_name_mapping = { 16 | '3': 'relu1', 17 | '8': 'relu2', 18 | '13': 'relu3', 19 | # '22':'relu4', 20 | # '31':'relu5', 21 | } 22 | 23 | def forward(self, x): 24 | output = {} 25 | 26 | for name, module in self.vgg_layers._sub_layers.items(): 27 | x = module(x) 28 | if name in self.layer_name_mapping: 29 | output[self.layer_name_mapping[name]] = x 30 | return output -------------------------------------------------------------------------------- /models/Model.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import paddle.nn as nn 3 | from typing import Union, List, Dict, Any, cast 4 | from utils import load_pretrained_model 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 8 | 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19' 9 | ] 10 | 11 | model_urls = { 12 | 'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth', 13 | 'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth', 14 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 15 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 16 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 17 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 18 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 19 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Layer): 24 | def __init__( 25 | self, 26 | features: nn.Layer, 27 | num_classes: int = 1000, 28 | init_weights: bool = True 29 | ) -> None: 30 | super(VGG, self).__init__() 31 | 32 | self.features = features 33 | self.avgpool = nn.AdaptiveAvgPool2D((7, 7)) 34 | self.classifier = nn.Sequential( 35 | nn.Linear(512 * 7 * 7, 4096), 36 | nn.ReLU(), 37 | nn.Dropout(), 38 | nn.Linear(4096, 4096), 39 | nn.ReLU(), 40 | nn.Dropout(), 41 | nn.Linear(4096, num_classes), 42 | ) 43 | if init_weights: 44 | self._initialize_weights() 45 | 46 | def forward(self, x): 47 | x = self.features(x) 48 | x = self.avgpool(x) 49 | x = paddle.flatten(x, 1) 50 | x = self.classifier(x) 51 | return x 52 | 53 | def _initialize_weights(self) -> None: 54 | pass 55 | 56 | 57 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: 58 | layers: List[nn.Layer] = [] 59 | in_channels = 3 60 | for v in cfg: 61 | if v == 'M': 62 | layers += [nn.MaxPool2D(kernel_size=2, stride=2)] 63 | else: 64 | v = cast(int, v) 65 | conv2d = nn.Conv2D(in_channels, v, kernel_size=3, padding=1) 66 | if batch_norm: 67 | layers += [conv2d, nn.BatchNorm2D(v), nn.ReLU()] 68 | else: 69 | layers += [conv2d, nn.ReLU()] 70 | in_channels = v 71 | return nn.Sequential(*layers) 72 | 73 | 74 | cfgs: Dict[str, List[Union[str, int]]] = { 75 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 76 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 77 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 78 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 79 | } 80 | 81 | 82 | def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: str, progress: bool, **kwargs: Any) -> VGG: 83 | if pretrained: 84 | kwargs['init_weights'] = False 85 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 86 | if pretrained is not None: 87 | load_pretrained_model(model, pretrained) 88 | return model 89 | 90 | 91 | def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 92 | r"""VGG 11-layer model (configuration "A") from 93 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 94 | 95 | Args: 96 | pretrained (bool): If True, returns a model pre-trained on ImageNet 97 | progress (bool): If True, displays a progress bar of the download to stderr 98 | """ 99 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 100 | 101 | 102 | def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 103 | r"""VGG 11-layer model (configuration "A") with batch normalization 104 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 105 | 106 | Args: 107 | pretrained (bool): If True, returns a model pre-trained on ImageNet 108 | progress (bool): If True, displays a progress bar of the download to stderr 109 | """ 110 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 111 | 112 | 113 | def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 114 | r"""VGG 13-layer model (configuration "B") 115 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 116 | 117 | Args: 118 | pretrained (bool): If True, returns a model pre-trained on ImageNet 119 | progress (bool): If True, displays a progress bar of the download to stderr 120 | """ 121 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 122 | 123 | 124 | def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 125 | r"""VGG 13-layer model (configuration "B") with batch normalization 126 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | progress (bool): If True, displays a progress bar of the download to stderr 131 | """ 132 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 133 | 134 | 135 | def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 136 | r"""VGG 16-layer model (configuration "D") 137 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 138 | 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | progress (bool): If True, displays a progress bar of the download to stderr 142 | """ 143 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 144 | 145 | 146 | def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 147 | r"""VGG 16-layer model (configuration "D") with batch normalization 148 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 149 | 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | progress (bool): If True, displays a progress bar of the download to stderr 153 | """ 154 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 155 | 156 | 157 | def vgg19(pretrained: str = None, progress: bool = True, **kwargs: Any) -> VGG: 158 | r"""VGG 19-layer model (configuration "E") 159 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 160 | 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | progress (bool): If True, displays a progress bar of the download to stderr 164 | """ 165 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 166 | 167 | 168 | def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 169 | r"""VGG 19-layer model (configuration 'E') with batch normalization 170 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | progress (bool): If True, displays a progress bar of the download to stderr 175 | """ 176 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 177 | 178 | 179 | # import paddle 180 | # import paddle.nn as nn 181 | # from paddle.vision import models 182 | # #VGG16 feature extract 183 | # class VGG16FeatureExtractor(nn.Layer): 184 | # def __init__(self): 185 | # super(VGG16FeatureExtractor, self).__init__() 186 | # vgg16 = models.vgg16(pretrained=True) 187 | # # vgg16.load_state_dict(torch.load('./vgg16-397923af.pth')) 188 | # self.enc_1 = nn.Sequential(*vgg16.features[:5]) 189 | # self.enc_2 = nn.Sequential(*vgg16.features[5:10]) 190 | # self.enc_3 = nn.Sequential(*vgg16.features[10:17]) 191 | # 192 | # # fix the encoder 193 | # for i in range(3): 194 | # for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): 195 | # param.requires_grad = False 196 | # 197 | # def forward(self, image): 198 | # results = [image] 199 | # for i in range(3): 200 | # func = getattr(self, 'enc_{:d}'.format(i + 1)) 201 | # results.append(func(results[-1])) 202 | # return results[1:] 203 | 204 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import paddle.nn as nn 3 | from .networks import ConvWithActivation, get_pad 4 | 5 | ##discriminator 6 | class Discriminator_STE(nn.Layer): 7 | def __init__(self, inputChannels): 8 | super(Discriminator_STE, self).__init__() 9 | cnum =32 10 | self.globalDis = nn.Sequential( 11 | ConvWithActivation(3, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)), 12 | ConvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), 13 | ConvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)), 14 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)), 15 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), 16 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), 17 | ) 18 | 19 | self.localDis = nn.Sequential( 20 | ConvWithActivation(3, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)), 21 | ConvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), 22 | ConvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)), 23 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)), 24 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), 25 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), 26 | ) 27 | 28 | self.fusion = nn.Sequential( 29 | nn.Conv2D(512, 1, kernel_size=4), 30 | nn.Sigmoid() 31 | ) 32 | 33 | def forward(self, input, masks): 34 | global_feat = self.globalDis(input) 35 | local_feat = self.localDis(input * (1 - masks)) 36 | 37 | concat_feat = paddle.concat((global_feat, local_feat), 1) 38 | concat_feat = self.fusion(concat_feat) 39 | concat_feat = concat_feat.reshape([input.shape[0], -1]) 40 | 41 | return concat_feat 42 | -------------------------------------------------------------------------------- /models/idr.py: -------------------------------------------------------------------------------- 1 | # from x2paddle import torch2paddle 2 | import paddle 3 | import paddle.nn as nn 4 | from models.non_local import NonLocalBlock 5 | import paddle.nn.functional as F 6 | 7 | class AIDR(nn.Layer): 8 | 9 | def __init__(self, in_channels=3, out_channels=3, num_c=48): 10 | super(AIDR, self).__init__() 11 | self.en_block1 = nn.Sequential( 12 | nn.Conv2D(in_channels, num_c, 3, padding=1, bias_attr=True), 13 | nn.LeakyReLU(negative_slope=0.1), 14 | nn.Conv2D(num_c, num_c, 3, padding=1, bias_attr=True), 15 | nn.LeakyReLU(negative_slope=0.1), 16 | nn.MaxPool2D(2)) 17 | 18 | self.en_block2 = nn.Sequential( 19 | nn.Conv2D(num_c, num_c, 3, padding=1,bias_attr=True), 20 | nn.LeakyReLU(negative_slope=0.1), 21 | nn.MaxPool2D(2)) 22 | 23 | self.en_block3 = nn.Sequential( 24 | nn.Conv2D(num_c, num_c, 3, padding=1, bias_attr=True), 25 | nn.LeakyReLU(negative_slope=0.1), 26 | nn.MaxPool2D(2)) 27 | 28 | self.en_block4 = nn.Sequential( 29 | nn.Conv2D(num_c, num_c, 3, padding=1, bias_attr=True), 30 | nn.LeakyReLU(negative_slope=0.1), 31 | nn.MaxPool2D(2)) 32 | 33 | self.en_block5 = nn.Sequential( 34 | nn.Conv2D(num_c, num_c, 3, padding=1, bias_attr=True), 35 | nn.LeakyReLU(negative_slope=0.1), 36 | NonLocalBlock(num_c), 37 | nn.LeakyReLU(negative_slope=0.1), 38 | nn.MaxPool2D(2), 39 | nn.Conv2D(num_c, num_c, 3, padding=1, bias_attr=True), 40 | nn.LeakyReLU(negative_slope=0.1), 41 | NonLocalBlock(num_c), 42 | nn.LeakyReLU(negative_slope=0.1), 43 | nn.Upsample(scale_factor=2, mode='nearest')) 44 | 45 | self.de_block1 = nn.Sequential( 46 | nn.Conv2D(num_c*2 + 256, num_c*2, 3, padding=1, bias_attr=True), 47 | nn.LeakyReLU(negative_slope=0.1), 48 | NonLocalBlock(num_c*2), 49 | nn.LeakyReLU(negative_slope=0.1), 50 | nn.Conv2D(num_c*2, num_c*2, 3, padding=1, bias_attr=True), 51 | nn.LeakyReLU(negative_slope=0.1), 52 | nn.Upsample(scale_factor=2,mode='nearest')) 53 | 54 | self.de_block2 = nn.Sequential( 55 | nn.Conv2D(num_c*3 + 128, num_c*2, 3, padding=1,bias_attr=True), 56 | nn.LeakyReLU(negative_slope=0.1), 57 | nn.Conv2D(num_c*2, num_c*2, 3, padding=1, bias_attr=True), 58 | nn.LeakyReLU(negative_slope=0.1), 59 | nn.Upsample(scale_factor=2,mode='nearest')) 60 | 61 | self.de_block3 = nn.Sequential( 62 | nn.Conv2D(num_c*3 + 64, num_c*2, 3, padding=1,bias_attr=True), 63 | nn.LeakyReLU(negative_slope=0.1), 64 | nn.Conv2D(num_c*2, num_c*2, 3, padding=1, bias_attr=True), 65 | nn.LeakyReLU(negative_slope=0.1), 66 | nn.Upsample(scale_factor=2, mode='nearest')) 67 | 68 | self.de_block4 = nn.Sequential( 69 | nn.Conv2D(num_c*3, num_c*2, 3, padding=1,bias_attr=True), 70 | nn.LeakyReLU(negative_slope=0.1), 71 | nn.Conv2D(num_c*2, num_c*2, 3, padding=1, bias_attr=True), 72 | nn.LeakyReLU(negative_slope=0.1), 73 | nn.Upsample(scale_factor=2,mode='nearest')) 74 | 75 | self.de_block5 = nn.Sequential( 76 | nn.Conv2D(num_c*2 + in_channels, 64, 3,padding=1, bias_attr=True), 77 | nn.LeakyReLU(negative_slope=0.1), 78 | nn.Conv2D(64, 32, 3, padding=1, bias_attr=True), 79 | nn.LeakyReLU(negative_slope=0.1), 80 | nn.Conv2D(32, out_channels, 3, padding=1, bias_attr=True)) 81 | 82 | def forward(self, x, con_x2, con_x3, con_x4): 83 | # x -> x_o_unet: h, w 84 | # con_x1: h/2, w/2 # [1, 32, 32, 32] 85 | # con_x2: h/4, w/4 # [1, 64, 16, 16] 86 | # con_x3: h/8, w/8 # [1, 128, 8, 8] 87 | # con_x4: h/16, w/16 # [1, 256, 4, 4] 88 | pool1 = self.en_block1(x) # h/2, w/2 89 | pool2 = self.en_block2(pool1) # h/4, w/4 90 | pool3 = self.en_block3(pool2) # h/8, w/8 91 | pool4 = self.en_block4(pool3) # h/16, w/16 92 | # print('11111111111', con_x2.shape, con_x3.shape, con_x4.shape) 93 | # print('11111111111', pool2.shape, pool3.shape, pool4.shape) 94 | upsample5 = self.en_block5(pool4) 95 | concat5 = paddle.concat((upsample5, pool4, con_x4), axis=1) 96 | upsample4 = self.de_block1(concat5) 97 | concat4 = paddle.concat((upsample4, pool3, con_x3), axis=1) 98 | upsample3 = self.de_block2(concat4) # h/8, w/8 99 | concat3 = paddle.concat((upsample3, pool2, con_x2), axis=1) 100 | upsample2 = self.de_block3(concat3) # h/4, w/4 101 | concat2 = paddle.concat((upsample2, pool1), axis=1) 102 | upsample1 = self.de_block4(concat2) # h/2, w/2 103 | concat1 = paddle.concat((upsample1, x), axis=1) 104 | out = self.de_block5(concat1) 105 | return out 106 | 107 | if __name__ == '__main__': 108 | bgr = paddle.rand([1, 3, 1920, 1280]) 109 | bgr = paddle.to_tensor(bgr) 110 | model = AIDR(num_c=96) 111 | for _ in range(20): 112 | with paddle.no_grad(): 113 | out = model(bgr) 114 | print(out.shape) 115 | 116 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import numpy as np 3 | import paddle.nn.functional as F 4 | import paddle.nn as nn 5 | import math 6 | 7 | 8 | def get_pad(in_, ksize, stride, atrous=1): 9 | out_ = np.ceil(float(in_) / stride) 10 | return int(((out_ - 1) * stride + atrous * (ksize - 1) + 1 - in_) / 2) 11 | 12 | 13 | class ConvWithActivation(nn.Layer): 14 | ''' 15 | SN convolution for spetral normalization conv 16 | ''' 17 | 18 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, 19 | activation=nn.LeakyReLU(0.2)): 20 | super(ConvWithActivation, self).__init__() 21 | self.conv2d = nn.Conv2D(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, 22 | dilation=dilation, groups=groups, bias_attr=bias) 23 | self.conv2d = nn.utils.spectral_norm(self.conv2d) 24 | 25 | self.activation = activation 26 | for m in self.sublayers(): 27 | if isinstance(m, nn.Conv2D): 28 | n = m.weight.shape[0] * m.weight.shape[1] * m.weight.shape[2] 29 | v = np.random.normal(loc=0., scale=np.sqrt(2. / n), size=m.weight.shape).astype('float32') 30 | m.weight.set_value(v) 31 | 32 | def forward(self, input): 33 | x = self.conv2d(input) 34 | if self.activation is not None: 35 | return self.activation(x) 36 | else: 37 | return x 38 | 39 | 40 | class DeConvWithActivation(nn.Layer): 41 | ''' 42 | SN convolution for spetral normalization conv 43 | ''' 44 | 45 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, 46 | output_padding=1, bias=True, activation=nn.LeakyReLU(0.2)): 47 | super(DeConvWithActivation, self).__init__() 48 | self.conv2d = nn.Conv2DTranspose(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 49 | padding=padding, dilation=dilation, groups=groups, 50 | output_padding=output_padding, bias_attr=bias) 51 | self.conv2d = nn.utils.spectral_norm(self.conv2d) 52 | self.activation = activation 53 | 54 | def forward(self, input): 55 | 56 | x = self.conv2d(input) 57 | 58 | if self.activation is not None: 59 | return self.activation(x) 60 | else: 61 | return x -------------------------------------------------------------------------------- /models/non_local.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | from paddle import nn 3 | from paddle.nn import functional as F 4 | # import math 5 | # import numpy as np 6 | # from context_block import ContextBlock 7 | 8 | # def get_nonlocal_block(block_type): 9 | # block_dict = {'nl': NonLocal, 'bat': BATBlock, 'gc': ContextBlock} 10 | # if block_type in block_dict: 11 | # return block_dict[block_type] 12 | # else: 13 | # raise ValueError("UNKOWN NONLOCAL BLOCK TYPE:", block_type) 14 | 15 | class NonLocalBlock(nn.Layer): 16 | def __init__(self, channel): 17 | super(NonLocalBlock, self).__init__() 18 | self.inter_channel = channel // 2 19 | self.conv_phi = nn.Conv2D(channel, self.inter_channel, kernel_size=1, stride=1, bias_attr=False) 20 | self.conv_theta = nn.Conv2D(channel, self.inter_channel, kernel_size=1, stride=1, bias_attr=False) 21 | self.conv_g = nn.Conv2D(channel, self.inter_channel, kernel_size=1, stride=1, bias_attr=False) 22 | self.softmax = nn.Softmax(axis=1) 23 | self.conv_mask = nn.Conv2D(self.inter_channel, channel, kernel_size=1, stride=1, bias_attr=False) 24 | 25 | def forward(self, x): 26 | # [N, C, H , W] 27 | b, c, h, w = x.shape 28 | # 获取phi特征,维度为[N, C/2, H * W],注意是要保留batch和通道维度的,是在HW上 29 | x_phi = self.conv_phi(x) 30 | x_phi = paddle.reshape(x_phi, (b, c, -1)) 31 | # 获取theta特征,维度为[N, H * W, C/2] 32 | x_theta = self.conv_theta(x) 33 | x_theta = paddle.transpose(paddle.reshape(x_theta, (b, c, -1)), (0, 2, 1)) 34 | # 获取g特征,维度为[N, H * W, C/2] 35 | x_g = self.conv_g(x) 36 | # x_g = paddle.reshape(x_g, (b, c, -1)).permute(0, 2, 1).contiguous() 37 | x_g = paddle.transpose(paddle.reshape(x_g, (b, c, -1)), (0, 2, 1)) 38 | # 对phi和theta进行矩阵乘,[N, H * W, H * W] 39 | # print(x_theta.shape, x_phi.shape) # [1, 8192, 64] [1, 64, 8192] 40 | mul_theta_phi = paddle.matmul(x_theta, x_phi) 41 | # softmax拉到0~1之间 42 | # print(mul_theta_phi.shape) # [1, 8192, 8192] 43 | mul_theta_phi = self.softmax(mul_theta_phi) 44 | # 与g特征进行矩阵乘运算,[N, H * W, C/2] 45 | mul_theta_phi_g = paddle.matmul(mul_theta_phi, x_g) 46 | # [N, C/2, H, W] 47 | mul_theta_phi_g = paddle.transpose(mul_theta_phi_g, (0, 2, 1)) 48 | mul_theta_phi_g = paddle.reshape(mul_theta_phi_g, (b, self.inter_channel, h, w)) 49 | # 1X1卷积扩充通道数 50 | mask = self.conv_mask(mul_theta_phi_g) 51 | out = mask + x # 残差连接 52 | return out 53 | 54 | class NonLocalModule(nn.Layer): 55 | 56 | def __init__(self, in_channels, **kwargs): 57 | super(NonLocalModule, self).__init__() 58 | 59 | def init_modules(self): 60 | for m in self.sublayers(): 61 | if len(m.sublayers()) > 0: 62 | continue 63 | if isinstance(m, nn.Conv2D): 64 | m.weight=m.create_parameter(m.weight.shape, default_initializer=nn.initializer.KaimingNormal()) 65 | if len(list(m.parameters())) > 1: 66 | m.bias.set_value(paddle.zeros(m.bias.shape)) 67 | elif isinstance(m, nn.BatchNorm2D): 68 | m.weight.set_value(paddle.zeros(m.weight.shape)) 69 | m.bias.set_value(paddle.zeros(m.bias.shape)) 70 | elif isinstance(m, nn.GroupNorm): 71 | m.weight.set_value(paddle.zeros(m.weight.shape)) 72 | m.bias.set_value(paddle.zeros(m.bias.shape)) 73 | elif len(list(m.parameters())) > 0: 74 | raise ValueError("UNKOWN NONLOCAL LAYER TYPE:", name, m) 75 | 76 | 77 | class NonLocal(NonLocalModule): 78 | """Spatial NL block for image classification. 79 | [https://github.com/facebookresearch/video-nonlocal-net]. 80 | """ 81 | 82 | def __init__(self, inplanes, use_scale=False, **kwargs): 83 | planes = inplanes // 2 84 | self.use_scale = use_scale 85 | 86 | super(NonLocal, self).__init__(inplanes) 87 | self.t = nn.Conv2D(inplanes, planes, kernel_size=1, 88 | stride=1, bias_attr=True) 89 | self.p = nn.Conv2D(inplanes, planes, kernel_size=1, 90 | stride=1, bias_attr=True) 91 | self.g = nn.Conv2D(inplanes, planes, kernel_size=1, 92 | stride=1, bias_attr=True) 93 | self.softmax = nn.Softmax(axis=2) 94 | self.z = nn.Conv2D(planes, inplanes, kernel_size=1, 95 | stride=1, bias_attr=True) 96 | self.bn = nn.BatchNorm2D(inplanes) 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | t = self.t(x) 102 | p = self.p(x) 103 | g = self.g(x) 104 | 105 | b, c, h, w = t.shape 106 | 107 | t = paddle.transpose(paddle.reshape(t, (b, c, -1)), (0, 2, 1)) 108 | p = paddle.reshape(p, (b, c, -1)) 109 | g = paddle.transpose(paddle.reshape(g, (b, c, -1)), (0, 2, 1)) 110 | # print(t.shape, p.shape) 111 | att = paddle.bmm(t, p) 112 | # print(att.shape) 113 | if self.use_scale: 114 | att = paddle.divide(att, paddle.to_tensor(c**0.5)) 115 | # print(att.shape) # [4, 128, 64, 64] # [4, 64, 128, 128] 116 | att = self.softmax(att) 117 | x = paddle.bmm(att, g) 118 | 119 | x = paddle.transpose(x, (0, 2, 1)) 120 | x = paddle.reshape(x, (b, c, h, w)) 121 | 122 | x = self.z(x) 123 | x = self.bn(x) + residual 124 | # x = x + residual 125 | 126 | return x 127 | 128 | 129 | class BATransform(nn.Layer): 130 | 131 | def __init__(self, in_channels, s, k): 132 | super(BATransform, self).__init__() 133 | 134 | self.conv1 = nn.Sequential(nn.Conv2D(in_channels, k, 1), 135 | nn.BatchNorm2D(k), 136 | nn.ReLU()) 137 | self.conv_p = nn.Conv2D(k, s * s * k, [s, 1]) 138 | self.conv_q = nn.Conv2D(k, s * s * k, [1, s]) 139 | self.conv2 = nn.Sequential(nn.Conv2D(in_channels, in_channels, 1), 140 | nn.BatchNorm2D(in_channels), 141 | nn.ReLU()) 142 | self.s = s 143 | self.k = k 144 | self.in_channels = in_channels 145 | 146 | def extra_repr(self): 147 | return 'BATransform({in_channels}, s={s}, k={k})'.format(**self.__dict__) 148 | 149 | def resize_mat(self, x, t): 150 | n, c, s, s1 = x.shape 151 | assert s == s1 152 | if t <= 1: 153 | return x 154 | x = paddle.reshape(x, (n * c, -1, 1, 1)) 155 | x = x * paddle.eye(t, t, dtype=x.dtype) 156 | x = paddle.reshape(x, (n * c, s, s, t, t)) 157 | x = paddle.concat(paddle.split(x, 1, axis=1), axis=3) 158 | x = paddle.concat(paddle.split(x, 1, axis=2), axis=4) 159 | x = paddle.reshape(x, (n, c, s * t, s * t)) 160 | return x 161 | 162 | def forward(self, x): 163 | out = self.conv1(x) 164 | rp = F.adaptive_max_pool2d(out, (self.s, 1)) 165 | cp = F.adaptive_max_pool2d(out, (1, self.s)) 166 | p = paddle.reshape(self.conv_p(rp), (x.shape[0], self.k, self.s, self.s)) 167 | q = paddle.reshape(self.conv_q(cp), (x.shape[0], self.k, self.s, self.s)) 168 | p = F.sigmoid(p) 169 | q = F.sigmoid(q) 170 | p = p / paddle.sum(p, axis=3, keepdim=True) 171 | q = q / paddle.sum(q, axis=2, keepdim=True) 172 | 173 | p = paddle.reshape(p, (x.shape[0], self.k, 1, self.s, self.s)) 174 | p = paddle.expand(p, (x.shape[0], self.k, x.shape[1] // self.k, self.s, self.s)) 175 | 176 | p = paddle.reshape(p, (x.shape[0], x.shape[1], self.s, self.s)) 177 | 178 | q = paddle.reshape(q, (x.shape[0], self.k, 1, self.s, self.s)) 179 | q = paddle.expand(q, (x.shape[0], self.k, x.shape[1] // self.k, self.s, self.s)) 180 | 181 | q = paddle.reshape(q, (x.shape[0], x.shape[1], self.s, self.s)) 182 | 183 | p = self.resize_mat(p, x.shape[2] // self.s) 184 | q = self.resize_mat(q, x.shape[2] // self.s) 185 | y = paddle.matmul(p, x) 186 | y = paddle.matmul(y, q) 187 | 188 | y = self.conv2(y) 189 | return y 190 | 191 | 192 | class BATBlock(NonLocalModule): 193 | 194 | def __init__(self, in_channels, r=2, s=4, k=4, dropout=0.2, **kwargs): 195 | super().__init__(in_channels) 196 | 197 | inter_channels = in_channels // r 198 | self.conv1 = nn.Sequential(nn.Conv2D(in_channels, inter_channels, 1), 199 | nn.BatchNorm2D(inter_channels), 200 | nn.ReLU()) 201 | self.batransform = BATransform(inter_channels, s, k) 202 | self.conv2 = nn.Sequential(nn.Conv2D(inter_channels, in_channels, 1), 203 | nn.BatchNorm2D(in_channels), 204 | nn.ReLU()) 205 | self.dropout = nn.Dropout2D(p=dropout) 206 | 207 | def forward(self, x): 208 | xl = self.conv1(x) 209 | y = self.batransform(xl) 210 | y = self.conv2(y) 211 | y = self.dropout(y) 212 | return y + x 213 | 214 | def init_modules(self): 215 | for m in self.sublayers(): 216 | if isinstance(m, nn.Conv2D): 217 | m.weight=m.create_parameter(m.weight.shape, default_initializer=nn.initializer.KaimingNormal()) 218 | elif isinstance(m, nn.BatchNorm2D): 219 | m.weight.set_value(paddle.ones(m.weight.shape)) 220 | m.bias.set_value(paddle.zeros(m.bias.shape)) 221 | 222 | if __name__ == '__main__': 223 | x = paddle.rand([1, 64, 128, 128]) 224 | net = NonLocal(inplanes=64) 225 | # net = BATBlock(in_channels=128) 226 | # net = NonLocalBlock(channel=64) 227 | out = net(x) 228 | print(out.shape) 229 | 230 | 231 | -------------------------------------------------------------------------------- /models/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/sa_aidr.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import paddle.nn as nn 3 | import paddle.nn.functional as F 4 | import numpy as np 5 | from PIL import Image 6 | from models.networks import get_pad, ConvWithActivation, DeConvWithActivation 7 | from models.idr import AIDR 8 | 9 | def img2photo(imgs): 10 | return ((imgs + 1) * 127.5).transpose(1, 2).transpose(2, 3).detach().cpu().numpy() 11 | 12 | 13 | def visual(imgs): 14 | im = img2photo(imgs) 15 | Image.fromarray(im[0].astype(np.uint8)).show() 16 | 17 | 18 | class Residual(nn.Layer): 19 | def __init__(self, in_channels, out_channels, same_shape=True, **kwargs): 20 | super(Residual, self).__init__() 21 | self.same_shape = same_shape 22 | strides = 1 if same_shape else 2 23 | self.conv1 = nn.Conv2D(in_channels, in_channels, kernel_size=3, padding=1, stride=strides) 24 | self.conv2 = nn.Conv2D(in_channels, out_channels, kernel_size=3, padding=1) 25 | # self.conv2 = torch.nn.utils.spectral_norm(self.conv2) 26 | if not same_shape: 27 | self.conv3 = nn.Conv2D(in_channels, out_channels, kernel_size=1, 28 | # self.conv3 = nn.Conv2D(channels, kernel_size=3, padding=1, 29 | stride=strides) 30 | # self.conv3 = torch.nn.utils.spectral_norm(self.conv3) 31 | self.batch_norm2d = nn.BatchNorm2D(out_channels) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.conv1(x)) 35 | out = self.conv2(out) 36 | if not self.same_shape: 37 | x = self.conv3(x) 38 | out = self.batch_norm2d(out + x) 39 | # out = out + x 40 | return F.relu(out) 41 | 42 | 43 | class ASPP(nn.Layer): 44 | def __init__(self, in_channel=512, depth=256): 45 | super(ASPP, self).__init__() 46 | self.mean = nn.AdaptiveAvgPool2D((1, 1)) 47 | self.conv = nn.Conv2D(in_channel, depth, 1, 1) 48 | # k=1 s=1 no pad 49 | self.atrous_block1 = nn.Conv2D(in_channel, depth, 1, 1) 50 | self.atrous_block6 = nn.Conv2D(in_channel, depth, 3, 1, padding=6, dilation=6) 51 | self.atrous_block12 = nn.Conv2D(in_channel, depth, 3, 1, padding=12, dilation=12) 52 | self.atrous_block18 = nn.Conv2D(in_channel, depth, 3, 1, padding=18, dilation=18) 53 | 54 | self.conv_1x1_output = nn.Conv2D(depth * 5, depth, 1, 1) 55 | 56 | def forward(self, x): 57 | size = x.shape[2:] 58 | 59 | image_features = self.mean(x) 60 | image_features = self.conv(image_features) 61 | image_features = F.upsample(image_features, size=size, mode='bilinear') 62 | 63 | atrous_block1 = self.atrous_block1(x) 64 | 65 | atrous_block6 = self.atrous_block6(x) 66 | 67 | atrous_block12 = self.atrous_block12(x) 68 | 69 | atrous_block18 = self.atrous_block18(x) 70 | 71 | net = self.conv_1x1_output(paddle.concat([image_features, atrous_block1, atrous_block6, 72 | atrous_block12, atrous_block18], axis=1)) 73 | return net 74 | 75 | 76 | class STRAIDR(nn.Layer): 77 | def __init__(self, n_in_channel=3, num_c=48): 78 | super(STRAIDR, self).__init__() 79 | #### U-Net #### 80 | # downsample 81 | self.conv1 = ConvWithActivation(3, 32, kernel_size=4, stride=2, padding=1) 82 | self.conva = ConvWithActivation(32, 32, kernel_size=3, stride=1, padding=1) 83 | self.convb = ConvWithActivation(32, 64, kernel_size=4, stride=2, padding=1) 84 | self.res1 = Residual(64, 64) 85 | self.res2 = Residual(64, 64) 86 | self.res3 = Residual(64, 128, same_shape=False) 87 | self.res4 = Residual(128, 128) 88 | self.res5 = Residual(128, 256, same_shape=False) 89 | # self.nn = ConvWithActivation(256, 512, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2)) 90 | self.res6 = Residual(256, 256) 91 | self.res7 = Residual(256, 512, same_shape=False) 92 | self.res8 = Residual(512, 512) 93 | self.conv2 = ConvWithActivation(512, 512, kernel_size=1) 94 | 95 | # upsample 96 | self.deconv1 = DeConvWithActivation(512, 256, kernel_size=3, padding=1, stride=2) 97 | self.deconv2 = DeConvWithActivation(256 * 2, 128, kernel_size=3, padding=1, stride=2) 98 | self.deconv3 = DeConvWithActivation(128 * 2, 64, kernel_size=3, padding=1, stride=2) 99 | self.deconv4 = DeConvWithActivation(64 * 2, 32, kernel_size=3, padding=1, stride=2) 100 | self.deconv5 = DeConvWithActivation(64, 3, kernel_size=3, padding=1, stride=2) 101 | 102 | # lateral connection 103 | self.lateral_connection1 = nn.Sequential( 104 | nn.Conv2D(256, 256, kernel_size=1, padding=0, stride=1), 105 | nn.Conv2D(256, 512, kernel_size=3, padding=1, stride=1), 106 | nn.Conv2D(512, 512, kernel_size=3, padding=1, stride=1), 107 | nn.Conv2D(512, 256, kernel_size=1, padding=0, stride=1), ) 108 | self.lateral_connection2 = nn.Sequential( 109 | nn.Conv2D(128, 128, kernel_size=1, padding=0, stride=1), 110 | nn.Conv2D(128, 256, kernel_size=3, padding=1, stride=1), 111 | nn.Conv2D(256, 256, kernel_size=3, padding=1, stride=1), 112 | nn.Conv2D(256, 128, kernel_size=1, padding=0, stride=1), ) 113 | self.lateral_connection3 = nn.Sequential( 114 | nn.Conv2D(64, 64, kernel_size=1, padding=0, stride=1), 115 | nn.Conv2D(64, 128, kernel_size=3, padding=1, stride=1), 116 | nn.Conv2D(128, 128, kernel_size=3, padding=1, stride=1), 117 | nn.Conv2D(128, 64, kernel_size=1, padding=0, stride=1), ) 118 | self.lateral_connection4 = nn.Sequential( 119 | nn.Conv2D(32, 32, kernel_size=1, padding=0, stride=1), 120 | nn.Conv2D(32, 64, kernel_size=3, padding=1, stride=1), 121 | nn.Conv2D(64, 64, kernel_size=3, padding=1, stride=1), 122 | nn.Conv2D(64, 32, kernel_size=1, padding=0, stride=1), ) 123 | 124 | # self.relu = nn.elu(alpha=1.0) 125 | self.conv_o1 = nn.Conv2D(64, 3, kernel_size=1) 126 | self.conv_o2 = nn.Conv2D(32, 3, kernel_size=1) 127 | ##### U-Net ##### 128 | 129 | ### ASPP ### 130 | # self.aspp = ASPP(512, 256) 131 | ### ASPP ### 132 | 133 | ### mask branch decoder ### 134 | self.mask_deconv_a = DeConvWithActivation(512, 256, kernel_size=3, padding=1, stride=2) 135 | self.mask_conv_a = ConvWithActivation(256, 128, kernel_size=3, padding=1, stride=1) 136 | self.mask_deconv_b = DeConvWithActivation(256, 128, kernel_size=3, padding=1, stride=2) 137 | self.mask_conv_b = ConvWithActivation(128, 64, kernel_size=3, padding=1, stride=1) 138 | self.mask_deconv_c = DeConvWithActivation(128, 64, kernel_size=3, padding=1, stride=2) 139 | self.mask_conv_c = ConvWithActivation(64, 32, kernel_size=3, padding=1, stride=1) 140 | self.mask_deconv_d = DeConvWithActivation(64, 32, kernel_size=3, padding=1, stride=2) 141 | self.mask_conv_d = nn.Conv2D(32, 3, kernel_size=1) 142 | ### mask branch ### 143 | 144 | ##### Refine sub-network ###### 145 | self.refine = AIDR(num_c=num_c) 146 | self.c1 = nn.Conv2D(32, 64, kernel_size=1) 147 | self.c2 = nn.Conv2D(64, 128, kernel_size=1) 148 | self.sig = nn.Sigmoid() 149 | 150 | def forward(self, x): 151 | # x: 3, h, w 152 | # downsample 153 | x = self.conv1(x) # 32, h/2,w/2 154 | x = self.conva(x) # 32, h/2,w/2 155 | con_x1 = x 156 | # print('con_x1: ',con_x1.shape) 157 | # import pdb;pdb.set_trace() 158 | x = self.convb(x) # 64, h/4,w/4 159 | x = self.res1(x) # 64, h/4,w/4 160 | con_x2 = x 161 | # print('con_x2: ', con_x2.shape) 162 | x = self.res2(x) # 64, h/4,w/4 163 | x = self.res3(x) # 128, h/8,w/8 164 | con_x3 = x 165 | # print('con_x3: ', con_x3.shape) 166 | x = self.res4(x) # 128, h/8,w/8 167 | x = self.res5(x) # 256, h/16,w/16 168 | con_x4 = x 169 | # print('con_x4: ', con_x4.shape) 170 | x = self.res6(x) # 256, h/16,w/16 171 | # x_mask = self.nn(con_x4) ### for mask branch aspp 172 | # x_mask = self.aspp(x_mask) ### for mask branch aspp 173 | x_mask = x ### no aspp 174 | # print('x_mask: ', x_mask.shape) 175 | # import pdb;pdb.set_trace() 176 | x = self.res7(x) # 512, h/32,w/32 177 | x = self.res8(x) # 512, h/32,w/32 178 | x = self.conv2(x) # 512, h/32,w/32 179 | # upsample 180 | x = self.deconv1(x) # 256, h/16,w/16 181 | # print(x.shape,con_x4.shape, self.lateral_connection1(con_x4).shape) 182 | x = paddle.concat([self.lateral_connection1(con_x4), x], axis=1) # 256 + 256 183 | x = self.deconv2(x) # 512->128, h/8,w/8 184 | x = paddle.concat([self.lateral_connection2(con_x3), x], axis=1) # 128 + 128 185 | x = self.deconv3(x) # 256->64, h/4,w/4 186 | xo1 = x 187 | x = paddle.concat([self.lateral_connection3(con_x2), x], axis=1) # 64 + 64 188 | x = self.deconv4(x) # 128->32, h/2,w/2 189 | xo2 = x 190 | x = paddle.concat([self.lateral_connection4(con_x1), x], axis=1) # 32 + 32 191 | # import pdb;pdb.set_trace() 192 | x = self.deconv5(x) # 64->3, h, w 193 | x_o1 = self.conv_o1(xo1) # 64->3, h/4,w/4 194 | x_o2 = self.conv_o2(xo2) # 32->3, h/2,w/2 195 | x_o_unet = x 196 | 197 | ### mask branch ### 198 | mm = self.mask_deconv_a(paddle.concat([x_mask, con_x4], axis=1)) # 256 + 256 -> 256 , h/8,w/8 199 | mm = self.mask_conv_a(mm) # 256 -> 128, h/8,w/8 200 | mm = self.mask_deconv_b(paddle.concat([mm, con_x3], axis=1)) # 128 + 128 -> 128, h/4,w/4 201 | mm = self.mask_conv_b(mm) # 128 -> 64, h/4,w/4 202 | mm = self.mask_deconv_c(paddle.concat([mm, con_x2], axis=1)) # 64 + 64 -> 64, h/2, w/2 203 | mm = self.mask_conv_c(mm) # 64 -> 32, h/2, w/2 204 | mm = self.mask_deconv_d(paddle.concat([mm, con_x1], axis=1)) # 32 +32 -> 32, h, w 205 | mm = self.mask_conv_d(mm) # 32 -> 3, h, w 206 | mm = self.sig(mm) 207 | ### mask branch end ### 208 | 209 | ###refine sub-network 210 | x = self.refine(x_o_unet, con_x2, con_x3, con_x4) 211 | return x_o1, x_o2, x_o_unet, x, mm 212 | 213 | 214 | if __name__ == '__main__': 215 | net = STRAIDR() 216 | x = paddle.rand([1, 3, 64, 64]) 217 | x_o1, x_o2, x_o_unet, x, mm = net(x) 218 | print(x.shape, mm.shape) 219 | -------------------------------------------------------------------------------- /models/sa_gan.py: -------------------------------------------------------------------------------- 1 | # from x2paddle import torch2paddle 2 | import paddle 3 | import paddle.nn as nn 4 | import paddle.nn.functional as F 5 | import numpy as np 6 | from PIL import Image 7 | from paddle import to_tensor 8 | from models.networks import get_pad 9 | from models.networks import ConvWithActivation 10 | from models.networks import DeConvWithActivation 11 | 12 | 13 | class Residual(nn.Layer): 14 | 15 | def __init__(self, in_channels, out_channels, same_shape=True, **kwargs): 16 | super(Residual, self).__init__() 17 | self.same_shape = same_shape 18 | strides = 1 if same_shape else 2 19 | self.conv1 = nn.Conv2D(in_channels, in_channels, kernel_size=3, 20 | padding=1, stride=strides) 21 | self.conv2 = nn.Conv2D(in_channels, out_channels, kernel_size=3, 22 | padding=1) 23 | if not same_shape: 24 | self.conv3 = nn.Conv2D(in_channels, out_channels, kernel_size=1, 25 | stride=strides) 26 | self.batch_norm2d = nn.BatchNorm2D(out_channels) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.conv1(x)) 30 | out = self.conv2(out) 31 | if not self.same_shape: 32 | x = self.conv3(x) 33 | out = self.batch_norm2d(out + x) 34 | return F.relu(out) 35 | 36 | 37 | class STRnet2(nn.Layer): 38 | 39 | def __init__(self, n_in_channel=3): 40 | super(STRnet2, self).__init__() 41 | self.conv1 = ConvWithActivation(3, 32, kernel_size=4, stride=2, 42 | padding=1) 43 | self.conva = ConvWithActivation(32, 32, kernel_size=3, stride=1, 44 | padding=1) 45 | self.convb = ConvWithActivation(32, 64, kernel_size=4, stride=2, 46 | padding=1) 47 | self.res1 = Residual(64, 64) 48 | self.res2 = Residual(64, 64) 49 | self.res3 = Residual(64, 128, same_shape=False) 50 | self.res4 = Residual(128, 128) 51 | self.res5 = Residual(128, 256, same_shape=False) 52 | self.res6 = Residual(256, 256) 53 | self.res7 = Residual(256, 512, same_shape=False) 54 | self.res8 = Residual(512, 512) 55 | self.conv2 = ConvWithActivation(512, 512, kernel_size=1) 56 | self.deconv1 = DeConvWithActivation(512, 256, kernel_size=3, 57 | padding=1, stride=2) 58 | self.deconv2 = DeConvWithActivation(256 * 2, 128, kernel_size=3, 59 | padding=1, stride=2) 60 | self.deconv3 = DeConvWithActivation(128 * 2, 64, kernel_size=3, 61 | padding=1, stride=2) 62 | self.deconv4 = DeConvWithActivation(64 * 2, 32, kernel_size=3, 63 | padding=1, stride=2) 64 | self.deconv5 = DeConvWithActivation(64, 3, kernel_size=3, padding=1, 65 | stride=2) 66 | self.lateral_connection1 = nn.Sequential(nn.Conv2D(256, 256, 67 | kernel_size=1, padding=0, stride=1), nn.Conv2D(256, 512, 68 | kernel_size=3, padding=1, stride=1), nn.Conv2D(512, 512, 69 | kernel_size=3, padding=1, stride=1), nn.Conv2D(512, 256, 70 | kernel_size=1, padding=0, stride=1)) 71 | self.lateral_connection2 = nn.Sequential(nn.Conv2D(128, 128, 72 | kernel_size=1, padding=0, stride=1), nn.Conv2D(128, 256, 73 | kernel_size=3, padding=1, stride=1), nn.Conv2D(256, 256, 74 | kernel_size=3, padding=1, stride=1), nn.Conv2D(256, 128, 75 | kernel_size=1, padding=0, stride=1)) 76 | self.lateral_connection3 = nn.Sequential(nn.Conv2D(64, 64, 77 | kernel_size=1, padding=0, stride=1), nn.Conv2D(64, 128, 78 | kernel_size=3, padding=1, stride=1), nn.Conv2D(128, 128, 79 | kernel_size=3, padding=1, stride=1), nn.Conv2D(128, 64, 80 | kernel_size=1, padding=0, stride=1)) 81 | self.lateral_connection4 = nn.Sequential(nn.Conv2D(32, 32, 82 | kernel_size=1, padding=0, stride=1), nn.Conv2D(32, 64, 83 | kernel_size=3, padding=1, stride=1), nn.Conv2D(64, 64, 84 | kernel_size=3, padding=1, stride=1), nn.Conv2D(64, 32, 85 | kernel_size=1, padding=0, stride=1)) 86 | self.conv_o1 = nn.Conv2D(64, 3, kernel_size=1) 87 | self.conv_o2 = nn.Conv2D(32, 3, kernel_size=1) 88 | self.mask_deconv_a = DeConvWithActivation(512, 256, kernel_size=3, 89 | padding=1, stride=2) 90 | self.mask_conv_a = ConvWithActivation(256, 128, kernel_size=3, 91 | padding=1, stride=1) 92 | self.mask_deconv_b = DeConvWithActivation(256, 128, kernel_size=3, 93 | padding=1, stride=2) 94 | self.mask_conv_b = ConvWithActivation(128, 64, kernel_size=3, 95 | padding=1, stride=1) 96 | self.mask_deconv_c = DeConvWithActivation(128, 64, kernel_size=3, 97 | padding=1, stride=2) 98 | self.mask_conv_c = ConvWithActivation(64, 32, kernel_size=3, 99 | padding=1, stride=1) 100 | self.mask_deconv_d = DeConvWithActivation(64, 32, kernel_size=3, 101 | padding=1, stride=2) 102 | self.mask_conv_d = nn.Conv2D(32, 3, kernel_size=1) 103 | n_in_channel = 3 104 | cnum = 32 105 | self.coarse_conva = ConvWithActivation(n_in_channel, cnum, 106 | kernel_size=5, stride=1, padding=2) 107 | self.coarse_convb = ConvWithActivation(cnum, 2 * cnum, kernel_size=\ 108 | 4, stride=2, padding=1) 109 | self.coarse_convc = ConvWithActivation(2 * cnum, 2 * cnum, 110 | kernel_size=3, stride=1, padding=1) 111 | self.coarse_convd = ConvWithActivation(2 * cnum, 4 * cnum, 112 | kernel_size=4, stride=2, padding=1) 113 | self.coarse_conve = ConvWithActivation(4 * cnum, 4 * cnum, 114 | kernel_size=3, stride=1, padding=1) 115 | self.coarse_convf = ConvWithActivation(4 * cnum, 4 * cnum, 116 | kernel_size=3, stride=1, padding=1) 117 | self.astrous_net = nn.Sequential(ConvWithActivation(4 * cnum, 4 * 118 | cnum, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2)), 119 | ConvWithActivation(4 * cnum, 4 * cnum, 3, 1, dilation=4, 120 | padding=get_pad(64, 3, 1, 4)), ConvWithActivation(4 * cnum, 4 * 121 | cnum, 3, 1, dilation=8, padding=get_pad(64, 3, 1, 8)), 122 | ConvWithActivation(4 * cnum, 4 * cnum, 3, 1, dilation=16, 123 | padding=get_pad(64, 3, 1, 16))) 124 | self.coarse_convk = ConvWithActivation(4 * cnum, 4 * cnum, 125 | kernel_size=3, stride=1, padding=1) 126 | self.coarse_convl = ConvWithActivation(4 * cnum, 4 * cnum, 127 | kernel_size=3, stride=1, padding=1) 128 | self.coarse_deconva = DeConvWithActivation(4 * cnum * 3, 2 * cnum, 129 | kernel_size=3, padding=1, stride=2) 130 | self.coarse_convm = ConvWithActivation(2 * cnum, 2 * cnum, 131 | kernel_size=3, stride=1, padding=1) 132 | self.coarse_deconvb = DeConvWithActivation(2 * cnum * 3, cnum, 133 | kernel_size=3, padding=1, stride=2) 134 | self.coarse_convn = nn.Sequential(ConvWithActivation(cnum, cnum // 135 | 2, kernel_size=3, stride=1, padding=1), ConvWithActivation(cnum // 136 | 2, 3, kernel_size=3, stride=1, padding=1, activation=None)) 137 | self.c1 = nn.Conv2D(32, 64, kernel_size=1) 138 | self.c2 = nn.Conv2D(64, 128, kernel_size=1) 139 | self.sig = nn.Sigmoid() 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.conva(x) 144 | con_x1 = x 145 | x = self.convb(x) 146 | x = self.res1(x) 147 | con_x2 = x 148 | x = self.res2(x) 149 | x = self.res3(x) 150 | con_x3 = x 151 | x = self.res4(x) 152 | x = self.res5(x) 153 | con_x4 = x 154 | x = self.res6(x) 155 | x_mask = x 156 | x = self.res7(x) 157 | x = self.res8(x) 158 | x = self.conv2(x) 159 | x = self.deconv1(x) 160 | x = paddle.concat([self.lateral_connection1(con_x4), x], axis=1) 161 | x = self.deconv2(x) 162 | x = paddle.concat([self.lateral_connection2(con_x3), x], axis=1) 163 | x = self.deconv3(x) 164 | xo1 = x 165 | x = paddle.concat([self.lateral_connection3(con_x2), x], axis=1) 166 | x = self.deconv4(x) 167 | xo2 = x 168 | x = paddle.concat([self.lateral_connection4(con_x1), x], axis=1) 169 | x = self.deconv5(x) 170 | x_o1 = self.conv_o1(xo1) 171 | x_o2 = self.conv_o2(xo2) 172 | x_o_unet = x 173 | mm = self.mask_deconv_a(paddle.concat([x_mask, con_x4], axis=1)) 174 | mm = self.mask_conv_a(mm) 175 | mm = self.mask_deconv_b(paddle.concat([mm, con_x3], axis=1)) 176 | mm = self.mask_conv_b(mm) 177 | mm = self.mask_deconv_c(paddle.concat([mm, con_x2], axis=1)) 178 | mm = self.mask_conv_c(mm) 179 | mm = self.mask_deconv_d(paddle.concat([mm, con_x1], axis=1)) 180 | mm = self.mask_conv_d(mm) 181 | mm = self.sig(mm) 182 | x = self.coarse_conva(x_o_unet) 183 | x = self.coarse_convb(x) 184 | x = self.coarse_convc(x) 185 | x_c1 = x 186 | x = self.coarse_convd(x) 187 | x = self.coarse_conve(x) 188 | x = self.coarse_convf(x) 189 | x_c2 = x 190 | x = self.astrous_net(x) 191 | x = self.coarse_convk(x) 192 | x = self.coarse_convl(x) 193 | x = self.coarse_deconva(paddle.concat([x, x_c2, self.c2(con_x2)], axis=1)) 194 | x = self.coarse_convm(x) 195 | x = self.coarse_deconvb(paddle.concat([x, x_c1, self.c1(con_x1)], axis=1)) 196 | x = self.coarse_convn(x) 197 | return x_o1, x_o2, x_o_unet, x, mm 198 | 199 | if __name__ == '__main__': 200 | net = STRnet2() 201 | x = paddle.rand([1, 3, 64, 64]) 202 | x_o1, x_o2, x_o_unet, x, mm = net(x) 203 | print(x.shape, mm.shape) 204 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # 代码示例 2 | # python predict.py /dataset/baidu/watermark_test_datasets/images results 3 | 4 | import os 5 | import sys 6 | 7 | import cv2 8 | import paddle 9 | import paddle.nn as nn 10 | import paddle.nn.functional as F 11 | from paddle.io import DataLoader 12 | import numpy as np 13 | from test_dataloader import devdata 14 | from models.sa_aidr import STRAIDR 15 | import time 16 | TIME = [] 17 | 18 | def pd_tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 19 | img = tensor.squeeze().cpu().numpy() 20 | img = img.clip(min_max[0], min_max[1]) 21 | img = (img - min_max[0]) / (min_max[1] - min_max[0]) 22 | if out_type == np.uint8: 23 | # scaling 24 | img = img * 255.0 25 | img = np.transpose(img, (1, 2, 0)) 26 | img = img.round() 27 | img = img[:, :, ::-1] 28 | return img.astype(out_type) 29 | 30 | def process(src_image_dir, save_dir): 31 | start = time.time() 32 | netG = STRAIDR(num_c=96) 33 | weights = paddle.load("model_best.pdparams") 34 | netG.load_dict(weights) 35 | 36 | netG.eval() 37 | 38 | Erase_data = devdata(dataRoot=src_image_dir, gtRoot=src_image_dir) 39 | Erase_data = DataLoader(Erase_data, batch_size=1, shuffle=False, num_workers=0, drop_last=False) 40 | 41 | print('OK!') 42 | 43 | for index, (imgs, path) in enumerate(Erase_data): 44 | _, _, h, w = imgs.shape 45 | if h < 1600 and w < 1600: # 1000, 0.2399 46 | pad_size = 128 47 | h_padded = False 48 | w_padded = False 49 | if h % pad_size != 0: 50 | pad_h = pad_size - (h % pad_size) 51 | imgs = F.pad(imgs, (0, 0, 0, pad_h), mode='reflect') 52 | h_padded = True 53 | 54 | if w % pad_size != 0: 55 | pad_w = pad_size - (w % pad_size) 56 | imgs = F.pad(imgs, (0, pad_w, 0, 0), mode='reflect') 57 | w_padded = True 58 | print(index, imgs.shape, path) 59 | with paddle.no_grad(): 60 | res = netG(imgs)[3] 61 | res += paddle.flip(netG(paddle.flip(imgs, axis=[2]))[3], axis=[2]) 62 | res += paddle.flip(netG(paddle.flip(imgs, axis=[3]))[3], axis=[3]) 63 | res += paddle.flip(netG(paddle.flip(imgs, axis=[2, 3]))[3], axis=[2, 3]) 64 | res = res / 4 # 16 + 480 65 | if h_padded: 66 | res = res[:, :, 0:h, :] 67 | if w_padded: 68 | res = res[:, :, :, 0:w] 69 | res = pd_tensor2img(res) 70 | cv2.imwrite(os.path.join(save_dir, path[0]), res, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) 71 | else: 72 | pad = 112 73 | m = nn.Pad2D(pad, mode='reflect') 74 | imgs = m(imgs) 75 | print(index, imgs.shape, path) 76 | _, _, h, w = imgs.shape 77 | step = 800 78 | res = paddle.zeros_like(imgs) 79 | for i in range(0, h, step): 80 | for j in range(0, w, step): 81 | if h - i < step + 2 * pad: 82 | i = h - (step + 2 * pad) 83 | if w - j < step + 2 * pad: 84 | j = w - (step + 2 * pad) 85 | clip = imgs[:, :, i:i + step + 2 * pad, j:j + step + 2 * pad] 86 | clip = clip.cuda() 87 | with paddle.no_grad(): 88 | g_images_clip = netG(clip)[3] 89 | g_images_clip += paddle.flip(netG(paddle.flip(clip, axis=[2]))[3], axis=[2]) 90 | g_images_clip += paddle.flip(netG(paddle.flip(clip, axis=[3]))[3], axis=[3]) 91 | g_images_clip += paddle.flip(netG(paddle.flip(clip, axis=[2, 3]))[3], axis=[2, 3]) 92 | g_images_clip = g_images_clip / 4 # 16 + 480 93 | res[:, :, i + pad:i + step + pad, j + pad:j + step + pad] = g_images_clip[:, :, pad:-pad, pad:-pad] 94 | res = res[:, :, pad:-pad, pad:-pad] 95 | res = pd_tensor2img(res) 96 | cv2.imwrite(os.path.join(save_dir, path[0]), res, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) 97 | print('Total time: ', (time.time() - start) / len(Erase_data)) 98 | 99 | if __name__ == "__main__": 100 | assert len(sys.argv) == 3 101 | 102 | src_image_dir = sys.argv[1] 103 | save_dir = sys.argv[2] 104 | 105 | if not os.path.exists(save_dir): 106 | os.makedirs(save_dir) 107 | 108 | dataRoot = src_image_dir 109 | savePath = save_dir 110 | 111 | # set gpu 112 | if paddle.is_compiled_with_cuda(): 113 | paddle.set_device('gpu:0') 114 | else: 115 | paddle.set_device('cpu') 116 | process(src_image_dir, save_dir) 117 | -------------------------------------------------------------------------------- /submit_dehw.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zdyshine/Baidu-netdisk-AI-Image-processing-Challenge-handwriting/ebc524e7646c37ef502455b89c9a9571f324b6e7/submit_dehw.zip -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import paddle 5 | import paddle.nn as nn 6 | from paddle.io import DataLoader 7 | import numpy as np 8 | from data.dataloader import ErasingData, devdata 9 | from models.sa_gan import STRnet2 10 | from models.sa_aidr import STRAIDR 11 | 12 | # paddle.enable_static() 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--numOfWorkers', type=int, default=0, 15 | help='workers for dataloader') 16 | parser.add_argument('--modelsSavePath', type=str, default='', 17 | help='path for saving models') 18 | parser.add_argument('--logPath', type=str, 19 | default='') 20 | parser.add_argument('--batchSize', type=int, default=16) 21 | parser.add_argument('--loadSize', type=int, default=512, 22 | help='image loading size') 23 | parser.add_argument('--dataRoot', type=str, 24 | default='') 25 | parser.add_argument('--pretrained', type=str, default='', help='pretrained models for finetuning') 26 | parser.add_argument('--savePath', type=str, default='./results/sn_tv/') 27 | parser.add_argument('--net', type=str, default='str') 28 | args = parser.parse_args() 29 | 30 | 31 | def pd_tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 32 | img = tensor.squeeze().cpu().numpy() 33 | img = img.clip(min_max[0], min_max[1]) 34 | img = (img - min_max[0]) / (min_max[1] - min_max[0]) 35 | if out_type == np.uint8: 36 | # scaling 37 | img = img * 255.0 38 | img = np.transpose(img, (1, 2, 0)) 39 | img = img.round() 40 | img = img[:, :, ::-1] 41 | return img.astype(out_type) 42 | 43 | 44 | # set gpu 45 | if paddle.is_compiled_with_cuda(): 46 | paddle.set_device('gpu:0') 47 | else: 48 | paddle.set_device('cpu') 49 | 50 | batchSize = args.batchSize 51 | loadSize = (args.loadSize, args.loadSize) 52 | dataRoot = args.dataRoot 53 | savePath = args.savePath 54 | result_with_mask = savePath + 'WithMaskOutput/' 55 | result_straight = savePath + 'StrOuput/' 56 | # import pdb;pdb.set_trace() 57 | 58 | if not os.path.exists(savePath): 59 | os.makedirs(savePath) 60 | os.makedirs(result_with_mask) 61 | os.makedirs(result_straight) 62 | 63 | Erase_data = devdata(dataRoot=dataRoot, gtRoot=dataRoot) 64 | Erase_data = DataLoader(Erase_data, batch_size=batchSize, shuffle=False, num_workers=args.numOfWorkers, drop_last=False) 65 | 66 | # netG = STRAIDR(num_c=96) 67 | if args.net == 'str': 68 | netG = STRnet2(3) 69 | weights = paddle.load('STE_best_38.6789.pdparams') 70 | netG.load_dict(weights) 71 | print('load:', 'STE_best_38.6789.pdparams') 72 | netG.eval() 73 | for param in netG.parameters(): 74 | param.requires_grad = False 75 | elif args.net == 'idr': 76 | netG = STRAIDR(num_c=96) 77 | weights = paddle.load('STE_idr_38.0642.pdparams') 78 | netG.load_dict(weights) 79 | print('load:', 'STE_idr_38.0642.pdparams') 80 | netG.eval() 81 | for param in netG.parameters(): 82 | param.requires_grad = False 83 | elif args.net == 'mix': 84 | netG1 = STRAIDR(num_c=96) 85 | netG2 = STRnet2(3) 86 | # weights1 = paddle.load('STE_best_37.99.pdparams') # 668 87 | # weights1 = paddle.load('STE_idr_38.0642.pdparams') # 668 88 | weights1 = paddle.load('STE_idr_best.pdparams') # 668 89 | 90 | # weights2 = paddle.load('STE_best_38.6789.pdparams') 91 | # weights2 = paddle.load('STE_best_38.6016_new.pdparams') 92 | weights2 = paddle.load('STE_str_best.pdparams') # 668 93 | 94 | netG1.load_dict(weights1) 95 | netG2.load_dict(weights2) 96 | print('load:', 'STE_idr_38.0642.pdparams', 'STE_best_38.6016_new.pdparams') 97 | netG1.eval() 98 | netG2.eval() 99 | for param in netG1.parameters(): 100 | param.requires_grad = False 101 | for param in netG2.parameters(): 102 | param.requires_grad = False 103 | 104 | print('OK!') 105 | 106 | import time 107 | TIME = [] 108 | 109 | for index,(imgs, gt, path) in enumerate(Erase_data): 110 | pad = 106 111 | m = nn.Pad2D(pad, mode='reflect') 112 | imgs = m(imgs) 113 | print(index, imgs.shape, gt.shape, path) 114 | _, _, h, w = imgs.shape 115 | rh, rw = h, w 116 | step = 300 117 | res = paddle.zeros_like(imgs) 118 | for i in range(0, h, step): 119 | for j in range(0, w, step): 120 | if h - i < step + 2 * pad: 121 | i = h - (step + 2 * pad) 122 | if w - j < step + 2 * pad: 123 | j = w - (step + 2 * pad) 124 | clip = imgs[:, :, i:i + step + 2 * pad, j:j + step + 2 * pad] 125 | clip = clip.cuda() 126 | start = time.time() 127 | with paddle.no_grad(): 128 | if args.net == 'mix': 129 | g_images_clip1 = netG1(clip)[3] 130 | g_images_clip1 += paddle.flip(netG1(paddle.flip(clip, axis=[3]))[3], axis=[3]) 131 | g_images_clip1 = g_images_clip1 / 2 132 | g_images_clip2 = netG2(clip)[3] 133 | g_images_clip2 += paddle.flip(netG2(paddle.flip(clip, axis=[3]))[3], axis=[3]) 134 | g_images_clip2 = g_images_clip2 / 2 135 | g_images_clip = (g_images_clip1 + g_images_clip2) / 2 136 | else: 137 | g_images_clip = netG(clip)[3] 138 | # g_images_clip += paddle.flip(netG(paddle.flip(clip, axis=[2]))[3], axis=[2]) 139 | g_images_clip += paddle.flip(netG(paddle.flip(clip, axis=[3]))[3], axis=[3]) 140 | # g_images_clip += paddle.flip(netG(paddle.flip(clip, axis=[2, 3]))[3], axis=[2, 3]) 141 | g_images_clip = g_images_clip / 2 142 | res[:, :, i + pad:i + step + pad, j + pad:j + step + pad] = g_images_clip[:, :, pad:-pad, pad:-pad] 143 | res = res[:, :, pad:-pad, pad:-pad] 144 | TIME.append(time.time() - start) 145 | # res = res.clamp_(0, 1) 146 | res = pd_tensor2img(res) 147 | cv2.imwrite(result_with_mask + path[0].replace('.jpg', '.png'), res) 148 | print('total time: {}, avg_time: {}.'.format(np.sum(TIME), np.mean(TIME))) 149 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test.py --dataRoot '../data/dehw_testB_dataset' \ 2 | --batchSize 1 \ 3 | --pretrain 'STE_best.pdparams' \ 4 | --savePath 'res/' \ 5 | --net 'mix' 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import paddle 4 | import paddle.nn as nn 5 | import paddle.nn.functional as F 6 | from paddle.io import DataLoader 7 | from data.dataloader import ErasingData,devdata 8 | from loss.Loss import LossWithGAN_STE 9 | from models.sa_gan import STRnet2 10 | from models.sa_aidr import STRAIDR 11 | import utils 12 | import random 13 | import numpy as np 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--numOfWorkers', type=int, default=16, help='workers for dataloader') 17 | parser.add_argument('--modelsSavePath', type=str, default='', help='path for saving models') 18 | parser.add_argument('--logPath', type=str, default='') 19 | parser.add_argument('--batchSize', type=int, default=16) 20 | parser.add_argument('--loadSize', type=int, default=512, help='image loading size') 21 | parser.add_argument('--dataRoot', type=str, default='') 22 | parser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning') 23 | parser.add_argument('--num_epochs', type=int, default=5000, help='epochs') 24 | parser.add_argument('--net', type=str, default='str') 25 | parser.add_argument('--lr', type=float, default=1e-4) 26 | parser.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay') 27 | parser.add_argument('--lr_decay_iters', type=int, default=400000, help='learning rate decay per N iters') 28 | parser.add_argument('--mask_dir', type=str, default='mask') 29 | parser.add_argument('--seed', type=int, default=2022) 30 | args = parser.parse_args() 31 | 32 | log_file = os.path.join('./log', args.net + '_log.txt') 33 | logging = utils.setup_logger(output=log_file, name=args.net) 34 | logging.info(args) 35 | 36 | # set gpu 37 | if paddle.is_compiled_with_cuda(): 38 | paddle.set_device('gpu:0') 39 | else: 40 | paddle.set_device('cpu') 41 | 42 | # set random seed 43 | logging.info('========> Random Seed: {}'.format(args.seed)) 44 | random.seed(args.seed) 45 | np.random.seed(args.seed) 46 | paddle.seed(args.seed) 47 | paddle.framework.random._manual_program_seed(args.seed) 48 | 49 | 50 | batchSize = args.batchSize 51 | loadSize = (args.loadSize, args.loadSize) 52 | 53 | if not os.path.exists(args.modelsSavePath): 54 | os.makedirs(args.modelsSavePath) 55 | 56 | dataRoot = args.dataRoot 57 | 58 | Erase_data = ErasingData(dataRoot, loadSize, training=True, mask_dir=args.mask_dir) 59 | Erase_data = DataLoader(Erase_data, batch_size=batchSize, shuffle=True, num_workers=args.numOfWorkers, drop_last=False) 60 | val_dataRoot='./dataset/task2/dehw_val_dataset/images' 61 | Erase_val_data = devdata(dataRoot=val_dataRoot, gtRoot=val_dataRoot.replace('images','gts')) 62 | Erase_val_data = DataLoader(Erase_val_data, batch_size=1, shuffle=False, num_workers=0, drop_last=False) 63 | print('==============', len(Erase_val_data)) 64 | print('==============>net use: ', args.net) 65 | if args.net == 'str': 66 | netG = STRnet2(3) 67 | elif args.net == 'idr': 68 | netG = STRAIDR(num_c=96) 69 | 70 | if args.pretrained != '': 71 | print('loaded ') 72 | weights = paddle.load(args.pretrained) 73 | netG.load_dict(weights) 74 | 75 | count = 1 76 | scheduler = paddle.optimizer.lr.StepDecay(learning_rate=args.lr, step_size=args.lr_decay_iters, gamma=args.gamma, verbose=False) 77 | G_optimizer = paddle.optimizer.Adam(scheduler, parameters=netG.parameters(), weight_decay=0.0)#betas=(0.5, 0.9)) 78 | 79 | criterion = LossWithGAN_STE(lr=0.00001, betasInit=(0.0, 0.9), Lamda=10.0) 80 | print('OK!') 81 | num_epochs = args.num_epochs 82 | mse = nn.MSELoss() 83 | best_psnr = 0 84 | iters = 0 85 | for epoch in range(1, num_epochs + 1): 86 | netG.train() 87 | 88 | for k, (imgs, gt, masks, path) in enumerate(Erase_data): 89 | iters += 1 90 | #print(imgs.max(), gt.max(), masks.max()) 91 | 92 | x_o1, x_o2, x_o3, fake_images, mm = netG(imgs) 93 | G_loss = criterion(imgs, masks, x_o1, x_o2, x_o3, fake_images, mm, gt, count, epoch) 94 | G_loss = G_loss.sum() 95 | G_optimizer.clear_grad() 96 | G_loss.backward() 97 | G_optimizer.step() 98 | scheduler.step() 99 | if iters % 100 == 0: 100 | logging.info('[{}/{}] Generator Loss of epoch{} is {:.5f}, {}, {}, Lr:{}'.format(iters, len(Erase_data) * num_epochs, epoch, G_loss.item(), args.net, args.mask_dir, G_optimizer.get_lr())) 101 | count += 1 102 | 103 | if (iters % 5000 == 0): 104 | netG.eval() 105 | val_psnr = 0 106 | for index, (imgs, gt, path) in enumerate(Erase_val_data): 107 | print(index, imgs.shape,gt.shape, path) 108 | _,_,h,w = imgs.shape 109 | rh, rw = h, w 110 | step = 512 111 | pad_h = step - h if h < step else 0 112 | pad_w = step - w if w < step else 0 113 | m = nn.Pad2D((0, pad_w,0, pad_h)) 114 | imgs = m(imgs) 115 | _, _, h, w = imgs.shape 116 | res = paddle.zeros_like(imgs) 117 | for i in range(0, h, step): 118 | for j in range(0, w, step): 119 | if h - i < step: 120 | i = h - step 121 | if w -j < step: 122 | j = w - step 123 | clip = imgs[:, :, i:i+step, j:j+step] 124 | clip = clip.cuda() 125 | with paddle.no_grad(): 126 | _, _, _, g_images_clip,mm = netG(clip) 127 | g_images_clip = g_images_clip.cpu() 128 | mm = mm.cpu() 129 | clip = clip.cpu() 130 | mm = paddle.where(F.sigmoid(mm)>0.5, paddle.zeros_like(mm), paddle.ones_like(mm)) 131 | g_image_clip_with_mask = clip * (mm) + g_images_clip * (1- mm) 132 | res[:, :, i:i+step, j:j+step] = g_image_clip_with_mask 133 | res = res[:, :, :rh, :rw] 134 | output = utils.pd_tensor2img(res) 135 | target = utils.pd_tensor2img(gt) 136 | del res 137 | del gt 138 | psnr = utils.compute_psnr(target, output) 139 | del target 140 | del output 141 | val_psnr += psnr 142 | logging.info('index:{} psnr: {}'.format(index, psnr)) 143 | ave_psnr = val_psnr/(index+1) 144 | paddle.save(netG.state_dict(), args.modelsSavePath + '/STE_{}_{:.4f}.pdparams'.format(epoch, ave_psnr)) 145 | if ave_psnr > best_psnr: 146 | best_psnr = ave_psnr 147 | paddle.save(netG.state_dict(), args.modelsSavePath + '/STE_best.pdparams') 148 | logging.info('epoch: {}, ave_psnr: {}, best_psnr: {}'.format(epoch, ave_psnr, best_psnr)) 149 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 train.py --batchSize 4 \ 2 | --dataRoot './dataset/task2/dehw_train_dataset/images' \ 3 | --net 'idr' \ 4 | --lr 1e-4 \ 5 | --modelsSavePath 'ckpts_str_m331_25_idr' \ 6 | --logPath 'logs' \ 7 | --mask_dir 'mask_331_25' 8 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import paddle 4 | 5 | 6 | from skimage.metrics import peak_signal_noise_ratio as psnr 7 | from skimage.metrics import structural_similarity as ssim 8 | import numpy as np 9 | import math 10 | import logging 11 | import os 12 | import sys 13 | from paddle.distributed import ParallelEnv 14 | 15 | def AdjustLearningRate(optimizer, lr): 16 | for param_group in optimizer.param_groups: 17 | print('param_group',param_group['lr']) 18 | param_group['lr'] = lr 19 | 20 | def compute_psnr(im1, im2): 21 | p = psnr(im1, im2) 22 | return p 23 | 24 | 25 | def compute_ssim(im1, im2): 26 | isRGB = len(im1.shape) == 3 and im1.shape[-1] == 3 27 | s = ssim(im1, im2, K1=0.01, K2=0.03, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, 28 | multichannel=isRGB) 29 | return s 30 | 31 | 32 | def pd_tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 33 | img = tensor.squeeze().cpu().numpy() 34 | img = img.clip(min_max[0], min_max[1]) 35 | img = (img - min_max[0]) / (min_max[1] - min_max[0]) 36 | if out_type == np.uint8: 37 | # scaling 38 | img = img * 255.0 39 | img = np.transpose(img, (1, 2, 0)) 40 | img = img.round() 41 | img = img[:,:,::-1] 42 | return img.astype(out_type) 43 | 44 | def setup_logger(output=None, name="ppgan"): 45 | logger_initialized = [] 46 | """ 47 | Initialize the ppgan logger and set its verbosity level to "INFO". 48 | 49 | Args: 50 | output (str): a file name or a directory to save log. If None, will not save log file. 51 | If ends with ".txt" or ".log", assumed to be a file name. 52 | Otherwise, logs will be saved to `output/log.txt`. 53 | name (str): the root module name of this logger 54 | 55 | Returns: 56 | logging.Logger: a logger 57 | """ 58 | logger = logging.getLogger(name) 59 | if name in logger_initialized: 60 | return logger 61 | logger.setLevel(logging.INFO) 62 | logger.propagate = False 63 | 64 | plain_formatter = logging.Formatter( 65 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", 66 | datefmt="%m/%d %H:%M:%S") 67 | # stdout logging: master only 68 | local_rank = ParallelEnv().local_rank 69 | if local_rank == 0: 70 | ch = logging.StreamHandler(stream=sys.stdout) 71 | ch.setLevel(logging.DEBUG) 72 | formatter = plain_formatter 73 | ch.setFormatter(formatter) 74 | logger.addHandler(ch) 75 | 76 | # file logging: all workers 77 | if output is not None: 78 | if output.endswith(".txt") or output.endswith(".log"): 79 | filename = output 80 | else: 81 | filename = os.path.join(output, "log.txt") 82 | if local_rank > 0: 83 | filename = filename + ".rank{}".format(local_rank) 84 | 85 | # make dir if path not exist 86 | os.makedirs(os.path.dirname(filename), exist_ok=True) 87 | 88 | fh = logging.FileHandler(filename, mode='a') 89 | fh.setLevel(logging.DEBUG) 90 | fh.setFormatter(plain_formatter) 91 | logger.addHandler(fh) 92 | logger_initialized.append(name) 93 | return logger 94 | 95 | 96 | def load_pretrained_model(model, pretrained_model): 97 | if pretrained_model is not None: 98 | print('Loading pretrained model from {}'.format(pretrained_model)) 99 | 100 | if os.path.exists(pretrained_model): 101 | para_state_dict = paddle.load(pretrained_model) 102 | model_state_dict = model.state_dict() 103 | keys = model_state_dict.keys() 104 | num_params_loaded = 0 105 | for k in keys: 106 | if k not in para_state_dict: 107 | print('{} is not in pretrained model'.format(k)) 108 | elif list(para_state_dict[k].shape) != list( 109 | model_state_dict[k].shape): 110 | print("[SKIP] shape of pretrained params {} doesn't match.(Pretrained: {}, Actual:{})" 111 | .format(k, para_state_dict[k].shape, 112 | model_state_dict[k].shape)) 113 | else: 114 | model_state_dict[k] = para_state_dict[k] 115 | num_params_loaded += 1 116 | model.set_dict(model_state_dict) 117 | print("There are {}/{} variables loaded into {}." 118 | .format(num_params_loaded, len(model_state_dict), 119 | model.__class__.__name__)) 120 | else: 121 | raise ValueError( 122 | "The pretrained model directory is not Found: {}" 123 | .format(pretrained_model) 124 | ) 125 | else: 126 | print('No pretrained model to load, {} will be trained from scratch.' 127 | .format(model.__class__.__name__)) 128 | 129 | 130 | -------------------------------------------------------------------------------- /zip.sh: -------------------------------------------------------------------------------- 1 | cd res/WithMaskOutput 2 | zip -r result.zip * 3 | mv result.zip ../../ 4 | cd ../.. --------------------------------------------------------------------------------