├── README.md ├── imgs ├── fapnet.png └── results.png ├── lib ├── Res2Net_v1b.py ├── __pycache__ │ ├── Res2Net_v1b.cpython-38.pyc │ └── model.cpython-38.pyc └── model.py ├── requirement.txt ├── test.py ├── train.py └── utils ├── dataloader.py ├── eva_funcs.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # Feature Aggregation and Propagation Network for Camouflaged Object Detection 2 | 3 | > **Authors:** 4 | > [*Tao Zhou*](https://taozh2017.github.io), 5 | > [*Yi Zhou*](https://cse.seu.edu.cn/2021/0303/c23024a362239/page.htm), 6 | > [*Chen Gong*](https://gcatnjust.github.io/ChenGong/index.html), 7 | > [*Jian Yang*](https://scholar.google.com.hk/citations?user=6CIDtZQAAAAJ&hl=zh-CN), 8 | > and [*Yu Zhang*](https://scholar.google.com.hk/citations?user=oDrTEi0AAAAJ&hl=zh-CN). 9 | 10 | 11 | 12 | ## 1. Preface 13 | 14 | - This repository provides code for "_**Feature Aggregation and Propagation Network for Camouflaged Object Detection**_" IEEE TIP 2022. [Paper](https://ieeexplore.ieee.org/abstract/document/9940173/) [![Arxiv Page](https://img.shields.io/badge/Arxiv-2105.12555-red?style=flat-square)](https://arxiv.org/pdf/2212.00990.pdf) 15 | 16 | 17 | ## 2. Overview 18 | 19 | ### 2.1. Introduction 20 | 21 | Camouflaged object detection (COD) aims to detect/segment camouflaged objects embedded in the environment, which has attracted increasing attention over the past decades. 22 | Although several COD methods have been developed, they still suffer from unsatisfactory performance due to the intrinsic similarities between the foreground objects and background surroundings. In this paper, we propose a novel Feature Aggregation 23 | and Propagation Network (FAP-Net) for camouflaged object detection. Specifically, we propose a Boundary Guidance Module (BGM) to explicitly model the boundary characteristic, which can provide boundary-enhanced features to boost the COD performance. To capture the scale variations of the camouflaged objects, we propose a Multi-scale Feature Aggregation Module (MFAM) to characterize the multi-scale information from 24 | each layer and obtain the aggregated feature representations. Furthermore, we propose a Cross-level Fusion and Propagation Module (CFPM). In the CFPM, the feature fusion part can effectively integrate the features from adjacent layers to exploit the cross-level correlations, and the feature propagation part can transmit valuable context information from the encoder to the decoder network via a gate unit. Finally, we formulate a unified 25 | and end-to-end trainable framework where cross-level features can be effectively fused and propagated for capturing rich context information. Extensive experiments on three benchmark camouflaged datasets demonstrate that our FAP-Net outperforms other state-of-the-art COD models. Moreover, our model can be extended to the polyp segmentation task, and the comparison 26 | results further validate the effectiveness of the proposed model in segmenting polyps. 27 | 28 | 29 | ### 2.2. Framework Overview 30 | 31 |

32 |
33 | 34 | Figure 1: The overall architecture of the proposed FAP-Net, consisting of three key components, i.e., boundary guidance module, multi-scale feature aggregation module, and cross-level fusion and propagation module. 35 | 36 |

37 | 38 | ### 2.3. Qualitative Results 39 | 40 |

41 |
42 | 43 | Figure 2: Qualitative Results. 44 | 45 |

46 | 47 | ## 3. Proposed Method 48 | 49 | ### 3.1. Training/Testing 50 | 51 | The training and testing experiments are conducted using [PyTorch](https://github.com/pytorch/pytorch) with one NVIDIA Tesla P40 GPU of 24 GB Memory. 52 | 53 | 1. Configuring your environment (Prerequisites): 54 | 55 | + Installing necessary packages: `pip install -r requirements.txt`. 56 | 57 | 1. Downloading necessary data: 58 | 59 | 60 | + downloading training dataset and move it into `./data/`, 61 | which can be found from [Google Drive](https://drive.google.com/file/d/1Kifp7I0n9dlWKXXNIbN7kgyokoRY4Yz7/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1uyQz0b_r_5yCee0orSw7EA) (extraction code: fapn). 62 | 63 | + downloading testing dataset and move it into `./data/`, 64 | which can be found from [Google Drive](https://drive.google.com/file/d/1SLRB5Wg1Hdy7CQ74s3mTQ3ChhjFRSFdZ/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1F3tVEWYzKYp5NBv3cjiaAg) (extraction code: fapn). 65 | 66 | + downloading our weights and move it into `./checkpoints/FAPNet.pth`, 67 | which can be found from [Google Drive](https://drive.google.com/file/d/1qjb70ZGwExei21x6uMQbUAwZ0Ts4Z-xk/view?usp=share_link) or [(Baidu Drive)](https://pan.baidu.com/s/1BeRx81XNKq_jA7LHut1VZg) (extraction code: fapn). . 68 | 69 | + downloading Res2Net weights and move it into `./lib/res2net50_v1b_26w_4s-3cf99910.pth`, 70 | which can be found from [Google Drive](https://drive.google.com/file/d/1_1N-cx1UpRQo7Ybsjno1PAg4KE1T9e5J/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1gDNNh7Cad3jv2l7i4_V33Q) (extraction code: fapn). 71 | 72 | 1. Training Configuration: 73 | 74 | + After you download training dataset, just run `train.py` to train our model. 75 | 76 | 77 | 1. Testing Configuration: 78 | 79 | + After you download all the pre-trained model and testing dataset, just run `test.py` to generate the final prediction maps. 80 | 81 | + You can also download prediction maps ('CHAMELEON', 'CAMO', 'COD10K') from [Google Drive](https://drive.google.com/file/d/1O5gDTBasHWuwPv4hxd04Nt08y16_Y578/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1nltTLWnU3YZpCQO5LewAIw) (extraction code: fapn)). 82 | 83 | + You can also download prediction maps (NC4K) from [Google Drive](https://drive.google.com/file/d/139CoLtoQp_9n3T8WXc2G8xI61MJj_bpj/view?usp=share_link) or [Baidu Drive](https://pan.baidu.com/s/1iGcAAnFjbav-HLoG7Dc_FA) (extraction code: fapn)). 84 | 85 | ### 3.2 Evaluating your trained model: 86 | 87 | One evaluation is written in MATLAB code ([link](https://github.com/DengPingFan/CODToolbox)), 88 | please follow this the instructions in `./eval/main.m` and just run it to generate the evaluation results in. 89 | 90 | 91 | 92 | ## 4. Citation 93 | 94 | Please cite our paper if you find the work useful, thanks! 95 | 96 | @article{zhou2022feature, 97 | title={Feature Aggregation and Propagation Network for Camouflaged Object Detection}, 98 | author={Zhou, Tao and Zhou, Yi and Gong, Chen and Yang, Jian and Zhang, Yu}, 99 | journal={IEEE Transactions on Image Processing}, 100 | volume={31}, 101 | pages={7036--7047}, 102 | year={2022}, 103 | publisher={IEEE} 104 | } 105 | 106 | 107 | 108 | **[⬆ back to top](#1-preface)** 109 | -------------------------------------------------------------------------------- /imgs/fapnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/FAPNet/1c791a1bf210ca88991b52300b420ec0706a553a/imgs/fapnet.png -------------------------------------------------------------------------------- /imgs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/FAPNet/1c791a1bf210ca88991b52300b420ec0706a553a/imgs/results.png -------------------------------------------------------------------------------- /lib/Res2Net_v1b.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b', 'res2net50_v1b_26w_4s'] 8 | 9 | model_urls = { 10 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 11 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 12 | } 13 | 14 | 15 | class Bottle2neck(nn.Module): 16 | expansion = 4 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'): 19 | """ Constructor 20 | Args: 21 | inplanes: input channel dimensionality 22 | planes: output channel dimensionality 23 | stride: conv stride. Replaces pooling layer. 24 | downsample: None when stride = 1 25 | baseWidth: basic width of conv3x3 26 | scale: number of scale. 27 | type: 'normal': normal set. 'stage': first block of a new stage. 28 | """ 29 | super(Bottle2neck, self).__init__() 30 | 31 | width = int(math.floor(planes * (baseWidth / 64.0))) 32 | self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(width * scale) 34 | 35 | if scale == 1: 36 | self.nums = 1 37 | else: 38 | self.nums = scale - 1 39 | if stype == 'stage': 40 | self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) 41 | convs = [] 42 | bns = [] 43 | for i in range(self.nums): 44 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False)) 45 | bns.append(nn.BatchNorm2d(width)) 46 | self.convs = nn.ModuleList(convs) 47 | self.bns = nn.ModuleList(bns) 48 | 49 | self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 51 | 52 | self.relu = nn.ReLU(inplace=True) 53 | self.downsample = downsample 54 | self.stype = stype 55 | self.scale = scale 56 | self.width = width 57 | 58 | def forward(self, x): 59 | residual = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | spx = torch.split(out, self.width, 1) 66 | for i in range(self.nums): 67 | if i == 0 or self.stype == 'stage': 68 | sp = spx[i] 69 | else: 70 | sp = sp + spx[i] 71 | sp = self.convs[i](sp) 72 | sp = self.relu(self.bns[i](sp)) 73 | if i == 0: 74 | out = sp 75 | else: 76 | out = torch.cat((out, sp), 1) 77 | if self.scale != 1 and self.stype == 'normal': 78 | out = torch.cat((out, spx[self.nums]), 1) 79 | elif self.scale != 1 and self.stype == 'stage': 80 | out = torch.cat((out, self.pool(spx[self.nums])), 1) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class Res2Net(nn.Module): 95 | 96 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): 97 | self.inplanes = 64 98 | super(Res2Net, self).__init__() 99 | self.baseWidth = baseWidth 100 | self.scale = scale 101 | self.conv1 = nn.Sequential( 102 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 103 | nn.BatchNorm2d(32), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 106 | nn.BatchNorm2d(32), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 109 | ) 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.relu = nn.ReLU() 112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 113 | self.layer1 = self._make_layer(block, 64, layers[0]) 114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 117 | self.avgpool = nn.AdaptiveAvgPool2d(1) 118 | self.fc = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | elif isinstance(m, nn.BatchNorm2d): 124 | nn.init.constant_(m.weight, 1) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | nn.AvgPool2d(kernel_size=stride, stride=stride, 132 | ceil_mode=True, count_include_pad=False), 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=1, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 140 | stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | x = self.bn1(x) 150 | x = self.relu(x) 151 | x = self.maxpool(x) 152 | 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | 158 | x = self.avgpool(x) 159 | x = x.view(x.size(0), -1) 160 | x = self.fc(x) 161 | 162 | return x 163 | 164 | 165 | def res2net50_v1b(pretrained=False, **kwargs): 166 | """Constructs a Res2Net-50_v1b lib. 167 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 168 | Args: 169 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 170 | """ 171 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 172 | if pretrained: 173 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 174 | return model 175 | 176 | 177 | def res2net101_v1b(pretrained=False, **kwargs): 178 | """Constructs a Res2Net-50_v1b_26w_4s lib. 179 | Args: 180 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 181 | """ 182 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 183 | if pretrained: 184 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 185 | return model 186 | 187 | 188 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 189 | """Constructs a Res2Net-50_v1b_26w_4s lib. 190 | Args: 191 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 192 | """ 193 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 194 | if pretrained: 195 | model_state = torch.load('./lib/res2net50_v1b_26w_4s-3cf99910.pth') 196 | model.load_state_dict(model_state) 197 | # lib.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 198 | return model 199 | 200 | 201 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 202 | """Constructs a Res2Net-50_v1b_26w_4s lib. 203 | Args: 204 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 205 | """ 206 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 207 | if pretrained: 208 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 209 | return model 210 | 211 | 212 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 213 | """Constructs a Res2Net-50_v1b_26w_4s lib. 214 | Args: 215 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 216 | """ 217 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth=26, scale=4, **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'])) 220 | return model 221 | 222 | 223 | if __name__ == '__main__': 224 | images = torch.rand(1, 3, 352, 352).cuda(0) 225 | model = res2net50_v1b_26w_4s(pretrained=True) 226 | model = model.cuda(0) 227 | print(model(images).size()) 228 | -------------------------------------------------------------------------------- /lib/__pycache__/Res2Net_v1b.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/FAPNet/1c791a1bf210ca88991b52300b420ec0706a553a/lib/__pycache__/Res2Net_v1b.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/FAPNet/1c791a1bf210ca88991b52300b420ec0706a553a/lib/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /lib/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torch.nn.functional as F 5 | 6 | from lib.Res2Net_v1b import res2net50_v1b_26w_4s 7 | 8 | 9 | class BasicConv2d(nn.Module): 10 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 11 | super(BasicConv2d, self).__init__() 12 | self.conv = nn.Conv2d(in_planes, out_planes, 13 | kernel_size=kernel_size, stride=stride, 14 | padding=padding, dilation=dilation, bias=False) 15 | self.bn = nn.BatchNorm2d(out_planes) 16 | self.relu = nn.ReLU(inplace=True) 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | x = self.bn(x) 21 | return x 22 | 23 | 24 | ################################################################### 25 | class MFAM0(nn.Module): 26 | 27 | def __init__(self, in_channels, out_channels): 28 | super(MFAM0, self).__init__() 29 | 30 | self.relu = nn.ReLU(inplace=True) 31 | 32 | 33 | self.conv_1_1 = BasicConv2d(in_channels, out_channels, 1) 34 | self.conv_1_2 = BasicConv2d(in_channels, out_channels, 1) 35 | self.conv_1_3 = BasicConv2d(in_channels, out_channels, 1) 36 | self.conv_1_4 = BasicConv2d(in_channels, out_channels, 1) 37 | self.conv_1_5 = BasicConv2d(out_channels, out_channels, 3, stride=1, padding=1) 38 | 39 | self.conv_3_1 = nn.Conv2d(out_channels, out_channels , kernel_size=3, stride=1, padding=1) 40 | self.conv_3_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 41 | 42 | self.conv_5_1 = nn.Conv2d(out_channels, out_channels , kernel_size=5, stride=1, padding=2) 43 | self.conv_5_2 = nn.Conv2d(out_channels, out_channels, kernel_size=5, stride=1, padding=2) 44 | 45 | 46 | def forward(self, x): 47 | 48 | ###+ 49 | x1 = x # self.conv_1_1(x) 50 | x2 = x # self.conv_1_2(x) 51 | x3 = x # self.conv_1_3(x) 52 | 53 | x_3_1 = self.relu(self.conv_3_1(x2)) ## (BS, 32, ***, ***) 54 | x_5_1 = self.relu(self.conv_5_1(x3)) ## (BS, 32, ***, ***) 55 | 56 | x_3_2 = self.relu(self.conv_3_2(x_3_1 + x_5_1)) ## (BS, 64, ***, ***) 57 | x_5_2 = self.relu(self.conv_5_2(x_5_1 + x_3_1)) ## (BS, 64, ***, ***) 58 | 59 | x_mul = torch. mul(x_3_2, x_5_2) 60 | out = self.relu(x1 + self.conv_1_5(x_mul + x_3_1 + x_5_1)) 61 | 62 | return out 63 | 64 | class MFAM(nn.Module): 65 | 66 | def __init__(self, in_channels, out_channels): 67 | super(MFAM, self).__init__() 68 | 69 | self.relu = nn.ReLU(inplace=True) 70 | 71 | self.conv_1_1 = BasicConv2d(in_channels, out_channels, 1) 72 | self.conv_1_2 = BasicConv2d(in_channels, out_channels, 1) 73 | self.conv_1_3 = BasicConv2d(in_channels, out_channels, 1) 74 | self.conv_1_4 = BasicConv2d(in_channels, out_channels, 1) 75 | self.conv_1_5 = BasicConv2d(out_channels, out_channels, 3, stride=1, padding=1) 76 | 77 | self.conv_3_1 = nn.Conv2d(out_channels, out_channels , kernel_size=3, stride=1, padding=1) 78 | self.conv_3_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 79 | 80 | self.conv_5_1 = nn.Conv2d(out_channels, out_channels , kernel_size=5, stride=1, padding=2) 81 | self.conv_5_2 = nn.Conv2d(out_channels, out_channels, kernel_size=5, stride=1, padding=2) 82 | 83 | 84 | def forward(self, x): 85 | 86 | ###+ 87 | x1 = self.conv_1_1(x) 88 | x2 = self.conv_1_2(x) 89 | x3 = self.conv_1_3(x) 90 | 91 | x_3_1 = self.relu(self.conv_3_1(x2)) ## (BS, 32, ***, ***) 92 | x_5_1 = self.relu(self.conv_5_1(x3)) ## (BS, 32, ***, ***) 93 | 94 | x_3_2 = self.relu(self.conv_3_2(x_3_1 + x_5_1)) ## (BS, 64, ***, ***) 95 | x_5_2 = self.relu(self.conv_5_2(x_5_1 + x_3_1)) ## (BS, 64, ***, ***) 96 | 97 | x_mul = torch.mul(x_3_2, x_5_2) 98 | 99 | out = self.relu(x1 + self.conv_1_5(x_mul + x_3_1 + x_5_1)) 100 | 101 | return out 102 | 103 | 104 | ################################################################### 105 | class FeaFusion(nn.Module): 106 | def __init__(self, channels): 107 | self.init__ = super(FeaFusion, self).__init__() 108 | 109 | self.relu = nn.ReLU() 110 | self.layer1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) 111 | 112 | self.layer2_1 = nn.Conv2d(channels, channels //4, kernel_size=3, stride=1, padding=1) 113 | self.layer2_2 = nn.Conv2d(channels, channels //4, kernel_size=3, stride=1, padding=1) 114 | 115 | self.layer_fu = nn.Conv2d(channels//4, channels, kernel_size=3, stride=1, padding=1) 116 | 117 | def forward(self, x1, x2): 118 | 119 | ### 120 | wweight = nn.Sigmoid()(self.layer1(x1+x2)) 121 | 122 | ### 123 | xw_resid_1 = x1+ x1.mul(wweight) 124 | xw_resid_2 = x2+ x2.mul(wweight) 125 | 126 | ### 127 | x1_2 = self.layer2_1(xw_resid_1) 128 | x2_2 = self.layer2_2(xw_resid_2) 129 | 130 | out = self.relu(self.layer_fu(x1_2 + x2_2)) 131 | 132 | return out 133 | 134 | ################################################################### 135 | class FeaProp(nn.Module): 136 | def __init__(self, in_planes): 137 | self.init__ = super(FeaProp, self).__init__() 138 | 139 | 140 | act_fn = nn.ReLU(inplace=True) 141 | 142 | self.layer_1 = nn.Sequential(nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(in_planes),act_fn) 143 | self.layer_2 = nn.Sequential(nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(in_planes),act_fn) 144 | 145 | self.gate_1 = nn.Conv2d(in_planes*2, 1, kernel_size=1, bias=True) 146 | self.gate_2 = nn.Conv2d(in_planes*2, 1, kernel_size=1, bias=True) 147 | 148 | self.softmax = nn.Softmax(dim=1) 149 | 150 | 151 | def forward(self, x10, x20): 152 | 153 | ### 154 | x1 = self.layer_1(x10) 155 | x2 = self.layer_2(x20) 156 | 157 | cat_fea = torch.cat([x1,x2], dim=1) 158 | 159 | ### 160 | att_vec_1 = self.gate_1(cat_fea) 161 | att_vec_2 = self.gate_2(cat_fea) 162 | 163 | att_vec_cat = torch.cat([att_vec_1, att_vec_2], dim=1) 164 | att_vec_soft = self.softmax(att_vec_cat) 165 | 166 | att_soft_1, att_soft_2 = att_vec_soft[:, 0:1, :, :], att_vec_soft[:, 1:2, :, :] 167 | x_fusion = x1 * att_soft_1 + x2 * att_soft_2 168 | 169 | return x_fusion 170 | 171 | ################################################################### 172 | 173 | class FAPNet(nn.Module): 174 | # resnet based encoder decoder 175 | def __init__(self, channel=32, opt=None): 176 | super(FAPNet, self).__init__() 177 | 178 | 179 | act_fn = nn.ReLU(inplace=True) 180 | self.nf = channel 181 | 182 | self.resnet = res2net50_v1b_26w_4s(pretrained=True) 183 | self.downSample = nn.MaxPool2d(2, stride=2) 184 | 185 | ## 186 | self.rf1 = MFAM0(64, self.nf) 187 | self.rf2 = MFAM(256, self.nf) 188 | self.rf3 = MFAM(512, self.nf) 189 | self.rf4 = MFAM(1024, self.nf) 190 | self.rf5 = MFAM(2048, self.nf) 191 | 192 | 193 | ## 194 | self.cfusion2 = FeaFusion(self.nf) 195 | self.cfusion3 = FeaFusion(self.nf) 196 | self.cfusion4 = FeaFusion(self.nf) 197 | self.cfusion5 = FeaFusion(self.nf) 198 | 199 | ## 200 | self.cgate5 = FeaProp(self.nf) 201 | self.cgate4 = FeaProp(self.nf) 202 | self.cgate3 = FeaProp(self.nf) 203 | self.cgate2 = FeaProp(self.nf) 204 | 205 | 206 | self.de_5 = nn.Sequential(nn.Conv2d(self.nf, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 207 | self.de_4 = nn.Sequential(nn.Conv2d(self.nf, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 208 | self.de_3 = nn.Sequential(nn.Conv2d(self.nf, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 209 | self.de_2 = nn.Sequential(nn.Conv2d(self.nf, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 210 | 211 | 212 | 213 | ## 214 | self.edge_conv0 = nn.Sequential(nn.Conv2d(64, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 215 | self.edge_conv1 = nn.Sequential(nn.Conv2d(256, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 216 | self.edge_conv2 = nn.Sequential(nn.Conv2d(self.nf, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 217 | self.edge_conv3 = BasicConv2d(self.nf, 1, kernel_size=3, padding=1) 218 | 219 | 220 | self.fu_5 = nn.Sequential(nn.Conv2d(self.nf*2, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 221 | self.fu_4 = nn.Sequential(nn.Conv2d(self.nf*2, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 222 | self.fu_3 = nn.Sequential(nn.Conv2d(self.nf*2, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 223 | self.fu_2 = nn.Sequential(nn.Conv2d(self.nf*2, self.nf, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(self.nf),act_fn) 224 | 225 | 226 | ## 227 | self.layer_out5 = nn.Sequential(nn.Conv2d(self.nf, 1, kernel_size=3, stride=1, padding=1)) 228 | self.layer_out4 = nn.Sequential(nn.Conv2d(self.nf, 1, kernel_size=3, stride=1, padding=1)) 229 | self.layer_out3 = nn.Sequential(nn.Conv2d(self.nf, 1, kernel_size=3, stride=1, padding=1)) 230 | self.layer_out2 = nn.Sequential(nn.Conv2d(self.nf, 1, kernel_size=3, stride=1, padding=1)) 231 | 232 | 233 | 234 | ## 235 | self.up_2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 236 | self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 237 | self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 238 | self.up_16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) 239 | 240 | 241 | 242 | def forward(self, xx): 243 | 244 | # ---- feature abstraction ----- 245 | x = self.resnet.conv1(xx) 246 | x = self.resnet.bn1(x) 247 | x = self.resnet.relu(x) 248 | 249 | # - low-level features 250 | x1 = self.resnet.maxpool(x) # (BS, 64, 88, 88) 251 | x2 = self.resnet.layer1(x1) # (BS, 256, 88, 88) 252 | x3 = self.resnet.layer2(x2) # (BS, 512, 44, 44) 253 | x4 = self.resnet.layer3(x3) # (BS, 1024, 22, 22) 254 | x5 = self.resnet.layer4(x4) # (BS, 2048, 11, 11) 255 | 256 | ## -------------------------------------- ## 257 | xf1 = self.rf1(x1) 258 | xf2 = self.rf2(x2) 259 | xf3 = self.rf3(x3) 260 | xf4 = self.rf4(x4) 261 | xf5 = self.rf5(x5) 262 | 263 | 264 | ## edge 265 | x21 = self.edge_conv1(x2) 266 | edge_guidance = self.edge_conv2(self.edge_conv0(x1) + x21) 267 | edge_out = self.up_4(self.edge_conv3(edge_guidance)) 268 | 269 | 270 | ### layer 5 271 | en_fusion5 = self.cfusion5(self.up_2(xf5), xf4) ## (BS, 64, 22, 22) 272 | out_gate_fu5 = self.fu_5(torch.cat((en_fusion5, F.interpolate(edge_guidance, scale_factor=1/4, mode='bilinear')),dim=1)) 273 | out5 = self.up_16(self.layer_out5(out_gate_fu5)) 274 | 275 | 276 | de_feature4 = self.de_4(self.up_2(en_fusion5)) ## (BS, 64, 22, 22) 277 | en_fusion4 = self.cfusion4(self.up_2(xf4), xf3) ## (BS, 64, 44, 44) 278 | out_gate4 = self.cgate4(en_fusion4, de_feature4) ## (BS, 64, 44, 44) 279 | out_gate_fu4 = self.fu_4(torch.cat((out_gate4, F.interpolate(edge_guidance, scale_factor=1/2, mode='bilinear')),dim=1)) 280 | out4 = self.up_8(self.layer_out4(out_gate_fu4)) 281 | 282 | 283 | de_feature3 = self.de_3(self.up_2(out_gate4)) ## (BS, 64, 88, 88) 284 | en_fusion3 = self.cfusion3(self.up_2(xf3), xf2) ## (BS, 64, 88, 88) 285 | out_gate3 = self.cgate3(en_fusion3, de_feature3) ## (BS, 64, 88, 88) 286 | out_gate_fu3 = self.fu_3(torch.cat((out_gate3, edge_guidance),dim=1)) 287 | out3 = self.up_4(self.layer_out3(out_gate_fu3)) 288 | 289 | 290 | de_feature2 = self.de_2(self.up_2(out_gate3)) ## (BS, 64, 176, 176) 291 | en_fusion2 = self.cfusion2(self.up_2(xf2), self.up_2(xf1)) ## (BS, 64, 176, 176) 292 | out_gate2 = self.cgate2(en_fusion2, de_feature2) ## (BS, 64, 176, 176) 293 | out_gate_fu2 = self.fu_2(torch.cat((out_gate2, self.up_2(edge_guidance)), dim=1)) 294 | out2 = self.up_2(self.layer_out2(out_gate_fu2)) 295 | 296 | 297 | # ---- output ---- 298 | return out5, out4, out3, out2, edge_out 299 | 300 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch==1.2.0 2 | python==3.6.6 3 | opencv-python==4.5.2 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os 5 | import argparse 6 | 7 | 8 | from lib.model import FAPNet 9 | from utils.dataloader import tt_dataset 10 | from utils.eva_funcs import eval_Smeasure,eval_mae,numpy2tensor 11 | import scipy.io as scio 12 | import cv2 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--testsize', type=int, default=352, help='the snapshot input size') 17 | parser.add_argument('--model_path', type=str, default='./checkpoints/') 18 | parser.add_argument('--save_path', type=str, default='./results/') 19 | 20 | opt = parser.parse_args() 21 | model = FAPNet(channel=64).cuda() 22 | 23 | 24 | cur_model_path = opt.model_path+'FAPNet.pth' 25 | model.load_state_dict(torch.load(cur_model_path)) 26 | model.eval() 27 | 28 | 29 | ################################################################ 30 | 31 | for dataset in ['CHAMELEON', 'CAMO', 'COD10K']: 32 | 33 | save_path = opt.save_path + dataset + '/' 34 | os.makedirs(save_path, exist_ok=True) 35 | 36 | 37 | test_loader = tt_dataset('/test/CamouflagedObjectDection/Dataset/TestDataset/{}/Imgs/'.format(dataset), 38 | '/test/CamouflagedObjectDection/Dataset/TestDataset/{}/GT/'.format(dataset), opt.testsize) 39 | 40 | 41 | 42 | for iteration in range(test_loader.size): 43 | 44 | 45 | image, gt, name = test_loader.load_data() 46 | 47 | gt = np.asarray(gt, np.float32) 48 | gt /= (gt.max() + 1e-8) 49 | image = image.cuda() 50 | 51 | 52 | _,_,_, cam,_ = model(image) 53 | 54 | res = F.upsample(cam, size=gt.shape, mode='bilinear', align_corners=False) 55 | res = res.sigmoid().data.cpu().numpy().squeeze() 56 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 57 | 58 | 59 | ################################################################ 60 | cv2.imwrite(save_path+name, res*255) 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import argparse 7 | import logging 8 | 9 | from lib.model import FAPNet 10 | from utils.dataloader import get_loader,test_dataset 11 | from utils.trainer import adjust_lr 12 | from datetime import datetime 13 | 14 | 15 | best_mae = 1 16 | best_epoch = 0 17 | 18 | 19 | def structure_loss(pred, mask): 20 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask) 21 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 22 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) 23 | 24 | pred = torch.sigmoid(pred) 25 | inter = ((pred*mask)*weit).sum(dim=(2,3)) 26 | union = ((pred+mask)*weit).sum(dim=(2,3)) 27 | wiou = 1-(inter+1)/(union-inter+1) 28 | return (wbce+wiou).mean() 29 | 30 | 31 | def train(train_loader, model, optimizer, epoch, opt, loss_func, total_step): 32 | """ 33 | Training iteration 34 | :param train_loader: 35 | :param model: 36 | :param optimizer: 37 | :param epoch: 38 | :param opt: 39 | :param loss_func: 40 | :param total_step: 41 | :return: 42 | """ 43 | model.train() 44 | 45 | size_rates = [0.75, 1, 1.25] 46 | 47 | for step, data_pack in enumerate(train_loader): 48 | for rate in size_rates: 49 | 50 | optimizer.zero_grad() 51 | images, gts, egs = data_pack 52 | 53 | 54 | images = images.cuda() 55 | gts = gts.cuda() 56 | egs = egs.cuda() 57 | 58 | 59 | # ---- rescale ---- 60 | trainsize = int(round(opt.trainsize*rate/32)*32) 61 | if rate != 1: 62 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 63 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 64 | egs = F.upsample(egs, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 65 | 66 | sal1,sal2,sal3,sal4,edge_out = model(images) 67 | 68 | loss_edge = loss_func(edge_out, egs) 69 | 70 | loss1 = structure_loss(sal1, gts) 71 | loss2 = structure_loss(sal2, gts) 72 | loss3 = structure_loss(sal3, gts) 73 | loss4 = structure_loss(sal4, gts) 74 | 75 | loss_obj = loss1 + loss2 + loss3 + loss4 76 | 77 | 78 | loss_total = loss_obj + loss_edge 79 | 80 | 81 | loss_total.backward() 82 | optimizer.step() 83 | 84 | if step % 20 == 0 or step == total_step: 85 | print('[{}] => [Epoch Num: {:03d}/{:03d}] => [Global Step: {:04d}/{:04d}] => [Loss_obj: {:.4f} Loss_edge: {:0.4f} Loss_all: {:0.4f}]'. 86 | format(datetime.now(), epoch, opt.epoch, step, total_step, loss_obj.data,loss_edge.data, loss_total.data)) 87 | 88 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss_obj: {:.4f} Loss_edge: {:0.4f} Loss_all: {:0.4f}'. 89 | format( epoch, opt.epoch, step, total_step, loss_obj.data,loss_edge.data, loss_total.data)) 90 | 91 | 92 | if (epoch) % opt.save_epoch == 0: 93 | torch.save(model.state_dict(), save_path + 'FAPNet_%d.pth' % (epoch)) 94 | 95 | 96 | def test(test_loader,model,epoch,save_path): 97 | 98 | global best_mae,best_epoch 99 | model.eval() 100 | 101 | with torch.no_grad(): 102 | mae_sum=0 103 | for i in range(test_loader.size): 104 | image, gt, name,_ = test_loader.load_data() 105 | gt = np.asarray(gt, np.float32) 106 | 107 | gt /= (gt.max() + 1e-8) 108 | 109 | image = image.cuda() 110 | 111 | _,_,_,res,_ = model(image) 112 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 113 | res = res.sigmoid().data.cpu().numpy().squeeze() 114 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 115 | mae_sum +=np.sum(np.abs(res-gt))*1.0/(gt.shape[0]*gt.shape[1]) 116 | 117 | mae = mae_sum / test_loader.size 118 | 119 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch,mae,best_mae,best_epoch)) 120 | if epoch == 1: 121 | best_mae = mae 122 | else: 123 | if mae < best_mae: 124 | best_mae = mae 125 | best_epoch = epoch 126 | 127 | torch.save(model.state_dict(), save_path+'/FAPNet_best.pth') 128 | print('best epoch:{}'.format(epoch)) 129 | 130 | logging.info('#TEST#:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch,mae,best_epoch,best_mae)) 131 | 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('--epoch', type=int, default=200, 137 | help='epoch number, default=30') 138 | parser.add_argument('--lr', type=float, default=1e-4, 139 | help='init learning rate, try `lr=1e-4`') 140 | parser.add_argument('--batchsize', type=int, default=32, 141 | help='training batch size (Note: ~500MB per img in GPU)') 142 | parser.add_argument('--trainsize', type=int, default=352, 143 | help='the size of training image, try small resolutions for speed (like 256)') 144 | parser.add_argument('--clip', type=float, default=0.5, 145 | help='gradient clipping margin') 146 | parser.add_argument('--decay_rate', type=float, default=0.1, 147 | help='decay rate of learning rate per decay step') 148 | parser.add_argument('--decay_epoch', type=int, default=30, 149 | help='every N epochs decay lr') 150 | parser.add_argument('--gpu', type=int, default=0, 151 | help='choose which gpu you use') 152 | parser.add_argument('--save_epoch', type=int, default=5, 153 | help='every N epochs save your trained snapshot') 154 | parser.add_argument('--save_model', type=str, default='./Snapshot/FAPNet/') 155 | 156 | parser.add_argument('--train_img_dir', type=str, default='./data/TrainDataset/Imgs/') 157 | parser.add_argument('--train_gt_dir', type=str, default='./data/TrainDataset/GT/') 158 | parser.add_argument('--train_eg_dir', type=str, default='./data/TrainDataset/Edge/') 159 | 160 | parser.add_argument('--test_img_dir', type=str, default='./data/TestDataset/CAMO/Imgs/') 161 | parser.add_argument('--test_gt_dir', type=str, default='./data/TestDataset/CAMO/GT/') 162 | parser.add_argument('--test_eg_dir', type=str, default='./data/TestDataset/CAMO/Edge/') 163 | 164 | 165 | opt = parser.parse_args() 166 | 167 | torch.cuda.set_device(opt.gpu) 168 | 169 | ## log 170 | 171 | save_path = opt.save_model 172 | os.makedirs(save_path, exist_ok=True) 173 | 174 | 175 | logging.basicConfig(filename=opt.save_model+'/log.log',format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level = logging.INFO,filemode='a',datefmt='%Y-%m-%d %I:%M:%S %p') 176 | logging.info("COD-Train") 177 | logging.info("Config") 178 | logging.info('epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};save_path:{};decay_epoch:{}'.format(opt.epoch,opt.lr,opt.batchsize,opt.trainsize,opt.clip,opt.decay_rate,opt.save_model,opt.decay_epoch)) 179 | 180 | 181 | 182 | # 183 | model = FAPNet(channel=64).cuda() 184 | optimizer = torch.optim.Adam(model.parameters(), opt.lr) 185 | LogitsBCE = torch.nn.BCEWithLogitsLoss() 186 | 187 | 188 | #net, optimizer = amp.initialize(model_SINet, optimizer, opt_level='O1') # NOTES: Ox not 0x 189 | 190 | train_loader = get_loader(opt.train_img_dir, opt.train_gt_dir, opt.train_eg_dir, batchsize=opt.batchsize,trainsize=opt.trainsize) 191 | test_loader = test_dataset(opt.test_img_dir, opt.test_gt_dir, testsize=opt.trainsize) 192 | total_step = len(train_loader) 193 | 194 | print('--------------------starting-------------------') 195 | 196 | print('-' * 30, "\n[Training Dataset INFO]\nimg_dir: {}\ngt_dir: {}\nLearning Rate: {}\nBatch Size: {}\n" 197 | "Training Save: {}\ntotal_num: {}\n".format(opt.train_img_dir, opt.train_gt_dir, opt.lr, 198 | opt.batchsize, opt.save_model, total_step), '-' * 30) 199 | 200 | for epoch_iter in range(1, opt.epoch): 201 | 202 | adjust_lr(optimizer, epoch_iter, opt.decay_rate, opt.decay_epoch) 203 | 204 | train(train_loader, model, optimizer, epoch_iter, opt, LogitsBCE, total_step) 205 | test(test_loader, model, epoch_iter, opt.save_model) 206 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | 10 | # several data augumentation strategies 11 | def cv_random_flip(img, label,edge): 12 | # left right flip 13 | flip_flag = random.randint(0, 1) 14 | if flip_flag == 1: 15 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 16 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 17 | edge = edge.transpose(Image.FLIP_LEFT_RIGHT) 18 | return img, label,edge 19 | 20 | 21 | def randomCrop(image, label,edge): 22 | border = 30 23 | image_width = image.size[0] 24 | image_height = image.size[1] 25 | crop_win_width = np.random.randint(image_width - border, image_width) 26 | crop_win_height = np.random.randint(image_height - border, image_height) 27 | random_region = ( 28 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 29 | (image_height + crop_win_height) >> 1) 30 | return image.crop(random_region), label.crop(random_region), edge.crop(random_region) 31 | 32 | 33 | def randomRotation(image, label, edge): 34 | mode = Image.BICUBIC 35 | if random.random() > 0.8: 36 | random_angle = np.random.randint(-15, 15) 37 | image = image.rotate(random_angle, mode) 38 | label = label.rotate(random_angle, mode) 39 | edge = edge.rotate(random_angle, mode) 40 | return image, label, edge 41 | 42 | 43 | def colorEnhance(image): 44 | bright_intensity = random.randint(5, 15) / 10.0 45 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 46 | contrast_intensity = random.randint(5, 15) / 10.0 47 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 48 | color_intensity = random.randint(0, 20) / 10.0 49 | image = ImageEnhance.Color(image).enhance(color_intensity) 50 | sharp_intensity = random.randint(0, 30) / 10.0 51 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 52 | return image 53 | 54 | 55 | def randomGaussian(image, mean=0.1, sigma=0.35): 56 | def gaussianNoisy(im, mean=mean, sigma=sigma): 57 | for _i in range(len(im)): 58 | im[_i] += random.gauss(mean, sigma) 59 | return im 60 | 61 | img = np.asarray(image) 62 | width, height = img.shape 63 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 64 | img = img.reshape([width, height]) 65 | return Image.fromarray(np.uint8(img)) 66 | 67 | 68 | def randomPeper(img): 69 | img = np.array(img) 70 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 71 | for i in range(noiseNum): 72 | 73 | randX = random.randint(0, img.shape[0] - 1) 74 | 75 | randY = random.randint(0, img.shape[1] - 1) 76 | 77 | if random.randint(0, 1) == 0: 78 | 79 | img[randX, randY] = 0 80 | 81 | else: 82 | 83 | img[randX, randY] = 255 84 | return Image.fromarray(img) 85 | 86 | 87 | def randomPeper_eg(img, edge): 88 | 89 | img = np.array(img) 90 | edge = np.array(edge) 91 | 92 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 93 | for i in range(noiseNum): 94 | 95 | randX = random.randint(0, img.shape[0] - 1) 96 | 97 | randY = random.randint(0, img.shape[1] - 1) 98 | 99 | if random.randint(0, 1) == 0: 100 | 101 | img[randX, randY] = 0 102 | edge[randX, randY] = 0 103 | 104 | else: 105 | 106 | img[randX, randY] = 255 107 | edge[randX, randY] = 255 108 | 109 | return Image.fromarray(img), Image.fromarray(edge) 110 | 111 | 112 | 113 | # dataset for training 114 | class PolypObjDataset(data.Dataset): 115 | def __init__(self, image_root, gt_root, edge_root, trainsize): 116 | self.trainsize = trainsize 117 | # get filenames 118 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 119 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 120 | or f.endswith('.png')] 121 | 122 | 123 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 124 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')] 125 | self.egs = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.jpg') or f.endswith('.png')] 126 | 127 | 128 | 129 | # self.grads = [grad_root + f for f in os.listdir(grad_root) if f.endswith('.jpg') 130 | # or f.endswith('.png')] 131 | # self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 132 | # or f.endswith('.png')] 133 | # sorted files 134 | self.images = sorted(self.images) 135 | self.gts = sorted(self.gts) 136 | self.egs = sorted(self.egs) 137 | 138 | 139 | # self.grads = sorted(self.grads) 140 | # self.depths = sorted(self.depths) 141 | # filter mathcing degrees of files 142 | self.filter_files() 143 | # transforms 144 | self.img_transform = transforms.Compose([ 145 | transforms.Resize((self.trainsize, self.trainsize)), 146 | transforms.ToTensor(), 147 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 148 | self.gt_transform = transforms.Compose([ 149 | transforms.Resize((self.trainsize, self.trainsize)), 150 | transforms.ToTensor()]) 151 | 152 | self.eg_transform = transforms.Compose([ 153 | transforms.Resize((self.trainsize, self.trainsize)), 154 | transforms.ToTensor()]) 155 | 156 | 157 | # get size of dataset 158 | self.size = len(self.images) 159 | 160 | def __getitem__(self, index): 161 | # read imgs/gts/grads/depths 162 | image = self.rgb_loader(self.images[index]) 163 | gt = self.binary_loader(self.gts[index]) 164 | eg = self.binary_loader(self.egs[index]) 165 | 166 | # data augumentation 167 | image, gt, eg = cv_random_flip(image, gt, eg) 168 | image, gt, eg = randomCrop(image, gt, eg) 169 | image, gt, eg = randomRotation(image, gt, eg) 170 | 171 | image = colorEnhance(image) 172 | gt,eg = randomPeper_eg(gt,eg) 173 | 174 | image = self.img_transform(image) 175 | gt = self.gt_transform(gt) 176 | eg = self.eg_transform(eg) 177 | 178 | return image, gt, eg 179 | 180 | def filter_files(self): 181 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 182 | images = [] 183 | gts = [] 184 | for img_path, gt_path in zip(self.images, self.gts): 185 | img = Image.open(img_path) 186 | gt = Image.open(gt_path) 187 | if img.size == gt.size: 188 | images.append(img_path) 189 | gts.append(gt_path) 190 | self.images = images 191 | self.gts = gts 192 | 193 | def rgb_loader(self, path): 194 | with open(path, 'rb') as f: 195 | img = Image.open(f) 196 | return img.convert('RGB') 197 | 198 | def binary_loader(self, path): 199 | with open(path, 'rb') as f: 200 | img = Image.open(f) 201 | return img.convert('L') 202 | 203 | def __len__(self): 204 | return self.size 205 | 206 | 207 | # dataloader for training 208 | def get_loader(image_root, gt_root, eg_root, batchsize, trainsize, 209 | shuffle=True, num_workers=12, pin_memory=True): 210 | dataset = PolypObjDataset(image_root, gt_root, eg_root,trainsize) 211 | data_loader = data.DataLoader(dataset=dataset, 212 | batch_size=batchsize, 213 | shuffle=shuffle, 214 | num_workers=num_workers, 215 | pin_memory=pin_memory) 216 | return data_loader 217 | 218 | 219 | # test dataset and loader 220 | class test_dataset: 221 | def __init__(self, image_root, gt_root, testsize): 222 | self.testsize = testsize 223 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 224 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] 225 | self.images = sorted(self.images) 226 | self.gts = sorted(self.gts) 227 | self.transform = transforms.Compose([ 228 | transforms.Resize((self.testsize, self.testsize)), 229 | transforms.ToTensor(), 230 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 231 | self.gt_transform = transforms.ToTensor() 232 | self.size = len(self.images) 233 | self.index = 0 234 | 235 | def load_data(self): 236 | image = self.rgb_loader(self.images[self.index]) 237 | image = self.transform(image).unsqueeze(0) 238 | 239 | gt = self.binary_loader(self.gts[self.index]) 240 | 241 | name = self.images[self.index].split('/')[-1] 242 | 243 | image_for_post = self.rgb_loader(self.images[self.index]) 244 | image_for_post = image_for_post.resize(gt.size) 245 | 246 | if name.endswith('.jpg'): 247 | name = name.split('.jpg')[0] + '.png' 248 | 249 | self.index += 1 250 | self.index = self.index % self.size 251 | 252 | return image, gt, name, np.array(image_for_post) 253 | 254 | def rgb_loader(self, path): 255 | with open(path, 'rb') as f: 256 | img = Image.open(f) 257 | return img.convert('RGB') 258 | 259 | def binary_loader(self, path): 260 | with open(path, 'rb') as f: 261 | img = Image.open(f) 262 | return img.convert('L') 263 | 264 | def __len__(self): 265 | return self.size 266 | 267 | class tt_dataset: 268 | """load test dataset (batchsize=1)""" 269 | def __init__(self, image_root, gt_root, testsize): 270 | self.testsize = testsize 271 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 272 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 273 | or f.endswith('.png')] 274 | self.images = sorted(self.images) 275 | self.gts = sorted(self.gts) 276 | self.transform = transforms.Compose([ 277 | transforms.Resize((self.testsize, self.testsize)), 278 | transforms.ToTensor(), 279 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 280 | self.gt_transform = transforms.ToTensor() 281 | self.size = len(self.images) 282 | self.index = 0 283 | 284 | def load_data(self): 285 | image = self.rgb_loader(self.images[self.index]) 286 | image = self.transform(image).unsqueeze(0) 287 | gt = self.binary_loader(self.gts[self.index]) 288 | name = self.images[self.index].split('/')[-1] 289 | 290 | if name.endswith('.jpg'): 291 | name = name.split('.jpg')[0] + '.png' 292 | self.index += 1 293 | return image, gt, name 294 | 295 | def rgb_loader(self, path): 296 | with open(path, 'rb') as f: 297 | img = Image.open(f) 298 | return img.convert('RGB') 299 | 300 | def binary_loader(self, path): 301 | with open(path, 'rb') as f: 302 | img = Image.open(f) 303 | return img.convert('L') 304 | 305 | 306 | -------------------------------------------------------------------------------- /utils/eva_funcs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Sep 29 17:21:18 2020 5 | 6 | @author: taozhou 7 | """ 8 | 9 | import os 10 | import time 11 | 12 | import numpy as np 13 | import torch 14 | from torchvision import transforms 15 | 16 | 17 | ############################################################################### 18 | ## basic funcs 19 | ############################################################################### 20 | 21 | def numpy2tensor(numpy): 22 | """ 23 | convert numpy_array in cpu to tensor in gpu 24 | :param numpy: 25 | :return: torch.from_numpy(numpy).cuda() 26 | """ 27 | return torch.from_numpy(numpy).cuda() 28 | 29 | def fun_eval_e(y_pred, y, num, cuda=True): 30 | 31 | if cuda: 32 | score = torch.zeros(num).cuda() 33 | else: 34 | score = torch.zeros(num) 35 | 36 | for i in range(num): 37 | 38 | fm = y_pred - y_pred.mean() 39 | gt = y - y.mean() 40 | align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20) 41 | enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4 42 | score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20) 43 | return score.max() 44 | 45 | 46 | def fun_eval_pr(y_pred, y, num, cuda=True): 47 | 48 | if cuda: 49 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda() 50 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 51 | else: 52 | prec, recall = torch.zeros(num), torch.zeros(num) 53 | thlist = torch.linspace(0, 1 - 1e-10, num) 54 | 55 | for i in range(num): 56 | y_temp = (y_pred >= thlist[i]).float() 57 | tp = (y_temp * y).sum() 58 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 59 | return prec, recall 60 | 61 | 62 | def fun_S_object(pred, gt): 63 | 64 | fg = torch.where(gt==0, torch.zeros_like(pred), pred) 65 | bg = torch.where(gt==1, torch.zeros_like(pred), 1-pred) 66 | o_fg = fun_object(fg, gt) 67 | o_bg = fun_object(bg, 1-gt) 68 | u = gt.mean() 69 | Q = u * o_fg + (1-u) * o_bg 70 | return Q 71 | 72 | 73 | def fun_object(pred, gt): 74 | 75 | temp = pred[gt == 1] 76 | x = temp.mean() 77 | sigma_x = temp.std() 78 | score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20) 79 | 80 | return score 81 | 82 | 83 | def fun_S_region(pred, gt): 84 | 85 | X, Y = fun_centroid(gt) 86 | gt1, gt2, gt3, gt4, w1, w2, w3, w4 = fun_divideGT(gt, X, Y) 87 | p1, p2, p3, p4 = fun_dividePrediction(pred, X, Y) 88 | Q1 = fun_ssim(p1, gt1) 89 | Q2 = fun_ssim(p2, gt2) 90 | Q3 = fun_ssim(p3, gt3) 91 | Q4 = fun_ssim(p4, gt4) 92 | Q = w1*Q1 + w2*Q2 + w3*Q3 + w4*Q4 93 | 94 | return Q 95 | 96 | def fun_centroid(gt, cuda=True): 97 | 98 | rows, cols = gt.size()[-2:] 99 | gt = gt.view(rows, cols) 100 | 101 | if gt.sum() == 0: 102 | 103 | if cuda: 104 | X = torch.eye(1).cuda() * round(cols / 2) 105 | Y = torch.eye(1).cuda() * round(rows / 2) 106 | else: 107 | X = torch.eye(1) * round(cols / 2) 108 | Y = torch.eye(1) * round(rows / 2) 109 | 110 | else: 111 | total = gt.sum() 112 | 113 | if cuda: 114 | i = torch.from_numpy(np.arange(0,cols)).cuda().float() 115 | j = torch.from_numpy(np.arange(0,rows)).cuda().float() 116 | else: 117 | i = torch.from_numpy(np.arange(0,cols)).float() 118 | j = torch.from_numpy(np.arange(0,rows)).float() 119 | 120 | X = torch.round((gt.sum(dim=0)*i).sum() / total) 121 | Y = torch.round((gt.sum(dim=1)*j).sum() / total) 122 | 123 | return X.long(), Y.long() 124 | 125 | 126 | def fun_divideGT(gt, X, Y): 127 | 128 | h, w = gt.size()[-2:] 129 | area = h*w 130 | gt = gt.view(h, w) 131 | LT = gt[:Y, :X] 132 | RT = gt[:Y, X:w] 133 | LB = gt[Y:h, :X] 134 | RB = gt[Y:h, X:w] 135 | X = X.float() 136 | Y = Y.float() 137 | w1 = X * Y / area 138 | w2 = (w - X) * Y / area 139 | w3 = X * (h - Y) / area 140 | w4 = 1 - w1 - w2 - w3 141 | 142 | return LT, RT, LB, RB, w1, w2, w3, w4 143 | 144 | def fun_dividePrediction(pred, X, Y): 145 | 146 | h, w = pred.size()[-2:] 147 | pred = pred.view(h, w) 148 | LT = pred[:Y, :X] 149 | RT = pred[:Y, X:w] 150 | LB = pred[Y:h, :X] 151 | RB = pred[Y:h, X:w] 152 | 153 | return LT, RT, LB, RB 154 | 155 | 156 | def fun_ssim(pred, gt): 157 | 158 | gt = gt.float() 159 | h, w = pred.size()[-2:] 160 | N = h*w 161 | x = pred.mean() 162 | y = gt.mean() 163 | sigma_x2 = ((pred - x)*(pred - x)).sum() / (N - 1 + 1e-20) 164 | sigma_y2 = ((gt - y)*(gt - y)).sum() / (N - 1 + 1e-20) 165 | sigma_xy = ((pred - x)*(gt - y)).sum() / (N - 1 + 1e-20) 166 | 167 | aplha = 4 * x * y *sigma_xy 168 | beta = (x*x + y*y) * (sigma_x2 + sigma_y2) 169 | 170 | if aplha != 0: 171 | Q = aplha / (beta + 1e-20) 172 | elif aplha == 0 and beta == 0: 173 | Q = 1.0 174 | else: 175 | Q = 0 176 | 177 | return Q 178 | 179 | ############################################################################### 180 | ## metric funcs 181 | ############################################################################### 182 | def eval_mae(pred,gt,cuda=True): 183 | 184 | with torch.no_grad(): 185 | 186 | trans = transforms.Compose([transforms.ToTensor()]) 187 | 188 | if cuda: 189 | pred = pred.cuda() 190 | gt = gt.cuda() 191 | # else: 192 | # pred = trans(pred) 193 | # gt = trans(gt) 194 | 195 | mae = torch.abs(pred - gt).mean() 196 | 197 | return mae.cpu().detach().numpy() 198 | 199 | 200 | def eval_Smeasure(pred,gt,cuda=True): 201 | 202 | alpha, avg_q, img_num = 0.5, 0.0, 0.0 203 | 204 | with torch.no_grad(): 205 | 206 | trans = transforms.Compose([transforms.ToTensor()]) 207 | 208 | y = gt.mean() 209 | 210 | ## 211 | if y == 0: 212 | x = pred.mean() 213 | Q = 1.0 - x 214 | elif y == 1: 215 | x = pred.mean() 216 | Q = x 217 | else: 218 | Q = alpha * fun_S_object(pred, gt) + (1-alpha) * fun_S_region(pred, gt) 219 | if Q.item() < 0: 220 | Q = torch.FLoatTensor([0.0]) 221 | 222 | return Q.item() 223 | 224 | 225 | def eval_fmeasure(pred, gt, cuda=True): 226 | print('eval[FMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 227 | 228 | beta2 = 0.3 229 | avg_p, avg_r, img_num = 0.0, 0.0, 0.0 230 | 231 | ## 232 | with torch.no_grad(): 233 | trans = transforms.Compose([transforms.ToTensor()]) 234 | if cuda: 235 | pred = trans(pred).cuda() 236 | gt = trans(gt).cuda() 237 | else: 238 | pred = trans(pred) 239 | gt = trans(gt) 240 | 241 | prec, recall = fun_eval_pr(pred, gt, 255) 242 | 243 | return prec, recall 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | class Eval_thread(): 256 | def __init__(self, loader, method, dataset, output_dir, cuda): 257 | self.loader = loader 258 | self.method = method 259 | self.dataset = dataset 260 | self.cuda = cuda 261 | self.logfile = os.path.join(output_dir, 'result.txt') 262 | def run(self): 263 | start_time = time.time() 264 | mae = self.Eval_mae() 265 | s = self.Eval_Smeasure() 266 | 267 | return mae,s 268 | 269 | #max_f = self.Eval_fmeasure() 270 | #max_e = self.Eval_Emeasure() 271 | 272 | #self.LOG('{} dataset with {} method get {:.4f} mae, {:.4f} max-fmeasure, {:.4f} max-Emeasure, {:.4f} S-measure..\n'.format(self.dataset, self.method, mae, max_f, max_e, s)) 273 | #return '[cost:{:.4f}s]{} dataset with {} method get {:.4f} mae, {:.4f} max-fmeasure, {:.4f} max-Emeasure, {:.4f} S-measure..'.format(time.time()-start_time, self.dataset, self.method, mae, max_f, max_e, s) 274 | 275 | def Eval_mae(self): 276 | 277 | with torch.no_grad(): 278 | trans = transforms.Compose([transforms.ToTensor()]) 279 | for pred, gt in self.loader: 280 | if self.cuda: 281 | 282 | pred = trans(pred).cuda() 283 | gt = trans(gt).cuda() 284 | else: 285 | pred = trans(pred) 286 | gt = trans(gt) 287 | mea = torch.abs(pred - gt).mean() 288 | if mea == mea: # for Nan 289 | avg_mae += mea 290 | img_num += 1.0 291 | avg_mae /= img_num 292 | 293 | return avg_mae.item() 294 | 295 | def Eval_fmeasure(self): 296 | print('eval[FMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 297 | beta2 = 0.3 298 | avg_p, avg_r, img_num = 0.0, 0.0, 0.0 299 | with torch.no_grad(): 300 | trans = transforms.Compose([transforms.ToTensor()]) 301 | for pred, gt in self.loader: 302 | if self.cuda: 303 | pred = trans(pred).cuda() 304 | gt = trans(gt).cuda() 305 | else: 306 | pred = trans(pred) 307 | gt = trans(gt) 308 | prec, recall = self._eval_pr(pred, gt, 255) 309 | avg_p += prec 310 | avg_r += recall 311 | img_num += 1.0 312 | avg_p /= img_num 313 | avg_r /= img_num 314 | score = (1 + beta2) * avg_p * avg_r / (beta2 * avg_p + avg_r) 315 | score[score != score] = 0 # for Nan 316 | 317 | return score.max().item() 318 | def Eval_Emeasure(self): 319 | print('eval[EMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 320 | avg_e, img_num = 0.0, 0.0 321 | with torch.no_grad(): 322 | trans = transforms.Compose([transforms.ToTensor()]) 323 | for pred, gt in self.loader: 324 | if self.cuda: 325 | pred = trans(pred).cuda() 326 | gt = trans(gt).cuda() 327 | else: 328 | pred = trans(pred) 329 | gt = trans(gt) 330 | max_e = self._eval_e(pred, gt, 255) 331 | if max_e == max_e: 332 | avg_e += max_e 333 | img_num += 1.0 334 | 335 | avg_e /= img_num 336 | return avg_e 337 | def Eval_Smeasure(self): 338 | #print('eval[SMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 339 | alpha, avg_q, img_num = 0.5, 0.0, 0.0 340 | with torch.no_grad(): 341 | trans = transforms.Compose([transforms.ToTensor()]) 342 | for pred, gt in self.loader: 343 | if self.cuda: 344 | pred = trans(pred).cuda() 345 | gt = trans(gt).cuda() 346 | else: 347 | pred = trans(pred) 348 | gt = trans(gt) 349 | y = gt.mean() 350 | if y == 0: 351 | x = pred.mean() 352 | Q = 1.0 - x 353 | elif y == 1: 354 | x = pred.mean() 355 | Q = x 356 | else: 357 | Q = alpha * self._S_object(pred, gt) + (1-alpha) * self._S_region(pred, gt) 358 | if Q.item() < 0: 359 | Q = torch.FLoatTensor([0.0]) 360 | img_num += 1.0 361 | avg_q += Q.item() 362 | avg_q /= img_num 363 | 364 | return avg_q 365 | def LOG(self, output): 366 | with open(self.logfile, 'a') as f: 367 | f.write(output) 368 | 369 | def _eval_e(self, y_pred, y, num): 370 | if self.cuda: 371 | score = torch.zeros(num).cuda() 372 | else: 373 | score = torch.zeros(num) 374 | for i in range(num): 375 | fm = y_pred - y_pred.mean() 376 | gt = y - y.mean() 377 | align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20) 378 | enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4 379 | score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20) 380 | return score.max() 381 | 382 | def _eval_pr(self, y_pred, y, num): 383 | if self.cuda: 384 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda() 385 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 386 | else: 387 | prec, recall = torch.zeros(num), torch.zeros(num) 388 | thlist = torch.linspace(0, 1 - 1e-10, num) 389 | for i in range(num): 390 | y_temp = (y_pred >= thlist[i]).float() 391 | tp = (y_temp * y).sum() 392 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 393 | return prec, recall 394 | 395 | def _S_object(self, pred, gt): 396 | fg = torch.where(gt==0, torch.zeros_like(pred), pred) 397 | bg = torch.where(gt==1, torch.zeros_like(pred), 1-pred) 398 | o_fg = self._object(fg, gt) 399 | o_bg = self._object(bg, 1-gt) 400 | u = gt.mean() 401 | Q = u * o_fg + (1-u) * o_bg 402 | return Q 403 | 404 | def _object(self, pred, gt): 405 | temp = pred[gt == 1] 406 | x = temp.mean() 407 | sigma_x = temp.std() 408 | score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20) 409 | 410 | return score 411 | 412 | def _S_region(self, pred, gt): 413 | X, Y = self._centroid(gt) 414 | gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divideGT(gt, X, Y) 415 | p1, p2, p3, p4 = self._dividePrediction(pred, X, Y) 416 | Q1 = self._ssim(p1, gt1) 417 | Q2 = self._ssim(p2, gt2) 418 | Q3 = self._ssim(p3, gt3) 419 | Q4 = self._ssim(p4, gt4) 420 | Q = w1*Q1 + w2*Q2 + w3*Q3 + w4*Q4 421 | # print(Q) 422 | return Q 423 | 424 | def _centroid(self, gt): 425 | rows, cols = gt.size()[-2:] 426 | gt = gt.view(rows, cols) 427 | if gt.sum() == 0: 428 | if self.cuda: 429 | X = torch.eye(1).cuda() * round(cols / 2) 430 | Y = torch.eye(1).cuda() * round(rows / 2) 431 | else: 432 | X = torch.eye(1) * round(cols / 2) 433 | Y = torch.eye(1) * round(rows / 2) 434 | else: 435 | total = gt.sum() 436 | if self.cuda: 437 | i = torch.from_numpy(np.arange(0,cols)).cuda().float() 438 | j = torch.from_numpy(np.arange(0,rows)).cuda().float() 439 | else: 440 | i = torch.from_numpy(np.arange(0,cols)).float() 441 | j = torch.from_numpy(np.arange(0,rows)).float() 442 | X = torch.round((gt.sum(dim=0)*i).sum() / total) 443 | Y = torch.round((gt.sum(dim=1)*j).sum() / total) 444 | return X.long(), Y.long() 445 | 446 | def _divideGT(self, gt, X, Y): 447 | h, w = gt.size()[-2:] 448 | area = h*w 449 | gt = gt.view(h, w) 450 | LT = gt[:Y, :X] 451 | RT = gt[:Y, X:w] 452 | LB = gt[Y:h, :X] 453 | RB = gt[Y:h, X:w] 454 | X = X.float() 455 | Y = Y.float() 456 | w1 = X * Y / area 457 | w2 = (w - X) * Y / area 458 | w3 = X * (h - Y) / area 459 | w4 = 1 - w1 - w2 - w3 460 | return LT, RT, LB, RB, w1, w2, w3, w4 461 | 462 | def _dividePrediction(self, pred, X, Y): 463 | h, w = pred.size()[-2:] 464 | pred = pred.view(h, w) 465 | LT = pred[:Y, :X] 466 | RT = pred[:Y, X:w] 467 | LB = pred[Y:h, :X] 468 | RB = pred[Y:h, X:w] 469 | return LT, RT, LB, RB 470 | 471 | def _ssim(self, pred, gt): 472 | gt = gt.float() 473 | h, w = pred.size()[-2:] 474 | N = h*w 475 | x = pred.mean() 476 | y = gt.mean() 477 | sigma_x2 = ((pred - x)*(pred - x)).sum() / (N - 1 + 1e-20) 478 | sigma_y2 = ((gt - y)*(gt - y)).sum() / (N - 1 + 1e-20) 479 | sigma_xy = ((pred - x)*(gt - y)).sum() / (N - 1 + 1e-20) 480 | 481 | aplha = 4 * x * y *sigma_xy 482 | beta = (x*x + y*y) * (sigma_x2 + sigma_y2) 483 | 484 | if aplha != 0: 485 | Q = aplha / (beta + 1e-20) 486 | elif aplha == 0 and beta == 0: 487 | Q = 1.0 488 | else: 489 | Q = 0 490 | return Q 491 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from datetime import datetime 4 | import os 5 | #from apex import amp 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | 10 | def eval_mae(y_pred, y): 11 | """ 12 | evaluate MAE (for test or validation phase) 13 | :param y_pred: 14 | :param y: 15 | :return: Mean Absolute Error 16 | """ 17 | return np.abs(y_pred - y).mean() 18 | 19 | 20 | 21 | 22 | def numpy2tensor(numpy): 23 | """ 24 | convert numpy_array in cpu to tensor in gpu 25 | :param numpy: 26 | :return: torch.from_numpy(numpy).cuda() 27 | """ 28 | return torch.from_numpy(numpy).cuda() 29 | 30 | 31 | def clip_gradient(optimizer, grad_clip): 32 | """ 33 | recalibrate the misdirection in the training 34 | :param optimizer: 35 | :param grad_clip: 36 | :return: 37 | """ 38 | for group in optimizer.param_groups: 39 | for param in group['params']: 40 | if param.grad is not None: 41 | param.grad.data.clamp_(-grad_clip, grad_clip) 42 | 43 | 44 | def adjust_lr(optimizer, epoch, decay_rate=0.1, decay_epoch=30): 45 | decay = decay_rate ** (epoch // decay_epoch) 46 | for param_group in optimizer.param_groups: 47 | param_group['lr'] *= decay 48 | 49 | --------------------------------------------------------------------------------