├── README.md ├── imgs ├── framework.png └── qualitative_results.png ├── lib ├── model.py └── res2net_v1b_base.py ├── mindspore ├── lib │ ├── model.py │ └── res2net_v1b_base.py ├── test.py ├── train.py └── utils │ └── dataloader.py ├── test.py ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── CODNet.cpython-36.pyc ├── CODNet.cpython-38.pyc ├── SINet.cpython-36.pyc ├── SearchAttention.cpython-36.pyc ├── SearchAttention.cpython-38.pyc ├── __init__.cpython-36.pyc ├── __init__.cpython-38.pyc └── res2net_v1b_base.cpython-38.pyc ├── dataloader.py ├── eva_funcs.py ├── evaluator.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # Cross-level Feature Aggregation Network for Polyp Segmentation 2 | 3 | > **Authors:** 4 | > [Tao Zhou](https://taozh2017.github.io/), 5 | > [Yi Zhou](https://cse.seu.edu.cn/2021/0303/c23024a362239/page.htm), 6 | > [Kelei He](https://scholar.google.com/citations?user=0Do_BMIAAAAJ&hl=en), 7 | > [Chen Gong](https://gcatnjust.github.io/ChenGong/index.html), 8 | > [Jian Yang](https://scholar.google.com/citations?user=6CIDtZQAAAAJ&hl=en), 9 | > [Huazhu Fu](http://hzfu.github.io/), and 10 | > [Dinggang Shen](https://scholar.google.com/citations?user=v6VYQC8AAAAJ&hl=en). 11 | 12 | 13 | ## 1. Preface 14 | 15 | - This repository provides code for "_**Cross-level Feature Aggregation Network for Polyp Segmentation (CFANet)**_". 16 | ([paper](https://www.sciencedirect.com/science/article/pii/S0031320323002558)) 17 | 18 | - If you have any questions about our paper, feel free to contact me. And if you are using CFANet for your research, please cite this paper ([BibTeX](#4-citation)). 19 | 20 | 21 | ### 1.1. :fire: NEWS :fire: 22 | 23 | - [2023/05/20] Release training/testing code. 24 | 25 | - [2020/05/10] Create repository. 26 | 27 | 28 | ### 2.1. Table of Contents 29 | 30 | - [Cross-level Feature Aggregation Network for Polyp Segmentation] 31 | - [2. Overview](#2-overview) 32 | - [2.1. Introduction](#21-introduction) 33 | - [2.2. Framework Overview](#22-framework-overview) 34 | - [2.3. Qualitative Results](#23-qualitative-results) 35 | - [3. Proposed Baseline](#3-proposed-baseline) 36 | - [3.1. Training/Testing](#31-trainingtesting) 37 | - [3.2 Evaluating your trained model:](#32-evaluating-your-trained-model) 38 | - [3.3 Pre-computed maps:](#33-pre-computed-maps) 39 | - [4. MindSpore](#4-mindspore) 40 | - [5. Citation](#5-citation) 41 | - [6. License](#6-license) 42 | 43 | Table of contents generated with markdown-toc 44 | 45 | ## 2. Overview 46 | 47 | ### 2.1. Introduction 48 | 49 | Accurate segmentation of polyps from colonoscopy images plays a critical role in the diagnosis and cure of colorectal cancer. Although effectiveness has been achieved in the field of polyp segmentation, there are still several challenges. Polyps often have a diversity of size and shape and have no sharp boundary between polyps and their surrounding. To address these challenges, we propose a novel Cross-level Feature Aggregation Network (CFA-Net) for polyp segmentation. Specifically, we first propose a boundary prediction network to generate boundary-aware features, which are incorporated into the segmentation network using a layer-wise strategy. In particular, we design a two-stream structure based segmentation network, to exploit hierarchical semantic information from cross-level features. Furthermore, a Cross-level Feature Fusion (CFF) module is proposed to integrate the adjacent features from different levels, which can characterize the cross-level and multi-scale information to handle scale variations of polyps. Further, a Boundary Aggregated Module (BAM) is proposed to incorporate boundary information into the segmentation network, which enhances these hierarchical features to generate finer segmentation maps. Quantitative and qualitative experiments on five public datasets demonstrate the effectiveness of our CFA-Net against other state-of-the-art polyp segmentation methods 50 | 51 | ### 2.2. Framework Overview 52 | 53 |

54 |
55 | 56 | Figure 1: Overview of the proposed CFANet. 57 | 58 |

59 | 60 | ### 2.3. Qualitative Results 61 | 62 |

63 |
64 | 65 | Figure 2: Qualitative Results. 66 | 67 |

68 | 69 | ## 3. Proposed Baseline 70 | 71 | ### 3.1. Training/Testing 72 | 73 | The training and testing experiments are conducted using [PyTorch](https://github.com/pytorch/pytorch) with 74 | a single NVIDIA Tesla P40 with 24 GB Memory. 75 | 76 | > Note that our model also supports low memory GPU, which means you can lower the batch size 77 | 78 | 79 | 1. Configuring your environment (Prerequisites): 80 | 81 | Note that CFANet is only tested on Ubuntu OS with the following environments. 82 | It may work on other operating systems as well but we do not guarantee that it will. 83 | 84 | + Creating a virtual environment in terminal: `conda create -n CFANet python=3.6`. 85 | 86 | + Installing necessary packages: PyTorch 1.1 87 | 88 | 1. Downloading necessary data: 89 | 90 | + downloading testing dataset and move it into `./data/TestDataset/`, 91 | which can be found in this [download link (Google Drive)](https://drive.google.com/file/d/1hwirZO201i_08fFgqmeqMuPuhPboHdVH/view?usp=sharing). It contains five sub-datsets: CVC-300 (60 test samples), CVC-ClinicDB (62 test samples), CVC-ColonDB (380 test samples), ETIS-LaribPolypDB (196 test samples), Kvasir (100 test samples). 92 | 93 | + downloading training dataset and move it into `./data/TrainDataset/`, 94 | which can be found in this [download link (Google Drive)](https://drive.google.com/file/d/1hzS21idjQlXnX9oxAgJI8KZzOBaz-OWj/view?usp=sharing). It contains two sub-datasets: Kvasir-SEG (900 train samples) and CVC-ClinicDB (550 train samples). 95 | 96 | + downloading pretrained weights and move it into `checkpoint/CFANet.pth`, 97 | which can be found in this [download link (Google Drive)](https://drive.google.com/file/d/1pgvgYebjVVm-QZN-VbGdtYmAyccQmKxZ/view?usp=sharing). 98 | 99 | + downloading Res2Net weights and and move it into `./lib/`, 100 | which can be found in this [download link (Google Drive)](https://drive.google.com/file/d/1_1N-cx1UpRQo7Ybsjno1PAg4KE1T9e5J/view?usp=sharing). 101 | 102 | 1. Training Configuration: 103 | 104 | + Assigning your costumed path, like `--save_model` and `--train_path` in `train.py`. 105 | 106 | + Just enjoy it! 107 | 108 | 1. Testing Configuration: 109 | 110 | + After you download all the pre-trained model and testing dataset, just run `test.py` to generate the final prediction map: 111 | replace your trained model directory (`--pth_path`). 112 | 113 | + Just enjoy it! 114 | 115 | ### 3.2 Evaluating your trained model: 116 | 117 | Matlab: One-key evaluation is written in MATLAB code ([link](https://drive.google.com/file/d/1_h4_CjD5GKEf7B1MRuzye97H0MXf2GE9/view?usp=sharing)), 118 | please follow this the instructions in `./eval/main.m` and just run it to generate the evaluation results in `./res/`. 119 | The complete evaluation toolbox (including data, map, eval code, and res): [new link](https://drive.google.com/file/d/1bnlz7nfJ9hhYsMLFSBr9smcI7k7p0pVy/view?usp=sharing). 120 | 121 | ### 3.3 Pre-computed maps: 122 | They can be found in [download link](https://drive.google.com/file/d/1FY2FFDw-VLwmZ-JbJ-h4uAizcpgiY5vg/view?usp=drive_link). 123 | 124 | 125 | ## 4. MindSpore 126 | 127 | You need to run `cd mindspore` first. 128 | 129 | 1. Environment Configuration: 130 | 131 | + MindSpore: 2.0.0-alpha 132 | 133 | + Python: 3.8.0 134 | 135 | 5. Training Configuration: 136 | 137 | + Assigning your costumed path, like `--save_model` , `--train_img_dir` and so on in `train.py`. 138 | 139 | + Just enjoy it! 140 | 141 | 6. Testing Configuration: 142 | 143 | + After you download all the pre-trained model and testing dataset, just run `test.py` to generate the final prediction map: 144 | replace your trained model directory (`--pth_path`). 145 | 146 | + Just enjoy it! 147 | 148 | ## 5. Citation 149 | 150 | Please cite our paper if you find the work useful: 151 | 152 | @article{zhou2023cross, 153 | title={Cross-level Feature Aggregation Network for Polyp Segmentation}, 154 | author={Zhou, Tao and Zhou, Yi and He, Kelei and Gong, Chen and Yang, Jian and Fu, Huazhu and Shen, Dinggang}, 155 | journal={Pattern Recognition}, 156 | volume={140}, 157 | pages={109555}, 158 | year={2023}, 159 | publisher={Elsevier} 160 | } 161 | 162 | 163 | ## 6. License 164 | 165 | The source code is free for research and education use only. Any comercial use should get formal permission first. 166 | 167 | --- 168 | 169 | **[⬆ back to top](#0-preface)** 170 | -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/imgs/framework.png -------------------------------------------------------------------------------- /imgs/qualitative_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/imgs/qualitative_results.png -------------------------------------------------------------------------------- /lib/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lib.res2net_v1b_base import Res2Net_model 4 | 5 | 6 | class global_module(nn.Module): 7 | def __init__(self, channels=64, r=4): 8 | super(global_module, self).__init__() 9 | out_channels = int(channels // r) 10 | # local_att 11 | 12 | self.global_att = nn.Sequential( 13 | nn.AdaptiveAvgPool2d(1), 14 | nn.Conv2d(channels, out_channels, kernel_size=1, stride=1, padding=0), 15 | nn.BatchNorm2d(out_channels), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(out_channels, channels, kernel_size=1, stride=1, padding=0), 18 | nn.BatchNorm2d(channels) 19 | ) 20 | 21 | self.sig = nn.Sigmoid() 22 | 23 | def forward(self, x): 24 | 25 | xg = self.global_att(x) 26 | out = self.sig(xg) 27 | 28 | return out 29 | 30 | 31 | class BasicConv2d(nn.Module): 32 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 33 | super(BasicConv2d, self).__init__() 34 | self.conv = nn.Conv2d(in_planes, out_planes, 35 | kernel_size=kernel_size, stride=stride, 36 | padding=padding, dilation=dilation, bias=False) 37 | self.bn = nn.BatchNorm2d(out_planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | 40 | def forward(self, x): 41 | x = self.conv(x) 42 | x = self.bn(x) 43 | return x 44 | 45 | 46 | class ChannelAttention(nn.Module): 47 | def __init__(self, in_planes, ratio=16): 48 | super(ChannelAttention, self).__init__() 49 | 50 | self.max_pool = nn.AdaptiveMaxPool2d(1) 51 | 52 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 53 | self.relu1 = nn.ReLU() 54 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 55 | 56 | self.sigmoid = nn.Sigmoid() 57 | def forward(self, x): 58 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 59 | out = max_out 60 | return self.sigmoid(out) 61 | 62 | 63 | class GateFusion(nn.Module): 64 | def __init__(self, in_planes): 65 | self.init__ = super(GateFusion, self).__init__() 66 | 67 | self.gate_1 = nn.Conv2d(in_planes*2, 1, kernel_size=1, bias=True) 68 | self.gate_2 = nn.Conv2d(in_planes*2, 1, kernel_size=1, bias=True) 69 | 70 | 71 | self.softmax = nn.Softmax(dim=1) 72 | 73 | def forward(self, x1, x2): 74 | 75 | ### 76 | cat_fea = torch.cat([x1,x2], dim=1) 77 | 78 | ### 79 | att_vec_1 = self.gate_1(cat_fea) 80 | att_vec_2 = self.gate_2(cat_fea) 81 | 82 | att_vec_cat = torch.cat([att_vec_1, att_vec_2], dim=1) 83 | att_vec_soft = self.softmax(att_vec_cat) 84 | 85 | att_soft_1, att_soft_2 = att_vec_soft[:, 0:1, :, :], att_vec_soft[:, 1:2, :, :] 86 | x_fusion = x1 * att_soft_1 + x2 * att_soft_2 87 | 88 | return x_fusion 89 | 90 | 91 | class BAM(nn.Module): 92 | # Partial Decoder Component (Identification Module) 93 | def __init__(self, channel): 94 | super(BAM, self).__init__() 95 | 96 | 97 | self.relu = nn.ReLU(True) 98 | 99 | self.global_att = global_module(channel) 100 | 101 | self.conv_layer = BasicConv2d(channel*2, channel, 3, padding=1) 102 | 103 | 104 | def forward(self, x, x_boun_atten): 105 | 106 | out1 = self.conv_layer(torch.cat((x, x_boun_atten), dim=1)) 107 | out2 = self.global_att(out1) 108 | out3 = out1.mul(out2) 109 | 110 | out = x + out3 111 | 112 | return out 113 | 114 | 115 | class CFF(nn.Module): 116 | def __init__(self, in_channel1, in_channel2, out_channel): 117 | self.init__ = super(CFF, self).__init__() 118 | 119 | 120 | act_fn = nn.ReLU(inplace=True) 121 | 122 | ## ---------------------------------------- ## 123 | self.layer0 = BasicConv2d(in_channel1, out_channel // 2, 1) 124 | self.layer1 = BasicConv2d(in_channel2, out_channel // 2, 1) 125 | 126 | self.layer3_1 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel // 2),act_fn) 127 | self.layer3_2 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel // 2),act_fn) 128 | 129 | self.layer5_1 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(out_channel // 2),act_fn) 130 | self.layer5_2 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(out_channel // 2),act_fn) 131 | 132 | self.layer_out = nn.Sequential(nn.Conv2d(out_channel // 2, out_channel, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel),act_fn) 133 | 134 | 135 | def forward(self, x0, x1): 136 | 137 | ## ------------------------------------------------------------------ ## 138 | x0_1 = self.layer0(x0) 139 | x1_1 = self.layer1(x1) 140 | 141 | x_3_1 = self.layer3_1(torch.cat((x0_1, x1_1), dim=1)) 142 | x_5_1 = self.layer5_1(torch.cat((x1_1, x0_1), dim=1)) 143 | 144 | x_3_2 = self.layer3_2(torch.cat((x_3_1, x_5_1), dim=1)) 145 | x_5_2 = self.layer5_2(torch.cat((x_5_1, x_3_1), dim=1)) 146 | 147 | out = self.layer_out(x0_1 + x1_1 + torch.mul(x_3_2, x_5_2)) 148 | 149 | return out 150 | 151 | 152 | 153 | ############################################################################### 154 | ## 2022/01/03 155 | ############################################################################### 156 | class CFANet(nn.Module): 157 | # resnet based encoder decoder 158 | def __init__(self, channel=64, opt=None): 159 | super(CFANet, self).__init__() 160 | 161 | act_fn = nn.ReLU(inplace=True) 162 | 163 | self.resnet = Res2Net_model(50) 164 | self.downSample = nn.MaxPool2d(2, stride=2) 165 | 166 | ## ---------------------------------------- ## 167 | 168 | self.layer0 = nn.Sequential(nn.Conv2d(64, channel, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(channel),act_fn) 169 | self.layer1 = nn.Sequential(nn.Conv2d(256, channel, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(channel),act_fn) 170 | 171 | self.low_fusion = GateFusion(channel) 172 | 173 | self.high_fusion1 = CFF(256, 512, channel) 174 | self.high_fusion2 = CFF(1024, 2048, channel) 175 | 176 | ## ---------------------------------------- ## 177 | self.layer_edge0 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 178 | self.layer_edge1 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 179 | self.layer_edge2 = nn.Sequential(nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),act_fn) 180 | self.layer_edge3 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1)) 181 | 182 | ## ---------------------------------------- ## 183 | self.layer_cat_ori1 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 184 | self.layer_hig01 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 185 | 186 | self.layer_cat11 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 187 | self.layer_hig11 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 188 | 189 | self.layer_cat21 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 190 | self.layer_hig21 = nn.Sequential(nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),act_fn) 191 | 192 | self.layer_cat31 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),act_fn) 193 | self.layer_hig31 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1)) 194 | 195 | self.layer_cat_ori2 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 196 | self.layer_hig02 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 197 | 198 | self.layer_cat12 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 199 | self.layer_hig12 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 200 | 201 | self.layer_cat22 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn) 202 | self.layer_hig22 = nn.Sequential(nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),act_fn) 203 | 204 | self.layer_cat32 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),act_fn) 205 | self.layer_hig32 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1)) 206 | 207 | self.layer_fil = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1)) 208 | 209 | ## ---------------------------------------- ## 210 | 211 | self.atten_edge_0 = ChannelAttention(channel) 212 | self.atten_edge_1 = ChannelAttention(channel) 213 | self.atten_edge_2 = ChannelAttention(channel) 214 | self.atten_edge_ori = ChannelAttention(channel) 215 | 216 | 217 | self.cat_01 = BAM(channel) 218 | self.cat_11 = BAM(channel) 219 | self.cat_21 = BAM(channel) 220 | self.cat_31 = BAM(channel) 221 | 222 | self.cat_02 = BAM(channel) 223 | self.cat_12 = BAM(channel) 224 | self.cat_22 = BAM(channel) 225 | self.cat_32 = BAM(channel) 226 | 227 | 228 | ## ---------------------------------------- ## 229 | self.downSample = nn.MaxPool2d(2, stride=2) 230 | self.up_2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 231 | self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 232 | self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 233 | 234 | 235 | 236 | def forward(self, xx): 237 | 238 | # ---- feature abstraction ----- 239 | 240 | x0, x1, x2, x3, x4 = self.resnet(xx) 241 | 242 | 243 | 244 | ## -------------------------------------- ## 245 | 246 | x0_1 = self.layer0(x0) 247 | x1_1 = self.layer1(x1) 248 | 249 | low_x = self.low_fusion(x0_1, x1_1) # 64*44 250 | 251 | 252 | edge_out0 = self.layer_edge0(self.up_2(low_x)) # 64*88 253 | edge_out1 = self.layer_edge1(self.up_2(edge_out0)) # 64*176 254 | edge_out2 = self.layer_edge2(self.up_2(edge_out1)) # 64*352 255 | edge_out3 = self.layer_edge3(edge_out2) 256 | 257 | 258 | etten_edge_ori = self.atten_edge_ori(low_x) 259 | etten_edge_0 = self.atten_edge_0(edge_out0) 260 | etten_edge_1 = self.atten_edge_1(edge_out1) 261 | etten_edge_2 = self.atten_edge_2(edge_out2) 262 | 263 | 264 | ## -------------------------------------- ## 265 | high_x01 = self.high_fusion1(self.downSample(x1), x2) 266 | high_x02 = self.high_fusion2(self.up_2(x3), self.up_4(x4)) 267 | 268 | ## --------------- high 1 ----------------------- # 269 | cat_out_01 = self.cat_01(high_x01,low_x.mul(etten_edge_ori)) 270 | hig_out01 = self.layer_hig01(self.up_2(cat_out_01)) 271 | 272 | cat_out11 = self.cat_11(hig_out01,edge_out0.mul(etten_edge_0)) 273 | hig_out11 = self.layer_hig11(self.up_2(cat_out11)) 274 | 275 | cat_out21 = self.cat_21(hig_out11,edge_out1.mul(etten_edge_1)) 276 | hig_out21 = self.layer_hig21(self.up_2(cat_out21)) 277 | 278 | cat_out31 = self.cat_31(hig_out21,edge_out2.mul(etten_edge_2)) 279 | sal_out1 = self.layer_hig31(cat_out31) 280 | 281 | ## ---------------- high 2 ---------------------- ## 282 | cat_out_02 = self.cat_02(high_x02,low_x.mul(etten_edge_ori)) 283 | hig_out02 = self.layer_hig02(self.up_2(cat_out_02)) 284 | 285 | cat_out12 = self.cat_12(hig_out02,edge_out0.mul(etten_edge_0)) 286 | hig_out12 = self.layer_hig12(self.up_2(cat_out12)) 287 | 288 | cat_out22 = self.cat_22(hig_out12,edge_out1.mul(etten_edge_1)) 289 | hig_out22 = self.layer_hig22(self.up_2(cat_out22)) 290 | 291 | cat_out32 = self.cat_32(hig_out22,edge_out2.mul(etten_edge_2)) 292 | sal_out2 = self.layer_hig32(cat_out32) 293 | 294 | ## --------------------------------------------- ## 295 | sal_out3 = self.layer_fil(cat_out31+cat_out32) 296 | 297 | # ---- output ---- 298 | return edge_out3, sal_out1, sal_out2, sal_out3 299 | 300 | 301 | -------------------------------------------------------------------------------- /lib/res2net_v1b_base.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import torch.nn.functional as F 7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b'] 8 | 9 | 10 | model_urls = { 11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 13 | } 14 | 15 | 16 | class Bottle2neck(nn.Module): 17 | expansion = 4 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'): 20 | """ Constructor 21 | Args: 22 | inplanes: input channel dimensionality 23 | planes: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | downsample: None when stride = 1 26 | baseWidth: basic width of conv3x3 27 | scale: number of scale. 28 | type: 'normal': normal set. 'stage': first block of a new stage. 29 | """ 30 | super(Bottle2neck, self).__init__() 31 | 32 | width = int(math.floor(planes * (baseWidth/64.0))) 33 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(width*scale) 35 | 36 | if scale == 1: 37 | self.nums = 1 38 | else: 39 | self.nums = scale -1 40 | if stype == 'stage': 41 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 42 | convs = [] 43 | bns = [] 44 | for i in range(self.nums): 45 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False)) 46 | bns.append(nn.BatchNorm2d(width)) 47 | self.convs = nn.ModuleList(convs) 48 | self.bns = nn.ModuleList(bns) 49 | 50 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 52 | 53 | self.relu = nn.ReLU(inplace=True) 54 | self.downsample = downsample 55 | self.stype = stype 56 | self.scale = scale 57 | self.width = width 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | spx = torch.split(out, self.width, 1) 67 | for i in range(self.nums): 68 | if i==0 or self.stype=='stage': 69 | sp = spx[i] 70 | else: 71 | sp = sp + spx[i] 72 | sp = self.convs[i](sp) 73 | sp = self.relu(self.bns[i](sp)) 74 | if i==0: 75 | out = sp 76 | else: 77 | out = torch.cat((out, sp), 1) 78 | if self.scale != 1 and self.stype=='normal': 79 | out = torch.cat((out, spx[self.nums]),1) 80 | elif self.scale != 1 and self.stype=='stage': 81 | out = torch.cat((out, self.pool(spx[self.nums])),1) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class Res2Net(nn.Module): 96 | 97 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): 98 | self.inplanes = 64 99 | super(Res2Net, self).__init__() 100 | self.baseWidth = baseWidth 101 | self.scale = scale 102 | self.conv1 = nn.Sequential( 103 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 104 | nn.BatchNorm2d(32), 105 | nn.ReLU(inplace=True), 106 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 107 | nn.BatchNorm2d(32), 108 | nn.ReLU(inplace=True), 109 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 110 | ) 111 | self.bn1 = nn.BatchNorm2d(64) 112 | self.relu = nn.ReLU() 113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 114 | self.layer1 = self._make_layer(block, 64, layers[0]) 115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 116 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 117 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 118 | self.avgpool = nn.AdaptiveAvgPool2d(1) 119 | self.fc = nn.Linear(512 * block.expansion, num_classes) 120 | 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 124 | elif isinstance(m, nn.BatchNorm2d): 125 | nn.init.constant_(m.weight, 1) 126 | nn.init.constant_(m.bias, 0) 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | nn.AvgPool2d(kernel_size=stride, stride=stride, 133 | ceil_mode=True, count_include_pad=False), 134 | nn.Conv2d(self.inplanes, planes * block.expansion, 135 | kernel_size=1, stride=1, bias=False), 136 | nn.BatchNorm2d(planes * block.expansion), 137 | ) 138 | 139 | layers = [] 140 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 141 | stype='stage', baseWidth = self.baseWidth, scale=self.scale)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.relu(x) 153 | x0 = self.maxpool(x) 154 | 155 | 156 | x1 = self.layer1(x0) 157 | x2 = self.layer2(x1) 158 | x3 = self.layer3(x2) 159 | x4 = self.layer4(x3) 160 | 161 | x5 = self.avgpool(x4) 162 | x6 = x5.view(x5.size(0), -1) 163 | x7 = self.fc(x6) 164 | 165 | return x7 166 | 167 | 168 | 169 | class Res2Net_Ours(nn.Module): 170 | 171 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): 172 | self.inplanes = 64 173 | super(Res2Net_Ours, self).__init__() 174 | 175 | self.baseWidth = baseWidth 176 | self.scale = scale 177 | self.conv1 = nn.Sequential( 178 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 179 | nn.BatchNorm2d(32), 180 | nn.ReLU(inplace=True), 181 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 182 | nn.BatchNorm2d(32), 183 | nn.ReLU(inplace=True), 184 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 185 | ) 186 | self.bn1 = nn.BatchNorm2d(64) 187 | self.relu = nn.ReLU() 188 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 189 | self.layer1 = self._make_layer(block, 64, layers[0]) 190 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 191 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 192 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 193 | 194 | 195 | for m in self.modules(): 196 | if isinstance(m, nn.Conv2d): 197 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 198 | elif isinstance(m, nn.BatchNorm2d): 199 | nn.init.constant_(m.weight, 1) 200 | nn.init.constant_(m.bias, 0) 201 | 202 | def _make_layer(self, block, planes, blocks, stride=1): 203 | downsample = None 204 | if stride != 1 or self.inplanes != planes * block.expansion: 205 | downsample = nn.Sequential( 206 | nn.AvgPool2d(kernel_size=stride, stride=stride, 207 | ceil_mode=True, count_include_pad=False), 208 | nn.Conv2d(self.inplanes, planes * block.expansion, 209 | kernel_size=1, stride=1, bias=False), 210 | nn.BatchNorm2d(planes * block.expansion), 211 | ) 212 | 213 | layers = [] 214 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 215 | stype='stage', baseWidth = self.baseWidth, scale=self.scale)) 216 | self.inplanes = planes * block.expansion 217 | for i in range(1, blocks): 218 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale)) 219 | 220 | return nn.Sequential(*layers) 221 | 222 | def forward(self, x): 223 | 224 | x = self.conv1(x) 225 | x = self.bn1(x) 226 | x = self.relu(x) 227 | x0 = self.maxpool(x) 228 | 229 | 230 | x1 = self.layer1(x0) 231 | x2 = self.layer2(x1) 232 | x3 = self.layer3(x2) 233 | x4 = self.layer4(x3) 234 | 235 | 236 | return x0,x1,x2,x3,x4 237 | 238 | 239 | 240 | def res2net50_v1b(pretrained=False, **kwargs): 241 | """Constructs a Res2Net-50_v1b model. 242 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 247 | if pretrained: 248 | #model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'],map_location='cpu')) 249 | 250 | model_state = torch.load('./lib/res2net50_v1b_26w_4s-3cf99910.pth') 251 | model.load_state_dict(model_state) 252 | 253 | return model 254 | 255 | def res2net101_v1b(pretrained=False, **kwargs): 256 | """Constructs a Res2Net-50_v1b_26w_4s model. 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | """ 260 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 261 | if pretrained: 262 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 263 | return model 264 | 265 | 266 | 267 | def res2net50_v1b_Ours(pretrained=False, **kwargs): 268 | """Constructs a Res2Net-50_v1b model. 269 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | """ 273 | model = Res2Net_Ours(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 274 | if pretrained: 275 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 276 | return model 277 | 278 | def res2net101_v1b_Ours(pretrained=False, **kwargs): 279 | """Constructs a Res2Net-50_v1b_26w_4s model. 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | """ 283 | model = Res2Net_Ours(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 284 | if pretrained: 285 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 286 | return model 287 | 288 | 289 | 290 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 291 | """Constructs a Res2Net-50_v1b_26w_4s model. 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | """ 295 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) 296 | if pretrained: 297 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'],map_location='cpu')) 298 | return model 299 | 300 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 301 | """Constructs a Res2Net-50_v1b_26w_4s model. 302 | Args: 303 | pretrained (bool): If True, returns a model pre-trained on ImageNet 304 | """ 305 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) 306 | if pretrained: 307 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 308 | return model 309 | 310 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 311 | """Constructs a Res2Net-50_v1b_26w_4s model. 312 | Args: 313 | pretrained (bool): If True, returns a model pre-trained on ImageNet 314 | """ 315 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth = 26, scale = 4, **kwargs) 316 | if pretrained: 317 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'])) 318 | return model 319 | 320 | 321 | 322 | 323 | def Res2Net_model(ind=50): 324 | 325 | if ind == 50: 326 | model_base = res2net50_v1b(pretrained=False) 327 | model = res2net50_v1b_Ours() 328 | 329 | if ind == 101: 330 | model_base = res2net101_v1b(pretrained=True) 331 | model = res2net101_v1b_Ours() 332 | 333 | 334 | pretrained_dict = model_base.state_dict() 335 | model_dict = model.state_dict() 336 | 337 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 338 | 339 | model_dict.update(pretrained_dict) 340 | model.load_state_dict(model_dict) 341 | 342 | return model 343 | 344 | 345 | 346 | 347 | 348 | if __name__ == '__main__': 349 | images = torch.rand(1, 3, 352, 352) 350 | model = res2net50_v1b_26w_4s(pretrained=False) 351 | model = model 352 | print(model(images).size()) 353 | -------------------------------------------------------------------------------- /mindspore/lib/model.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | from mindspore import Tensor, ops, Parameter, nn, common 4 | import mindspore as ms 5 | from lib.res2net_v1b_base import Res2Net_model 6 | 7 | 8 | class global_module(nn.Cell): 9 | def __init__(self, channels=64, r=4): 10 | super(global_module, self).__init__() 11 | out_channels = int(channels // r) 12 | # local_att 13 | 14 | self.global_att = nn.SequentialCell( 15 | nn.AdaptiveAvgPool2d(1), 16 | nn.Conv2d(channels, out_channels, kernel_size=1, stride=1, pad_mode='valid', has_bias=True), 17 | nn.BatchNorm2d(out_channels, use_batch_statistics=True), 18 | nn.ReLU(), 19 | nn.Conv2d(out_channels, channels, kernel_size=1, stride=1, pad_mode='valid', has_bias=True), 20 | nn.BatchNorm2d(channels, use_batch_statistics=True) 21 | ) 22 | 23 | self.sig = nn.Sigmoid() 24 | 25 | def construct(self, x): 26 | xg = self.global_att(x) 27 | out = self.sig(xg) 28 | 29 | return out 30 | 31 | 32 | class BasicConv2d(nn.Cell): 33 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, pad_mode='valid', dilation=1): 34 | super(BasicConv2d, self).__init__() 35 | self.conv = nn.Conv2d(in_planes, out_planes, 36 | kernel_size=kernel_size, stride=stride, 37 | pad_mode=pad_mode, dilation=dilation, has_bias=False) 38 | self.bn = nn.BatchNorm2d(out_planes) 39 | self.relu = nn.ReLU() 40 | 41 | def construct(self, x): 42 | x = self.conv(x) 43 | x = self.bn(x) 44 | return x 45 | 46 | 47 | class ChannelAttention(nn.Cell): 48 | def __init__(self, in_planes, ratio=16): 49 | super(ChannelAttention, self).__init__() 50 | 51 | # self.max_pool = nn.AdaptiveAvgPool2d(1) 52 | 53 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, pad_mode='valid', has_bias=False) 54 | self.relu1 = nn.ReLU() 55 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, pad_mode='valid', has_bias=False) 56 | 57 | self.sigmoid = nn.Sigmoid() 58 | 59 | def construct(self, x): 60 | _, x = ops.max(x, axis=2, keep_dims=True) 61 | _, x = ops.max(x, axis=3, keep_dims=True) 62 | # x = self.max_pool(x) 63 | max_out = self.fc2(self.relu1(self.fc1(x))) 64 | out = max_out 65 | return self.sigmoid(out) 66 | 67 | 68 | class GateFusion(nn.Cell): 69 | def __init__(self, in_planes): 70 | self.init__ = super(GateFusion, self).__init__() 71 | 72 | self.gate_1 = nn.Conv2d(in_planes * 2, 1, kernel_size=1, pad_mode='valid', has_bias=True) 73 | self.gate_2 = nn.Conv2d(in_planes * 2, 1, kernel_size=1, pad_mode='valid', has_bias=True) 74 | 75 | self.softmax = nn.Softmax(axis=1) 76 | 77 | def construct(self, x1, x2): 78 | ### 79 | cat_fea = ops.concat((x1, x2), axis=1) 80 | 81 | ### 82 | att_vec_1 = self.gate_1(cat_fea) 83 | att_vec_2 = self.gate_2(cat_fea) 84 | 85 | att_vec_cat = ops.concat((att_vec_1, att_vec_2), axis=1) 86 | att_vec_soft = self.softmax(att_vec_cat) 87 | 88 | att_soft_1, att_soft_2 = att_vec_soft[:, 0:1, :, :], att_vec_soft[:, 1:2, :, :] 89 | x_fusion = x1 * att_soft_1 + x2 * att_soft_2 90 | 91 | return x_fusion 92 | 93 | 94 | class BAM(nn.Cell): 95 | # Partial Decoder Component (Identification Module) 96 | def __init__(self, channel): 97 | super(BAM, self).__init__() 98 | 99 | self.relu = nn.ReLU() 100 | 101 | self.global_att = global_module(channel) 102 | 103 | self.conv_layer = BasicConv2d(channel * 2, channel, 3, pad_mode='same') 104 | 105 | def construct(self, x, x_boun_atten): 106 | out1 = self.conv_layer(ops.concat((x, x_boun_atten), axis=1)) 107 | out2 = self.global_att(out1) 108 | out3 = out1 * out2 109 | 110 | out = x + out3 111 | 112 | return out 113 | 114 | 115 | class CFF(nn.Cell): 116 | def __init__(self, in_channel1, in_channel2, out_channel): 117 | self.init__ = super(CFF, self).__init__() 118 | 119 | act_fn = nn.ReLU() 120 | 121 | ## ---------------------------------------- ## 122 | self.layer0 = BasicConv2d(in_channel1, out_channel // 2, 1) 123 | self.layer1 = BasicConv2d(in_channel2, out_channel // 2, 1) 124 | 125 | self.layer3_1 = nn.SequentialCell( 126 | nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 127 | nn.BatchNorm2d(out_channel // 2), act_fn) 128 | self.layer3_2 = nn.SequentialCell( 129 | nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 130 | nn.BatchNorm2d(out_channel // 2), act_fn) 131 | 132 | self.layer5_1 = nn.SequentialCell( 133 | nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, pad_mode='same', has_bias=True), 134 | nn.BatchNorm2d(out_channel // 2), act_fn) 135 | self.layer5_2 = nn.SequentialCell( 136 | nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, pad_mode='same', has_bias=True), 137 | nn.BatchNorm2d(out_channel // 2), act_fn) 138 | 139 | self.layer_out = nn.SequentialCell( 140 | nn.Conv2d(out_channel // 2, out_channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 141 | nn.BatchNorm2d(out_channel), act_fn) 142 | 143 | def construct(self, x0, x1): 144 | ## ------------------------------------------------------------------ ## 145 | x0_1 = self.layer0(x0) 146 | x1_1 = self.layer1(x1) 147 | 148 | x_3_1 = self.layer3_1(ops.concat((x0_1, x1_1), axis=1)) 149 | x_5_1 = self.layer5_1(ops.concat((x1_1, x0_1), axis=1)) 150 | 151 | x_3_2 = self.layer3_2(ops.concat((x_3_1, x_5_1), axis=1)) 152 | x_5_2 = self.layer5_2(ops.concat((x_5_1, x_3_1), axis=1)) 153 | 154 | out = self.layer_out(x0_1 + x1_1 + x_3_2 * x_5_2) 155 | 156 | return out 157 | 158 | 159 | ############################################################################### 160 | ## 2022/01/03 161 | ############################################################################### 162 | class CFANet(nn.Cell): 163 | # resnet based encoder decoder 164 | def __init__(self, channel=64, opt=None): 165 | super(CFANet, self).__init__() 166 | 167 | act_fn = nn.ReLU() 168 | 169 | self.resnet = Res2Net_model(50) 170 | self.downSample = nn.MaxPool2d(kernel_size=2, stride=2) 171 | 172 | ## ---------------------------------------- ## 173 | 174 | self.layer0 = nn.SequentialCell( 175 | nn.Conv2d(64, channel, kernel_size=3, stride=2, pad_mode='same', has_bias=True), 176 | nn.BatchNorm2d(channel ), 177 | act_fn) 178 | self.layer1 = nn.SequentialCell( 179 | nn.Conv2d(256, channel, kernel_size=3, stride=2, pad_mode='same', has_bias=True), 180 | nn.BatchNorm2d(channel ), 181 | act_fn) 182 | 183 | self.low_fusion = GateFusion(channel) 184 | 185 | self.high_fusion1 = CFF(256, 512, channel) 186 | self.high_fusion2 = CFF(1024, 2048, channel) 187 | 188 | ## ---------------------------------------- ## 189 | self.layer_edge0 = nn.SequentialCell( 190 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 191 | nn.BatchNorm2d(channel ), 192 | act_fn) 193 | self.layer_edge1 = nn.SequentialCell( 194 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 195 | nn.BatchNorm2d(channel ), 196 | act_fn) 197 | self.layer_edge2 = nn.SequentialCell( 198 | nn.Conv2d(channel, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 199 | nn.BatchNorm2d(64 ), 200 | act_fn) 201 | self.layer_edge3 = nn.Conv2d(64, 1, kernel_size=1, pad_mode='valid', has_bias=True) 202 | 203 | ## ---------------------------------------- ## 204 | self.layer_cat_ori1 = nn.SequentialCell( 205 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 206 | nn.BatchNorm2d(channel ), 207 | act_fn) 208 | self.layer_hig01 = nn.SequentialCell( 209 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 210 | nn.BatchNorm2d(channel ), 211 | act_fn) 212 | 213 | self.layer_cat11 = nn.SequentialCell( 214 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 215 | nn.BatchNorm2d(channel ), 216 | act_fn) 217 | self.layer_hig11 = nn.SequentialCell( 218 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 219 | nn.BatchNorm2d(channel ), 220 | act_fn) 221 | 222 | self.layer_cat21 = nn.SequentialCell( 223 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 224 | nn.BatchNorm2d(channel ), 225 | act_fn) 226 | self.layer_hig21 = nn.SequentialCell( 227 | nn.Conv2d(channel, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 228 | nn.BatchNorm2d(64 ), 229 | act_fn) 230 | 231 | self.layer_cat31 = nn.SequentialCell( 232 | nn.Conv2d(64 * 2, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 233 | nn.BatchNorm2d(64 ), 234 | act_fn) 235 | self.layer_hig31 = nn.Conv2d(64, 1, kernel_size=1, pad_mode='valid', has_bias=True) 236 | 237 | self.layer_cat_ori2 = nn.SequentialCell( 238 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 239 | nn.BatchNorm2d(channel ), 240 | act_fn) 241 | self.layer_hig02 = nn.SequentialCell( 242 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 243 | nn.BatchNorm2d(channel ), 244 | act_fn) 245 | 246 | self.layer_cat12 = nn.SequentialCell( 247 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 248 | nn.BatchNorm2d(channel ), 249 | act_fn) 250 | self.layer_hig12 = nn.SequentialCell( 251 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 252 | nn.BatchNorm2d(channel ), 253 | act_fn) 254 | 255 | self.layer_cat22 = nn.SequentialCell( 256 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 257 | nn.BatchNorm2d(channel ), 258 | act_fn) 259 | self.layer_hig22 = nn.SequentialCell( 260 | nn.Conv2d(channel, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 261 | nn.BatchNorm2d(64 ), 262 | act_fn) 263 | 264 | self.layer_cat32 = nn.SequentialCell( 265 | nn.Conv2d(64 * 2, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=True), 266 | nn.BatchNorm2d(64), 267 | act_fn) 268 | self.layer_hig32 = nn.Conv2d(64, 1, kernel_size=1, pad_mode='valid', has_bias=True) 269 | 270 | self.layer_fil = nn.Conv2d(64, 1, kernel_size=1, pad_mode='valid', has_bias=True) 271 | 272 | ## ---------------------------------------- ## 273 | 274 | self.atten_edge_0 = ChannelAttention(channel) 275 | self.atten_edge_1 = ChannelAttention(channel) 276 | self.atten_edge_2 = ChannelAttention(channel) 277 | self.atten_edge_ori = ChannelAttention(channel) 278 | 279 | self.cat_01 = BAM(channel) 280 | self.cat_11 = BAM(channel) 281 | self.cat_21 = BAM(channel) 282 | self.cat_31 = BAM(channel) 283 | 284 | self.cat_02 = BAM(channel) 285 | self.cat_12 = BAM(channel) 286 | self.cat_22 = BAM(channel) 287 | self.cat_32 = BAM(channel) 288 | 289 | ## ---------------------------------------- ## 290 | self.downSample = nn.MaxPool2d(kernel_size=2, stride=2) 291 | self.up = nn.ResizeBilinear() 292 | 293 | def construct(self, xx): 294 | # ---- feature abstraction ----- 295 | 296 | x0, x1, x2, x3, x4 = self.resnet(xx) 297 | 298 | ## -------------------------------------- ## 299 | 300 | x0_1 = self.layer0(x0) 301 | x1_1 = self.layer1(x1) 302 | 303 | low_x = self.low_fusion(x0_1, x1_1) # 64*44 304 | 305 | edge_out0 = self.layer_edge0(self.up(low_x, scale_factor=2, align_corners=True)) # 64*88 306 | edge_out1 = self.layer_edge1(self.up(edge_out0, scale_factor=2, align_corners=True)) # 64*176 307 | edge_out2 = self.layer_edge2(self.up(edge_out1, scale_factor=2, align_corners=True)) # 64*352 308 | edge_out3 = self.layer_edge3(edge_out2) 309 | 310 | etten_edge_ori = self.atten_edge_ori(low_x) 311 | etten_edge_0 = self.atten_edge_0(edge_out0) 312 | etten_edge_1 = self.atten_edge_1(edge_out1) 313 | etten_edge_2 = self.atten_edge_2(edge_out2) 314 | 315 | ## -------------------------------------- ## 316 | high_x01 = self.high_fusion1(self.downSample(x1), x2) 317 | high_x02 = self.high_fusion2(self.up(x3, scale_factor=2, align_corners=True), 318 | self.up(x4, scale_factor=4, align_corners=True)) 319 | 320 | ## --------------- high 1 ----------------------- # 321 | cat_out_01 = self.cat_01(high_x01, low_x * etten_edge_ori) 322 | hig_out01 = self.layer_hig01(self.up(cat_out_01, scale_factor=2, align_corners=True)) 323 | 324 | cat_out11 = self.cat_11(hig_out01, edge_out0 * (etten_edge_0)) 325 | hig_out11 = self.layer_hig11(self.up(cat_out11, scale_factor=2, align_corners=True)) 326 | 327 | cat_out21 = self.cat_21(hig_out11, edge_out1 * (etten_edge_1)) 328 | hig_out21 = self.layer_hig21(self.up(cat_out21, scale_factor=2, align_corners=True)) 329 | 330 | cat_out31 = self.cat_31(hig_out21, edge_out2 * (etten_edge_2)) 331 | sal_out1 = self.layer_hig31(cat_out31) 332 | 333 | ## ---------------- high 2 ---------------------- ## 334 | cat_out_02 = self.cat_02(high_x02, low_x * (etten_edge_ori)) 335 | hig_out02 = self.layer_hig02(self.up(cat_out_02, scale_factor=2, align_corners=True)) 336 | 337 | cat_out12 = self.cat_12(hig_out02, edge_out0 * (etten_edge_0)) 338 | hig_out12 = self.layer_hig12(self.up(cat_out12, scale_factor=2, align_corners=True)) 339 | 340 | cat_out22 = self.cat_22(hig_out12, edge_out1 * (etten_edge_1)) 341 | hig_out22 = self.layer_hig22(self.up(cat_out22, scale_factor=2, align_corners=True)) 342 | 343 | cat_out32 = self.cat_32(hig_out22, edge_out2 * (etten_edge_2)) 344 | sal_out2 = self.layer_hig32(cat_out32) 345 | 346 | ## --------------------------------------------- ## 347 | sal_out3 = self.layer_fil(cat_out31 + cat_out32) 348 | 349 | # ---- output ---- 350 | return edge_out3, sal_out1, sal_out2, sal_out3 351 | -------------------------------------------------------------------------------- /mindspore/lib/res2net_v1b_base.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | from mindspore import Tensor, ops, Parameter, nn, common 4 | import mindspore as ms 5 | import math 6 | 7 | __all__ = ['Res2Net', 'res2net50_v1b'] 8 | 9 | # model_urls = { 10 | # 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 11 | # 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 12 | # } 13 | 14 | 15 | class Bottle2neck(nn.Cell): 16 | expansion = 4 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'): 19 | """ Constructor 20 | Args: 21 | inplanes: input channel dimensionality 22 | planes: output channel dimensionality 23 | stride: conv stride. Replaces pooling layer. 24 | downsample: None when stride = 1 25 | baseWidth: basic width of conv3x3 26 | scale: number of scale. 27 | type: 'normal': normal set. 'stage': first block of a new stage. 28 | """ 29 | super(Bottle2neck, self).__init__() 30 | 31 | width = int(math.floor(planes * (baseWidth / 64.0))) 32 | self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, pad_mode='valid', has_bias=False) 33 | self.bn1 = nn.BatchNorm2d(width * scale) 34 | 35 | if scale == 1: 36 | self.nums = 1 37 | else: 38 | self.nums = scale - 1 39 | if stype == 'stage': 40 | self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, pad_mode="same") 41 | convs = [] 42 | bns = [] 43 | for i in range(self.nums): 44 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, pad_mode='same', has_bias=False)) 45 | bns.append(nn.BatchNorm2d(width)) 46 | self.convs = nn.CellList(convs) 47 | self.bns = nn.CellList(bns) 48 | 49 | self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, pad_mode='valid', has_bias=False) 50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 51 | 52 | self.relu = nn.ReLU() 53 | self.downsample = downsample 54 | self.stype = stype 55 | self.scale = scale 56 | self.width = width 57 | 58 | def construct(self, x): 59 | residual = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | sp = None 66 | spx = ops.Split(1, self.scale)(out) 67 | for i in range(self.nums): 68 | if i == 0 or self.stype == 'stage': 69 | sp = spx[i] 70 | else: 71 | sp = sp + spx[i] 72 | sp = self.convs[i](sp) 73 | sp = self.relu(self.bns[i](sp)) 74 | if i == 0: 75 | out = sp 76 | else: 77 | out = ops.Concat(1)((out, sp)) 78 | if self.scale != 1 and self.stype == 'normal': 79 | out = ops.Concat(1)((out, spx[self.nums])) # torch.cat((out, spx[self.nums]), 1) 80 | elif self.scale != 1 and self.stype == 'stage': 81 | out = ops.Concat(1)((out, self.pool(spx[self.nums]))) # torch.cat((out, self.pool(spx[self.nums])), 1) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class Res2Net(nn.Cell): 96 | 97 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): 98 | self.inplanes = 64 99 | super(Res2Net, self).__init__() 100 | self.baseWidth = baseWidth 101 | self.scale = scale 102 | self.conv1 = nn.SequentialCell( 103 | nn.Conv2d(3, 32, kernel_size=3, stride=2, pad_mode='same', has_bias=False), 104 | nn.BatchNorm2d(32), 105 | nn.ReLU(), 106 | nn.Conv2d(32, 32, kernel_size=3, stride=1, pad_mode='same', has_bias=False), 107 | nn.BatchNorm2d(32), 108 | nn.ReLU(), 109 | nn.Conv2d(32, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=False) 110 | ) 111 | self.bn1 = nn.BatchNorm2d(64) 112 | self.relu = nn.ReLU() 113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') 114 | self.layer1 = self._make_layer(block, 64, layers[0]) 115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 116 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 117 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 118 | self.avgpool = nn.AdaptiveAvgPool2d(1) 119 | self.fc = nn.Dense(512 * block.expansion, num_classes) 120 | 121 | # for m in self.cells(): 122 | # if isinstance(m, nn.Conv2d): 123 | # common.initializer.HeNormal(mode='fan_out', nonlinearity='relu')(m.weight) 124 | # elif isinstance(m, nn.BatchNorm2d): 125 | # print(m) 126 | # common.initializer.Constant(1)(m.weight) 127 | # common.initializer.Constant(0)(m.bias) 128 | 129 | def _make_layer(self, block, planes, blocks, stride=1): 130 | downsample = None 131 | if stride != 1 or self.inplanes != planes * block.expansion: 132 | downsample = nn.SequentialCell( 133 | nn.AvgPool2d(kernel_size=stride, stride=stride), 134 | nn.Conv2d(self.inplanes, planes * block.expansion, 135 | kernel_size=1, stride=1, pad_mode='valid', has_bias=False), 136 | nn.BatchNorm2d(planes * block.expansion), 137 | ) 138 | 139 | layers = [] 140 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 141 | stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) 145 | 146 | return nn.SequentialCell(*layers) 147 | 148 | def construct(self, x): 149 | 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.relu(x) 153 | x0 = self.maxpool(x) 154 | 155 | x1 = self.layer1(x0) 156 | x2 = self.layer2(x1) 157 | x3 = self.layer3(x2) 158 | x4 = self.layer4(x3) 159 | 160 | x5 = self.avgpool(x4) 161 | x6 = x5.view(x5.size(0), -1) 162 | x7 = self.fc(x6) 163 | 164 | return x7 165 | 166 | 167 | class Res2Net_Ours(nn.Cell): 168 | 169 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): 170 | self.inplanes = 64 171 | super(Res2Net_Ours, self).__init__() 172 | 173 | self.baseWidth = baseWidth 174 | self.scale = scale 175 | self.conv1 = nn.SequentialCell( 176 | nn.Conv2d(3, 32, kernel_size=3, stride=2, pad_mode='same', has_bias=False), 177 | nn.BatchNorm2d(32), 178 | nn.ReLU(), 179 | nn.Conv2d(32, 32, kernel_size=3, stride=1, pad_mode='same', has_bias=False), 180 | nn.BatchNorm2d(32), 181 | nn.ReLU(), 182 | nn.Conv2d(32, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=False) 183 | ) 184 | self.bn1 = nn.BatchNorm2d(64) 185 | self.relu = nn.ReLU() 186 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') 187 | self.layer1 = self._make_layer(block, 64, layers[0]) 188 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 189 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 190 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 191 | 192 | # for m in self.cells(): 193 | # if isinstance(m, nn.Conv2d): 194 | # common.initializer.HeNormal(mode='fan_out', nonlinearity='relu')(m.weight) 195 | # elif isinstance(m, nn.BatchNorm2d): 196 | # print(m) 197 | # common.initializer.Constant(1)(m.weight) 198 | # common.initializer.Constant(0)(m.bias) 199 | 200 | def _make_layer(self, block, planes, blocks, stride=1): 201 | downsample = None 202 | if stride != 1 or self.inplanes != planes * block.expansion: 203 | downsample = nn.SequentialCell( 204 | nn.AvgPool2d(kernel_size=stride, stride=stride), 205 | nn.Conv2d(self.inplanes, planes * block.expansion, 206 | kernel_size=1, stride=1, pad_mode='valid', has_bias=False), 207 | nn.BatchNorm2d(planes * block.expansion), 208 | ) 209 | 210 | layers = [] 211 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 212 | stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 213 | self.inplanes = planes * block.expansion 214 | for i in range(1, blocks): 215 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) 216 | 217 | return nn.SequentialCell(*layers) 218 | 219 | def construct(self, x): 220 | 221 | x = self.conv1(x) 222 | x = self.bn1(x) 223 | x = self.relu(x) 224 | x0 = self.maxpool(x) 225 | 226 | x1 = self.layer1(x0) 227 | x2 = self.layer2(x1) 228 | x3 = self.layer3(x2) 229 | x4 = self.layer4(x3) 230 | 231 | return x0, x1, x2, x3, x4 232 | 233 | 234 | def res2net50_v1b(pretrained=False, **kwargs): 235 | """Constructs a Res2Net-50_v1b model. 236 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 237 | Args: 238 | pretrained (bool): If True, returns a model pre-trained on ImageNet 239 | """ 240 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 241 | if pretrained: 242 | # model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'],map_location='cpu')) 243 | param_dict = ms.load_checkpoint( 244 | './lib/res2net50_v1b_26w_4s-3cf99910.ckpt') 245 | ms.load_param_into_net(model, param_dict) 246 | 247 | return model 248 | 249 | 250 | def res2net50_v1b_Ours(pretrained=False, **kwargs): 251 | """Constructs a Res2Net-50_v1b model. 252 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 253 | Args: 254 | pretrained (bool): If True, returns a model pre-trained on ImageNet 255 | """ 256 | model = Res2Net_Ours(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 257 | if pretrained: 258 | param_dict = ms.load_checkpoint( 259 | './lib/res2net50_v1b_26w_4s-3cf99910.ckpt') 260 | ms.load_param_into_net(model, param_dict) 261 | return model 262 | 263 | 264 | def Res2Net_model(ind=50): 265 | if ind == 50: 266 | model_base = res2net50_v1b(pretrained=False) 267 | model = res2net50_v1b_Ours() 268 | 269 | pretrained_dict = model_base.parameters_dict() 270 | model_dict = model.parameters_dict() 271 | 272 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 273 | 274 | ms.load_param_into_net(model, pretrained_dict) 275 | 276 | return model 277 | -------------------------------------------------------------------------------- /mindspore/test.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import mindspore.numpy as np 4 | import os, argparse 5 | import mindspore as ms 6 | from mindspore import ops, nn 7 | from lib.model import CFANet 8 | from utils.dataloader import get_loader, get_loader_test 9 | 10 | import cv2 11 | 12 | ms.set_context(device_target='GPU') 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--testsize', type=int, default=352, help='testing size') 16 | parser.add_argument('--pth_path', type=str, default='./checkpoint/CFANet.pth') 17 | 18 | for _data_name in ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']: 19 | 20 | print('-----------starting -------------') 21 | 22 | # data_path = '/test/Polpy/Dataset/TestDataset/{}/'.format(_data_name) 23 | data_path = '../data/TestDataset/{}/'.format(_data_name) 24 | save_path = './Snapshot/seg_maps/{}/'.format(_data_name) 25 | 26 | opt = parser.parse_args() 27 | model = CFANet(channel=64) 28 | 29 | param_dict = ms.load_checkpoint(opt.pth_path) 30 | ms.load_param_into_net(model, param_dict) 31 | model.set_train(False) 32 | 33 | os.makedirs(save_path, exist_ok=True) 34 | 35 | image_root = '{}/images/'.format(data_path) 36 | gt_root = '{}/masks/'.format(data_path) 37 | 38 | test_loader = get_loader_test(image_root, gt_root, testsize=opt.testsize) 39 | 40 | for i, pack in enumerate(test_loader.create_tuple_iterator()): 41 | print(['--------------processing-------------', i]) 42 | 43 | image, gt, name = pack 44 | image = ops.ExpandDims()(image, 0) 45 | gt = np.asarray(gt, np.float32) 46 | gt /= (gt.max() + 1e-8) 47 | 48 | _, _, _, res = model(image) 49 | 50 | res = nn.ResizeBilinear()(res, size=gt.shape, align_corners=False) 51 | res = res.sigmoid()[0] 52 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 53 | cv2.imwrite(save_path + str(name.asnumpy()), ops.Transpose()(res, (1, 2, 0)).asnumpy() * 255) 54 | -------------------------------------------------------------------------------- /mindspore/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pdb 4 | 5 | import numpy as np 6 | import logging 7 | from datetime import datetime 8 | 9 | import mindspore as ms 10 | from mindspore import context 11 | from mindspore import nn, ops 12 | 13 | from lib.model import CFANet 14 | from utils.dataloader import get_loader, get_loader_test 15 | 16 | ms.set_context(device_target='GPU', mode=context.GRAPH_MODE) 17 | 18 | import cv2 19 | 20 | 21 | class Trainer: 22 | def __init__(self, train_loader, test_loader, model, optimizer, opt): 23 | self.train_loader = train_loader 24 | self.test_loader = test_loader 25 | self.model = model 26 | self.optimizer = optimizer 27 | self.opt = opt 28 | self.total_step = self.train_loader.get_dataset_size() # TODO 29 | 30 | self.grad_fn = ops.value_and_grad(self.forward_fn, None, self.optimizer.parameters, has_aux=True) 31 | self.loss_func = nn.BCEWithLogitsLoss() 32 | self.size_rates = [0.75, 1, 1.25] 33 | self.resize = nn.ResizeBilinear() 34 | self.eval_loss_func = nn.L1Loss() 35 | self.best_mae = 1 36 | self.best_epoch = 0 37 | self.decay_rate = 0.1 38 | self.decay_epoch = 30 39 | 40 | def forward_fn(self, images, gts, egs): 41 | cam_edge, sal_out1, sal_out2, sal_out3 = self.model(images) 42 | loss_edge = self.loss_func(cam_edge, egs) 43 | loss_sal1 = self.structure_loss(sal_out1, gts) 44 | loss_sal2 = self.structure_loss(sal_out2, gts) 45 | loss_sal3 = self.structure_loss(sal_out3, gts) 46 | 47 | loss_total = loss_edge + loss_sal1 + loss_sal2 + loss_sal3 48 | return loss_total, loss_edge, loss_sal1, loss_sal2, loss_sal3 49 | 50 | def train_step(self, images, gts, egs): 51 | (loss, loss_edge, loss_sal1, loss_sal2, loss_sal3), grads = self.grad_fn(images, gts, egs) 52 | self.optimizer(grads) 53 | return loss, loss_edge, loss_sal1, loss_sal2, loss_sal3 54 | 55 | def train(self, epochs): 56 | for epoch in range(1, epochs + 1): 57 | self.model.set_train(True) 58 | self.adjust_lr(epoch) 59 | for step, data_pack in enumerate(self.train_loader.create_tuple_iterator(), start=1): 60 | images, gts, egs = data_pack 61 | for rate in self.size_rates: 62 | # ---- rescale ---- 63 | trainsize = int(round(opt.trainsize * rate / 32) * 32) 64 | if rate != 1: 65 | images = self.resize(images, size=(trainsize, trainsize), align_corners=True) 66 | gts = self.resize(gts, size=(trainsize, trainsize), align_corners=True) 67 | egs = self.resize(egs, size=(trainsize, trainsize), align_corners=True) 68 | loss, loss_edge, loss_sal1, loss_sal2, loss_sal3 = self.train_step(images, gts, egs) 69 | 70 | # -- output loss -- # 71 | if step % 10 == 0 or step == self.total_step: 72 | print( 73 | '[{}] => [Epoch Num: {:03d}/{:03d}] => [Global Step: {:04d}/{:04d}] => [Loss_edge: {:.4f} Loss_sal1: {:0.4f} Loss_sal2: {:0.4f} Loss_sal3: {:0.4f} Loss_total: {:0.4f}]'. 74 | format(datetime.now(), epoch, epochs, step, self.total_step, loss_edge.asnumpy(), 75 | loss_sal1.asnumpy(), loss_sal2.asnumpy(), loss_sal3.asnumpy(), loss.asnumpy())) 76 | 77 | logging.info( 78 | '#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss_edge: {:.4f} Loss_sal1: {:0.4f} Loss_sal2: {:0.4f} Loss_sal3: {:0.4f} Loss_total: {:0.4f}'. 79 | format(epoch, opt.epoch, step, self.total_step, loss_edge.asnumpy(), loss_sal1.asnumpy(), 80 | loss_sal2.asnumpy(), loss_sal3.asnumpy(), loss.asnumpy())) 81 | 82 | self.test(epoch) 83 | 84 | if epoch % self.opt.save_epoch == 0: 85 | ms.save_checkpoint(model, os.path.join(save_path, 'CODNet_%d.pth' % (epoch))) 86 | 87 | def test(self, epoch): 88 | self.model.set_train(False) 89 | mae_sum = 0 90 | for i, pack in enumerate(self.test_loader.create_tuple_iterator(), start=1): 91 | # ---- data prepare ---- 92 | image, gt, name = pack 93 | image = ops.ExpandDims()(image, 0) 94 | 95 | # ---- forward ---- 96 | _, _, _, res = self.model(image) 97 | res = self.resize(res, size=gt.shape, align_corners=False) 98 | # pdb.set_trace() 99 | res = nn.Sigmoid()(res[0][0]) # TODO 100 | # pdb.set_trace() 101 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 102 | # pdb.set_trace() 103 | mae_sum += self.eval_loss_func(res, gt) 104 | # pdb.set_trace() 105 | # mae_sum += np.sum(np.abs(res - gt)) * 1.0 / (gt.shape[0] * gt.shape[1]) 106 | 107 | # ---- recording loss ---- 108 | mae = mae_sum / self.test_loader.get_dataset_size() 109 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, self.best_mae, self.best_epoch)) 110 | if epoch == 1: 111 | self.best_mae = mae 112 | self.best_epoch = epoch 113 | else: 114 | if mae < self.best_mae: 115 | self.best_mae = mae 116 | self.best_epoch = epoch 117 | 118 | ms.save_checkpoint(model, os.path.join(save_path, 'Cod_best.ckpt')) 119 | print('best epoch:{}'.format(epoch)) 120 | 121 | 122 | def structure_loss(self, pred, mask): 123 | pred = nn.Sigmoid()(pred) 124 | weit = 1 + 5 * ops.Abs()(ops.AvgPool(kernel_size=31, strides=1, pad_mode='same')(mask) - mask) 125 | wbce = nn.BCELoss(reduction='none')(pred, mask) 126 | wbce = (weit * wbce).sum(axis=(2, 3)) / weit.sum(axis=(2, 3)) 127 | 128 | inter = ((pred * mask) * weit).sum(axis=(2, 3)) 129 | union = ((pred + mask) * weit).sum(axis=(2, 3)) 130 | wiou = 1 - (inter + 1) / (union - inter + 1) 131 | 132 | return (wbce + wiou).mean() 133 | 134 | def adjust_lr(self, epoch): 135 | decay = self.decay_rate ** (epoch // self.decay_epoch) # TODO 136 | self.optimizer.get_lr().set_data(self.opt.lr * decay) 137 | 138 | 139 | if __name__ == "__main__": 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument('--epoch', type=int, default=200, help='epoch number, default=30') 142 | parser.add_argument('--lr', type=float, default=1e-4, help='init learning rate, try `lr=1e-4`') 143 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size (Note: ~500MB per img in GPU)') 144 | parser.add_argument('--trainsize', type=int, default=352, 145 | help='the size of training image, try small resolutions for speed (like 256)') 146 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 147 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate per decay step') 148 | parser.add_argument('--decay_epoch', type=int, default=30, help='every N epochs decay lr') 149 | parser.add_argument('--gpu', type=int, default=0, help='choose which gpu you use') 150 | parser.add_argument('--save_epoch', type=int, default=5, help='every N epochs save your trained snapshot') 151 | parser.add_argument('--save_model', type=str, default='./Snapshot/CFANet/') 152 | 153 | parser.add_argument('--train_img_dir', type=str, default='./data/TrainDataset/images/') 154 | parser.add_argument('--train_gt_dir', type=str, default='./data/TrainDataset/masks/') 155 | parser.add_argument('--train_eg_dir', type=str, default='./data/TrainDataset/edges/') 156 | 157 | parser.add_argument('--test_img_dir', type=str, default='./data/TestDataset/CVC-300/images/') 158 | parser.add_argument('--test_gt_dir', type=str, default='./data/TestDataset/CVC-300/masks/') 159 | parser.add_argument('--test_eg_dir', type=str, default='./data/TestDataset/CVC-300/edges/') 160 | 161 | opt = parser.parse_args() 162 | 163 | ms.set_context(device_id=opt.gpu) 164 | 165 | save_path = opt.save_model 166 | os.makedirs(save_path, exist_ok=True) 167 | 168 | logging.basicConfig(filename=opt.save_model + '/log.log', 169 | format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level=logging.INFO, filemode='a', 170 | datefmt='%Y-%m-%d %I:%M:%S %p') 171 | logging.info("COD-Train") 172 | logging.info("Config") 173 | logging.info( 174 | 'epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};save_path:{};decay_epoch:{}'.format(opt.epoch, 175 | opt.lr, 176 | opt.batchsize, 177 | opt.trainsize, 178 | opt.clip, 179 | opt.decay_rate, 180 | opt.save_model, 181 | opt.decay_epoch)) 182 | 183 | # TIPS: you also can use deeper network for better performance like channel=64 184 | model = CFANet(channel=64) 185 | # print('-' * 30, model, '-' * 30) 186 | 187 | total = sum([param.nelement() for param in model.get_parameters()]) 188 | print('Number of parameter:%.2fM' % (total / 1e6)) 189 | 190 | optimizer = nn.Adam(model.trainable_params(), learning_rate=opt.lr) 191 | 192 | train_loader = get_loader(opt.train_img_dir, opt.train_gt_dir, opt.train_eg_dir, batchsize=opt.batchsize, 193 | trainsize=opt.trainsize, num_workers=12) 194 | test_loader = get_loader_test(opt.test_img_dir, opt.test_gt_dir, testsize=opt.trainsize) 195 | 196 | total_step = train_loader.get_dataset_size() 197 | print('-' * 30, "\n[Training Dataset INFO]\nimg_dir: {}\ngt_dir: {}\nLearning Rate: {}\nBatch Size: {}\n" 198 | "Training Save: {}\ntotal_num: {}\n".format(opt.train_img_dir, opt.train_gt_dir, opt.lr, 199 | opt.batchsize, opt.save_model, total_step), '-' * 30) 200 | 201 | train = Trainer(train_loader, test_loader, model, optimizer, opt) 202 | train.train(opt.epoch) 203 | -------------------------------------------------------------------------------- /mindspore/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import random 4 | import numpy as np 5 | from PIL import ImageEnhance 6 | 7 | from mindspore.dataset import transforms, vision, text 8 | from mindspore.dataset import GeneratorDataset 9 | from mindspore import ops 10 | 11 | 12 | # several data augumentation strategies 13 | def cv_random_flip(img, label, edge): 14 | # left right flip 15 | flip_flag = random.randint(0, 1) 16 | if flip_flag == 1: 17 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 18 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 19 | edge = edge.transpose(Image.FLIP_LEFT_RIGHT) 20 | return img, label, edge 21 | 22 | 23 | def randomCrop(image, label, edge): 24 | border = 30 25 | image_width = image.size[0] 26 | image_height = image.size[1] 27 | crop_win_width = np.random.randint(image_width - border, image_width) 28 | crop_win_height = np.random.randint(image_height - border, image_height) 29 | random_region = ( 30 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 31 | (image_height + crop_win_height) >> 1) 32 | return image.crop(random_region), label.crop(random_region), edge.crop(random_region) 33 | 34 | 35 | def randomRotation(image, label, edge): 36 | mode = Image.BICUBIC 37 | if random.random() > 0.8: 38 | random_angle = np.random.randint(-15, 15) 39 | image = image.rotate(random_angle, mode) 40 | label = label.rotate(random_angle, mode) 41 | edge = edge.rotate(random_angle, mode) 42 | return image, label, edge 43 | 44 | 45 | def colorEnhance(image): 46 | bright_intensity = random.randint(5, 15) / 10.0 47 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 48 | contrast_intensity = random.randint(5, 15) / 10.0 49 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 50 | color_intensity = random.randint(0, 20) / 10.0 51 | image = ImageEnhance.Color(image).enhance(color_intensity) 52 | sharp_intensity = random.randint(0, 30) / 10.0 53 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 54 | return image 55 | 56 | 57 | def randomGaussian(image, mean=0.1, sigma=0.35): 58 | def gaussianNoisy(im, mean=mean, sigma=sigma): 59 | for _i in range(len(im)): 60 | im[_i] += random.gauss(mean, sigma) 61 | return im 62 | 63 | img = np.asarray(image) 64 | width, height = img.shape 65 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 66 | img = img.reshape([width, height]) 67 | return Image.fromarray(np.uint8(img)) 68 | 69 | 70 | def randomPeper(img): 71 | img = np.array(img) 72 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 73 | for i in range(noiseNum): 74 | 75 | randX = random.randint(0, img.shape[0] - 1) 76 | 77 | randY = random.randint(0, img.shape[1] - 1) 78 | 79 | if random.randint(0, 1) == 0: 80 | 81 | img[randX, randY] = 0 82 | 83 | else: 84 | 85 | img[randX, randY] = 255 86 | return Image.fromarray(img) 87 | 88 | 89 | def randomPeper_eg(img, edge): 90 | img = np.array(img) 91 | edge = np.array(edge) 92 | 93 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 94 | for i in range(noiseNum): 95 | 96 | randX = random.randint(0, img.shape[0] - 1) 97 | 98 | randY = random.randint(0, img.shape[1] - 1) 99 | 100 | if random.randint(0, 1) == 0: 101 | 102 | img[randX, randY] = 0 103 | edge[randX, randY] = 0 104 | 105 | else: 106 | 107 | img[randX, randY] = 255 108 | edge[randX, randY] = 255 109 | 110 | return Image.fromarray(img), Image.fromarray(edge) 111 | 112 | 113 | # dataset for training 114 | class PolypObjDataset: 115 | def __init__(self, image_root, gt_root, edge_root, trainsize): 116 | self.trainsize = trainsize 117 | # get filenames 118 | 119 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 120 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')] 121 | self.egs = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.jpg') or f.endswith('.png')] 122 | 123 | # self.grads = [grad_root + f for f in os.listdir(grad_root) if f.endswith('.jpg') 124 | # or f.endswith('.png')] 125 | # self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 126 | # or f.endswith('.png')] 127 | # sorted files 128 | self.images = sorted(self.images) 129 | self.gts = sorted(self.gts) 130 | self.egs = sorted(self.egs) 131 | 132 | # self.grads = sorted(self.grads) 133 | # self.depths = sorted(self.depths) 134 | # filter mathcing degrees of files 135 | self.filter_files() 136 | 137 | # get size of dataset 138 | self.size = len(self.images) 139 | 140 | def __getitem__(self, index): 141 | # read imgs/gts/grads/depths 142 | image = self.rgb_loader(self.images[index]) 143 | gt = self.binary_loader(self.gts[index]) 144 | eg = self.binary_loader(self.egs[index]) 145 | 146 | # data augumentation 147 | image, gt, eg = cv_random_flip(image, gt, eg) 148 | image, gt, eg = randomCrop(image, gt, eg) 149 | image, gt, eg = randomRotation(image, gt, eg) 150 | 151 | image = colorEnhance(image) 152 | gt, eg = randomPeper_eg(gt, eg) 153 | return image, gt, eg 154 | 155 | def filter_files(self): 156 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 157 | images = [] 158 | gts = [] 159 | for img_path, gt_path in zip(self.images, self.gts): 160 | img = Image.open(img_path) 161 | gt = Image.open(gt_path) 162 | if img.size == gt.size: 163 | images.append(img_path) 164 | gts.append(gt_path) 165 | self.images = images 166 | self.gts = gts 167 | 168 | def rgb_loader(self, path): 169 | with open(path, 'rb') as f: 170 | img = Image.open(f) 171 | return img.convert('RGB') 172 | 173 | def binary_loader(self, path): 174 | with open(path, 'rb') as f: 175 | img = Image.open(f) 176 | return img.convert('L') 177 | 178 | def __len__(self): 179 | return self.size 180 | 181 | 182 | # dataloader for training 183 | def get_loader(image_root, gt_root, eg_root, batchsize, trainsize, 184 | shuffle=True, num_workers=12, pin_memory=True): 185 | # transforms 186 | img_transform = transforms.Compose([ 187 | vision.Resize((trainsize, trainsize)), 188 | vision.ToTensor(), 189 | vision.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], is_hwc=False)]) 190 | gt_transform = transforms.Compose([ 191 | vision.Resize((trainsize, trainsize)), 192 | vision.ToTensor()]) 193 | 194 | eg_transform = transforms.Compose([ 195 | vision.Resize((trainsize, trainsize)), 196 | vision.ToTensor()]) 197 | 198 | dataset = PolypObjDataset(image_root, gt_root, eg_root, trainsize) 199 | data_loader = GeneratorDataset(source=dataset, column_names=["images", "gts", "egs"]) 200 | data_loader = data_loader.map(img_transform, ["images"]) 201 | data_loader = data_loader.map(gt_transform, ["gts"]) 202 | data_loader = data_loader.map(eg_transform, ["egs"]) 203 | 204 | data_loader = data_loader.batch(batch_size=batchsize, num_parallel_workers=num_workers) 205 | return data_loader 206 | 207 | 208 | def get_loader_test(image_root, gt_root, testsize): 209 | testdata = test_dataset(image_root, gt_root, testsize) 210 | test_loader = GeneratorDataset(source=testdata, column_names=["image", "gt", "name"]) 211 | transform = transforms.Compose([ 212 | vision.Resize((testsize, testsize)), 213 | vision.ToTensor(), 214 | vision.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], is_hwc=False)]) 215 | test_loader = test_loader.map(transform, "image") 216 | return test_loader 217 | 218 | 219 | class test_dataset: 220 | """load test dataset (batchsize=1)""" 221 | 222 | def __init__(self, image_root, gt_root, testsize): 223 | self.testsize = testsize 224 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 225 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')] 226 | self.images = sorted(self.images) 227 | self.gts = sorted(self.gts) 228 | # self.gt_transform = transforms.ToTensor() 229 | self.size = len(self.images) 230 | 231 | def __getitem__(self, index): 232 | image = self.rgb_loader(self.images[index]) 233 | # image = self.transform(image).unsqueeze(0) 234 | gt = self.binary_loader(self.gts[index]) 235 | name = self.images[index].split('/')[-1] 236 | if name.endswith('.jpg'): 237 | name = name.split('.jpg')[0] + '.png' 238 | return image, gt, name 239 | 240 | def rgb_loader(self, path): 241 | with open(path, 'rb') as f: 242 | img = Image.open(f) 243 | return img.convert('RGB') 244 | 245 | def binary_loader(self, path): 246 | with open(path, 'rb') as f: 247 | img = Image.open(f) 248 | return img.convert('L') 249 | 250 | def __len__(self): 251 | return self.size 252 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os 5 | import argparse 6 | 7 | from lib.model import CFANet 8 | from utils.dataloader import get_loader,test_dataset 9 | from utils.eva_funcs import eval_Smeasure,eval_mae,numpy2tensor 10 | 11 | 12 | import scipy.io as scio 13 | import cv2 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--testsize', type=int, default=352, help='testing size') 17 | parser.add_argument('--pth_path', type=str, default='./checkpoint/CFANet.pth') 18 | 19 | for _data_name in ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']: 20 | 21 | print('-----------strating -------------') 22 | 23 | data_path = '/test/Polpy/Dataset/TestDataset/{}/'.format(_data_name) 24 | save_path = './Snapshot/seg_maps/{}/'.format(_data_name) 25 | 26 | 27 | 28 | opt = parser.parse_args() 29 | model = CFANet(channel=64).cuda() 30 | 31 | 32 | model.load_state_dict(torch.load(opt.pth_path)) 33 | model.cuda() 34 | model.eval() 35 | 36 | os.makedirs(save_path, exist_ok=True) 37 | 38 | image_root = '{}/images/'.format(data_path) 39 | gt_root = '{}/masks/'.format(data_path) 40 | 41 | 42 | 43 | test_loader = test_dataset(image_root, gt_root, opt.testsize) 44 | 45 | for i in range(test_loader.size): 46 | 47 | print(['--------------processing-------------', i]) 48 | 49 | image, gt, name = test_loader.load_data() 50 | 51 | gt = np.asarray(gt, np.float32) 52 | gt /= (gt.max() + 1e-8) 53 | image = image.cuda() 54 | 55 | _,_,_,res = model(image) 56 | 57 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 58 | res = res.sigmoid().data.cpu().numpy().squeeze() 59 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 60 | 61 | cv2.imwrite(save_path+name, res*255) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import argparse 5 | 6 | from lib.model import CFANet 7 | from utils.dataloader import get_loader,test_dataset 8 | from utils.trainer import adjust_lr 9 | from datetime import datetime 10 | 11 | import torch.nn.functional as F 12 | import numpy as np 13 | import logging 14 | 15 | best_mae = 1 16 | best_epoch = 0 17 | 18 | 19 | def structure_loss(pred, mask): 20 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask) 21 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 22 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) 23 | 24 | pred = torch.sigmoid(pred) 25 | inter = ((pred*mask)*weit).sum(dim=(2,3)) 26 | union = ((pred+mask)*weit).sum(dim=(2,3)) 27 | wiou = 1-(inter+1)/(union-inter+1) 28 | return (wbce+wiou).mean() 29 | 30 | 31 | def train(train_loader, model, optimizer, epoch, opt, loss_func, total_step): 32 | """ 33 | Training iteration 34 | :param train_loader: 35 | :param model: 36 | :param optimizer: 37 | :param epoch: 38 | :param opt: 39 | :param loss_func: 40 | :param total_step: 41 | :return: 42 | """ 43 | model.train() 44 | 45 | 46 | size_rates = [0.75, 1, 1.25] 47 | 48 | 49 | for step, data_pack in enumerate(train_loader): 50 | 51 | images, gts, egs = data_pack 52 | 53 | for rate in size_rates: 54 | 55 | optimizer.zero_grad() 56 | 57 | images = images.cuda() 58 | gts = gts.cuda() 59 | egs = egs.cuda() 60 | 61 | 62 | # ---- rescale ---- 63 | trainsize = int(round(opt.trainsize*rate/32)*32) 64 | if rate != 1: 65 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 66 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 67 | egs = F.upsample(egs, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 68 | 69 | 70 | cam_edge, sal_out1, sal_out2, sal_out3 = model(images) 71 | loss_edge = loss_func(cam_edge, egs) 72 | loss_sal1 = structure_loss(sal_out1, gts) 73 | loss_sal2 = structure_loss(sal_out2, gts) 74 | loss_sal3 = structure_loss(sal_out3, gts) 75 | 76 | loss_total = loss_edge + loss_sal1 + loss_sal2 + loss_sal3 77 | 78 | loss_total.backward() 79 | optimizer.step() 80 | 81 | if step % 10 == 0 or step == total_step: 82 | print('[{}] => [Epoch Num: {:03d}/{:03d}] => [Global Step: {:04d}/{:04d}] => [Loss_edge: {:.4f} Loss_sal1: {:0.4f} Loss_sal2: {:0.4f} Loss_sal3: {:0.4f} Loss_total: {:0.4f}]'. 83 | format(datetime.now(), epoch, opt.epoch, step, total_step, loss_edge.data, loss_sal1.data, loss_sal2.data, loss_sal3.data, loss_total.data)) 84 | 85 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss_edge: {:.4f} Loss_sal1: {:0.4f} Loss_sal2: {:0.4f} Loss_sal3: {:0.4f} Loss_total: {:0.4f}'. 86 | format( epoch, opt.epoch, step, total_step, loss_edge.data, loss_sal1.data, loss_sal2.data, loss_sal3.data, loss_total.data)) 87 | 88 | 89 | if (epoch) % opt.save_epoch == 0: 90 | torch.save(model.state_dict(), save_path + 'CODNet_%d.pth' % (epoch)) 91 | 92 | 93 | def test(test_loader,model,epoch,save_path): 94 | 95 | global best_mae,best_epoch 96 | model.eval() 97 | 98 | with torch.no_grad(): 99 | mae_sum=0 100 | for i in range(test_loader.size): 101 | image, gt, name = test_loader.load_data() 102 | gt = np.asarray(gt, np.float32) 103 | 104 | gt /= (gt.max() + 1e-8) 105 | 106 | image = image.cuda() 107 | 108 | _,_,_,res = model(image) 109 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 110 | res = res.sigmoid().data.cpu().numpy().squeeze() 111 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 112 | mae_sum +=np.sum(np.abs(res-gt))*1.0/(gt.shape[0]*gt.shape[1]) 113 | 114 | mae = mae_sum / test_loader.size 115 | 116 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch,mae,best_mae,best_epoch)) 117 | if epoch == 1: 118 | best_mae = mae 119 | else: 120 | if mae < best_mae: 121 | best_mae = mae 122 | best_epoch = epoch 123 | 124 | torch.save(model.state_dict(), save_path+'/Cod_best.pth') 125 | print('best epoch:{}'.format(epoch)) 126 | 127 | 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument('--epoch', type=int, default=200, help='epoch number, default=30') 133 | parser.add_argument('--lr', type=float, default=1e-4, help='init learning rate, try `lr=1e-4`') 134 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size (Note: ~500MB per img in GPU)') 135 | parser.add_argument('--trainsize', type=int, default=352, help='the size of training image, try small resolutions for speed (like 256)') 136 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 137 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate per decay step') 138 | parser.add_argument('--decay_epoch', type=int, default=30, help='every N epochs decay lr') 139 | parser.add_argument('--gpu', type=int, default=0, help='choose which gpu you use') 140 | parser.add_argument('--save_epoch', type=int, default=5, help='every N epochs save your trained snapshot') 141 | parser.add_argument('--save_model', type=str, default='./Snapshot/CFANet/') 142 | 143 | 144 | parser.add_argument('--train_img_dir', type=str, default='./data/TrainDataset/images/') 145 | parser.add_argument('--train_gt_dir', type=str, default='./data/TrainDataset/masks/') 146 | parser.add_argument('--train_eg_dir', type=str, default='./data/TrainDataset/edges/') 147 | 148 | parser.add_argument('--test_img_dir', type=str, default='./data/TestDataset/CVC-300/images/') 149 | parser.add_argument('--test_gt_dir', type=str, default='./data/TestDataset/CVC-300/masks/') 150 | parser.add_argument('--test_eg_dir', type=str, default='./data/TestDataset/CVC-300/edges/') 151 | 152 | 153 | opt = parser.parse_args() 154 | 155 | torch.cuda.set_device(opt.gpu) 156 | 157 | save_path = opt.save_model 158 | os.makedirs(save_path, exist_ok=True) 159 | 160 | 161 | logging.basicConfig(filename=opt.save_model+'/log.log',format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level = logging.INFO,filemode='a',datefmt='%Y-%m-%d %I:%M:%S %p') 162 | logging.info("COD-Train") 163 | logging.info("Config") 164 | logging.info('epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};save_path:{};decay_epoch:{}'.format(opt.epoch,opt.lr,opt.batchsize,opt.trainsize,opt.clip,opt.decay_rate,opt.save_model,opt.decay_epoch)) 165 | 166 | 167 | 168 | # TIPS: you also can use deeper network for better performance like channel=64 169 | model = CFANet(channel=64).cuda() 170 | #print('-' * 30, model, '-' * 30) 171 | 172 | total = sum([param.nelement() for param in model.parameters()]) 173 | print('Number of parameter:%.2fM' % (total/1e6)) 174 | 175 | 176 | 177 | optimizer = torch.optim.Adam(model.parameters(), opt.lr) 178 | LogitsBCE = torch.nn.BCEWithLogitsLoss() 179 | 180 | #net, optimizer = amp.initialize(model_SINet, optimizer, opt_level='O1') # NOTES: Ox not 0x 181 | 182 | train_loader = get_loader(opt.train_img_dir, opt.train_gt_dir, opt.train_eg_dir, batchsize=opt.batchsize,trainsize=opt.trainsize, num_workers=12) 183 | test_loader = test_dataset(opt.test_img_dir, opt.test_gt_dir, testsize=opt.trainsize) 184 | 185 | total_step = len(train_loader) 186 | 187 | print('-' * 30, "\n[Training Dataset INFO]\nimg_dir: {}\ngt_dir: {}\nLearning Rate: {}\nBatch Size: {}\n" 188 | "Training Save: {}\ntotal_num: {}\n".format(opt.train_img_dir, opt.train_gt_dir, opt.lr, 189 | opt.batchsize, opt.save_model, total_step), '-' * 30) 190 | 191 | for epoch_iter in range(1, opt.epoch): 192 | 193 | adjust_lr(optimizer, epoch_iter, opt.decay_rate, opt.decay_epoch) 194 | 195 | train(train_loader, model, optimizer, epoch_iter,opt, LogitsBCE, total_step) 196 | #test(test_loader, model, epoch_iter, opt.save_model) 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/CODNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/CODNet.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/CODNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/CODNet.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/SINet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/SINet.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/SearchAttention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/SearchAttention.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/SearchAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/SearchAttention.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/res2net_v1b_base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/res2net_v1b_base.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | 10 | # several data augumentation strategies 11 | def cv_random_flip(img, label,edge): 12 | # left right flip 13 | flip_flag = random.randint(0, 1) 14 | if flip_flag == 1: 15 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 16 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 17 | edge = edge.transpose(Image.FLIP_LEFT_RIGHT) 18 | return img, label,edge 19 | 20 | 21 | def randomCrop(image, label,edge): 22 | border = 30 23 | image_width = image.size[0] 24 | image_height = image.size[1] 25 | crop_win_width = np.random.randint(image_width - border, image_width) 26 | crop_win_height = np.random.randint(image_height - border, image_height) 27 | random_region = ( 28 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 29 | (image_height + crop_win_height) >> 1) 30 | return image.crop(random_region), label.crop(random_region), edge.crop(random_region) 31 | 32 | 33 | def randomRotation(image, label, edge): 34 | mode = Image.BICUBIC 35 | if random.random() > 0.8: 36 | random_angle = np.random.randint(-15, 15) 37 | image = image.rotate(random_angle, mode) 38 | label = label.rotate(random_angle, mode) 39 | edge = edge.rotate(random_angle, mode) 40 | return image, label, edge 41 | 42 | 43 | def colorEnhance(image): 44 | bright_intensity = random.randint(5, 15) / 10.0 45 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 46 | contrast_intensity = random.randint(5, 15) / 10.0 47 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 48 | color_intensity = random.randint(0, 20) / 10.0 49 | image = ImageEnhance.Color(image).enhance(color_intensity) 50 | sharp_intensity = random.randint(0, 30) / 10.0 51 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 52 | return image 53 | 54 | 55 | def randomGaussian(image, mean=0.1, sigma=0.35): 56 | def gaussianNoisy(im, mean=mean, sigma=sigma): 57 | for _i in range(len(im)): 58 | im[_i] += random.gauss(mean, sigma) 59 | return im 60 | 61 | img = np.asarray(image) 62 | width, height = img.shape 63 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 64 | img = img.reshape([width, height]) 65 | return Image.fromarray(np.uint8(img)) 66 | 67 | 68 | def randomPeper(img): 69 | img = np.array(img) 70 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 71 | for i in range(noiseNum): 72 | 73 | randX = random.randint(0, img.shape[0] - 1) 74 | 75 | randY = random.randint(0, img.shape[1] - 1) 76 | 77 | if random.randint(0, 1) == 0: 78 | 79 | img[randX, randY] = 0 80 | 81 | else: 82 | 83 | img[randX, randY] = 255 84 | return Image.fromarray(img) 85 | 86 | 87 | def randomPeper_eg(img, edge): 88 | 89 | img = np.array(img) 90 | edge = np.array(edge) 91 | 92 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 93 | for i in range(noiseNum): 94 | 95 | randX = random.randint(0, img.shape[0] - 1) 96 | 97 | randY = random.randint(0, img.shape[1] - 1) 98 | 99 | if random.randint(0, 1) == 0: 100 | 101 | img[randX, randY] = 0 102 | edge[randX, randY] = 0 103 | 104 | else: 105 | 106 | img[randX, randY] = 255 107 | edge[randX, randY] = 255 108 | 109 | return Image.fromarray(img), Image.fromarray(edge) 110 | 111 | 112 | 113 | # dataset for training 114 | class PolypObjDataset(data.Dataset): 115 | def __init__(self, image_root, gt_root, edge_root, trainsize): 116 | self.trainsize = trainsize 117 | # get filenames 118 | 119 | 120 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 121 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')] 122 | self.egs = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.jpg') or f.endswith('.png')] 123 | 124 | 125 | 126 | # self.grads = [grad_root + f for f in os.listdir(grad_root) if f.endswith('.jpg') 127 | # or f.endswith('.png')] 128 | # self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 129 | # or f.endswith('.png')] 130 | # sorted files 131 | self.images = sorted(self.images) 132 | self.gts = sorted(self.gts) 133 | self.egs = sorted(self.egs) 134 | 135 | 136 | # self.grads = sorted(self.grads) 137 | # self.depths = sorted(self.depths) 138 | # filter mathcing degrees of files 139 | self.filter_files() 140 | # transforms 141 | self.img_transform = transforms.Compose([ 142 | transforms.Resize((self.trainsize, self.trainsize)), 143 | transforms.ToTensor(), 144 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 145 | self.gt_transform = transforms.Compose([ 146 | transforms.Resize((self.trainsize, self.trainsize)), 147 | transforms.ToTensor()]) 148 | 149 | self.eg_transform = transforms.Compose([ 150 | transforms.Resize((self.trainsize, self.trainsize)), 151 | transforms.ToTensor()]) 152 | 153 | 154 | # get size of dataset 155 | self.size = len(self.images) 156 | 157 | def __getitem__(self, index): 158 | # read imgs/gts/grads/depths 159 | image = self.rgb_loader(self.images[index]) 160 | gt = self.binary_loader(self.gts[index]) 161 | eg = self.binary_loader(self.egs[index]) 162 | 163 | # data augumentation 164 | image, gt, eg = cv_random_flip(image, gt, eg) 165 | image, gt, eg = randomCrop(image, gt, eg) 166 | image, gt, eg = randomRotation(image, gt, eg) 167 | 168 | image = colorEnhance(image) 169 | gt,eg = randomPeper_eg(gt,eg) 170 | 171 | image = self.img_transform(image) 172 | gt = self.gt_transform(gt) 173 | eg = self.eg_transform(eg) 174 | 175 | return image, gt, eg 176 | 177 | def filter_files(self): 178 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 179 | images = [] 180 | gts = [] 181 | for img_path, gt_path in zip(self.images, self.gts): 182 | img = Image.open(img_path) 183 | gt = Image.open(gt_path) 184 | if img.size == gt.size: 185 | images.append(img_path) 186 | gts.append(gt_path) 187 | self.images = images 188 | self.gts = gts 189 | 190 | def rgb_loader(self, path): 191 | with open(path, 'rb') as f: 192 | img = Image.open(f) 193 | return img.convert('RGB') 194 | 195 | def binary_loader(self, path): 196 | with open(path, 'rb') as f: 197 | img = Image.open(f) 198 | return img.convert('L') 199 | 200 | def __len__(self): 201 | return self.size 202 | 203 | 204 | # dataloader for training 205 | def get_loader(image_root, gt_root, eg_root, batchsize, trainsize, 206 | shuffle=True, num_workers=12, pin_memory=True): 207 | dataset = PolypObjDataset(image_root, gt_root, eg_root,trainsize) 208 | data_loader = data.DataLoader(dataset=dataset, 209 | batch_size=batchsize, 210 | shuffle=shuffle, 211 | num_workers=num_workers, 212 | pin_memory=pin_memory) 213 | return data_loader 214 | 215 | 216 | # test dataset and loader 217 | class test_dataset_ori: 218 | def __init__(self, image_root, gt_root, testsize): 219 | self.testsize = testsize 220 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 221 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] 222 | self.images = sorted(self.images) 223 | self.gts = sorted(self.gts) 224 | self.transform = transforms.Compose([ 225 | transforms.Resize((self.testsize, self.testsize)), 226 | transforms.ToTensor(), 227 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 228 | self.gt_transform = transforms.ToTensor() 229 | self.size = len(self.images) 230 | self.index = 0 231 | 232 | def load_data(self): 233 | image = self.rgb_loader(self.images[self.index]) 234 | image = self.transform(image).unsqueeze(0) 235 | 236 | gt = self.binary_loader(self.gts[self.index]) 237 | 238 | name = self.images[self.index].split('/')[-1] 239 | 240 | image_for_post = self.rgb_loader(self.images[self.index]) 241 | image_for_post = image_for_post.resize(gt.size) 242 | 243 | if name.endswith('.jpg'): 244 | name = name.split('.jpg')[0] + '.png' 245 | 246 | self.index += 1 247 | self.index = self.index % self.size 248 | 249 | return image, gt, name, np.array(image_for_post) 250 | 251 | def rgb_loader(self, path): 252 | with open(path, 'rb') as f: 253 | img = Image.open(f) 254 | return img.convert('RGB') 255 | 256 | def binary_loader(self, path): 257 | with open(path, 'rb') as f: 258 | img = Image.open(f) 259 | return img.convert('L') 260 | 261 | def __len__(self): 262 | return self.size 263 | 264 | 265 | class test_dataset: 266 | """load test dataset (batchsize=1)""" 267 | def __init__(self, image_root, gt_root, testsize): 268 | self.testsize = testsize 269 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 270 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')] 271 | self.images = sorted(self.images) 272 | self.gts = sorted(self.gts) 273 | self.transform = transforms.Compose([ 274 | transforms.Resize((self.testsize, self.testsize)), 275 | transforms.ToTensor(), 276 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 277 | self.gt_transform = transforms.ToTensor() 278 | self.size = len(self.images) 279 | self.index = 0 280 | 281 | def load_data(self): 282 | image = self.rgb_loader(self.images[self.index]) 283 | image = self.transform(image).unsqueeze(0) 284 | gt = self.binary_loader(self.gts[self.index]) 285 | name = self.images[self.index].split('/')[-1] 286 | if name.endswith('.jpg'): 287 | name = name.split('.jpg')[0] + '.png' 288 | self.index += 1 289 | return image, gt, name 290 | 291 | def rgb_loader(self, path): 292 | with open(path, 'rb') as f: 293 | img = Image.open(f) 294 | return img.convert('RGB') 295 | 296 | def binary_loader(self, path): 297 | with open(path, 'rb') as f: 298 | img = Image.open(f) 299 | return img.convert('L') 300 | 301 | 302 | -------------------------------------------------------------------------------- /utils/eva_funcs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Sep 29 17:21:18 2020 5 | 6 | @author: taozhou 7 | """ 8 | 9 | import os 10 | import time 11 | 12 | import numpy as np 13 | import torch 14 | from torchvision import transforms 15 | 16 | 17 | ############################################################################### 18 | ## basic funcs 19 | ############################################################################### 20 | 21 | def numpy2tensor(numpy): 22 | """ 23 | convert numpy_array in cpu to tensor in gpu 24 | :param numpy: 25 | :return: torch.from_numpy(numpy).cuda() 26 | """ 27 | return torch.from_numpy(numpy).cuda() 28 | 29 | def fun_eval_e(y_pred, y, num, cuda=True): 30 | 31 | if cuda: 32 | score = torch.zeros(num).cuda() 33 | else: 34 | score = torch.zeros(num) 35 | 36 | for i in range(num): 37 | 38 | fm = y_pred - y_pred.mean() 39 | gt = y - y.mean() 40 | align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20) 41 | enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4 42 | score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20) 43 | return score.max() 44 | 45 | 46 | def fun_eval_pr(y_pred, y, num, cuda=True): 47 | 48 | if cuda: 49 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda() 50 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 51 | else: 52 | prec, recall = torch.zeros(num), torch.zeros(num) 53 | thlist = torch.linspace(0, 1 - 1e-10, num) 54 | 55 | for i in range(num): 56 | y_temp = (y_pred >= thlist[i]).float() 57 | tp = (y_temp * y).sum() 58 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 59 | return prec, recall 60 | 61 | 62 | def fun_S_object(pred, gt): 63 | 64 | fg = torch.where(gt==0, torch.zeros_like(pred), pred) 65 | bg = torch.where(gt==1, torch.zeros_like(pred), 1-pred) 66 | o_fg = fun_object(fg, gt) 67 | o_bg = fun_object(bg, 1-gt) 68 | u = gt.mean() 69 | Q = u * o_fg + (1-u) * o_bg 70 | return Q 71 | 72 | 73 | def fun_object(pred, gt): 74 | 75 | temp = pred[gt == 1] 76 | x = temp.mean() 77 | sigma_x = temp.std() 78 | score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20) 79 | 80 | return score 81 | 82 | 83 | def fun_S_region(pred, gt): 84 | 85 | X, Y = fun_centroid(gt) 86 | gt1, gt2, gt3, gt4, w1, w2, w3, w4 = fun_divideGT(gt, X, Y) 87 | p1, p2, p3, p4 = fun_dividePrediction(pred, X, Y) 88 | Q1 = fun_ssim(p1, gt1) 89 | Q2 = fun_ssim(p2, gt2) 90 | Q3 = fun_ssim(p3, gt3) 91 | Q4 = fun_ssim(p4, gt4) 92 | Q = w1*Q1 + w2*Q2 + w3*Q3 + w4*Q4 93 | 94 | return Q 95 | 96 | def fun_centroid(gt, cuda=True): 97 | 98 | rows, cols = gt.size()[-2:] 99 | gt = gt.view(rows, cols) 100 | 101 | if gt.sum() == 0: 102 | 103 | if cuda: 104 | X = torch.eye(1).cuda() * round(cols / 2) 105 | Y = torch.eye(1).cuda() * round(rows / 2) 106 | else: 107 | X = torch.eye(1) * round(cols / 2) 108 | Y = torch.eye(1) * round(rows / 2) 109 | 110 | else: 111 | total = gt.sum() 112 | 113 | if cuda: 114 | i = torch.from_numpy(np.arange(0,cols)).cuda().float() 115 | j = torch.from_numpy(np.arange(0,rows)).cuda().float() 116 | else: 117 | i = torch.from_numpy(np.arange(0,cols)).float() 118 | j = torch.from_numpy(np.arange(0,rows)).float() 119 | 120 | X = torch.round((gt.sum(dim=0)*i).sum() / total) 121 | Y = torch.round((gt.sum(dim=1)*j).sum() / total) 122 | 123 | return X.long(), Y.long() 124 | 125 | 126 | def fun_divideGT(gt, X, Y): 127 | 128 | h, w = gt.size()[-2:] 129 | area = h*w 130 | gt = gt.view(h, w) 131 | LT = gt[:Y, :X] 132 | RT = gt[:Y, X:w] 133 | LB = gt[Y:h, :X] 134 | RB = gt[Y:h, X:w] 135 | X = X.float() 136 | Y = Y.float() 137 | w1 = X * Y / area 138 | w2 = (w - X) * Y / area 139 | w3 = X * (h - Y) / area 140 | w4 = 1 - w1 - w2 - w3 141 | 142 | return LT, RT, LB, RB, w1, w2, w3, w4 143 | 144 | def fun_dividePrediction(pred, X, Y): 145 | 146 | h, w = pred.size()[-2:] 147 | pred = pred.view(h, w) 148 | LT = pred[:Y, :X] 149 | RT = pred[:Y, X:w] 150 | LB = pred[Y:h, :X] 151 | RB = pred[Y:h, X:w] 152 | 153 | return LT, RT, LB, RB 154 | 155 | 156 | def fun_ssim(pred, gt): 157 | 158 | gt = gt.float() 159 | h, w = pred.size()[-2:] 160 | N = h*w 161 | x = pred.mean() 162 | y = gt.mean() 163 | sigma_x2 = ((pred - x)*(pred - x)).sum() / (N - 1 + 1e-20) 164 | sigma_y2 = ((gt - y)*(gt - y)).sum() / (N - 1 + 1e-20) 165 | sigma_xy = ((pred - x)*(gt - y)).sum() / (N - 1 + 1e-20) 166 | 167 | aplha = 4 * x * y *sigma_xy 168 | beta = (x*x + y*y) * (sigma_x2 + sigma_y2) 169 | 170 | if aplha != 0: 171 | Q = aplha / (beta + 1e-20) 172 | elif aplha == 0 and beta == 0: 173 | Q = 1.0 174 | else: 175 | Q = 0 176 | 177 | return Q 178 | 179 | ############################################################################### 180 | ## metric funcs 181 | ############################################################################### 182 | def eval_mae(pred,gt,cuda=True): 183 | 184 | with torch.no_grad(): 185 | 186 | trans = transforms.Compose([transforms.ToTensor()]) 187 | 188 | if cuda: 189 | pred = pred.cuda() 190 | gt = gt.cuda() 191 | # else: 192 | # pred = trans(pred) 193 | # gt = trans(gt) 194 | 195 | mae = torch.abs(pred - gt).mean() 196 | 197 | return mae.cpu().detach().numpy() 198 | 199 | 200 | def eval_Smeasure(pred,gt,cuda=True): 201 | 202 | alpha, avg_q, img_num = 0.5, 0.0, 0.0 203 | 204 | with torch.no_grad(): 205 | 206 | trans = transforms.Compose([transforms.ToTensor()]) 207 | 208 | y = gt.mean() 209 | 210 | ## 211 | if y == 0: 212 | x = pred.mean() 213 | Q = 1.0 - x 214 | elif y == 1: 215 | x = pred.mean() 216 | Q = x 217 | else: 218 | Q = alpha * fun_S_object(pred, gt) + (1-alpha) * fun_S_region(pred, gt) 219 | if Q.item() < 0: 220 | Q = torch.FLoatTensor([0.0]) 221 | 222 | return Q.item() 223 | 224 | 225 | def eval_fmeasure(pred, gt, cuda=True): 226 | print('eval[FMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 227 | 228 | beta2 = 0.3 229 | avg_p, avg_r, img_num = 0.0, 0.0, 0.0 230 | 231 | ## 232 | with torch.no_grad(): 233 | trans = transforms.Compose([transforms.ToTensor()]) 234 | if cuda: 235 | pred = trans(pred).cuda() 236 | gt = trans(gt).cuda() 237 | else: 238 | pred = trans(pred) 239 | gt = trans(gt) 240 | 241 | prec, recall = fun_eval_pr(pred, gt, 255) 242 | 243 | return prec, recall 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | class Eval_thread(): 256 | def __init__(self, loader, method, dataset, output_dir, cuda): 257 | self.loader = loader 258 | self.method = method 259 | self.dataset = dataset 260 | self.cuda = cuda 261 | self.logfile = os.path.join(output_dir, 'result.txt') 262 | def run(self): 263 | start_time = time.time() 264 | mae = self.Eval_mae() 265 | s = self.Eval_Smeasure() 266 | 267 | return mae,s 268 | 269 | #max_f = self.Eval_fmeasure() 270 | #max_e = self.Eval_Emeasure() 271 | 272 | #self.LOG('{} dataset with {} method get {:.4f} mae, {:.4f} max-fmeasure, {:.4f} max-Emeasure, {:.4f} S-measure..\n'.format(self.dataset, self.method, mae, max_f, max_e, s)) 273 | #return '[cost:{:.4f}s]{} dataset with {} method get {:.4f} mae, {:.4f} max-fmeasure, {:.4f} max-Emeasure, {:.4f} S-measure..'.format(time.time()-start_time, self.dataset, self.method, mae, max_f, max_e, s) 274 | 275 | def Eval_mae(self): 276 | 277 | with torch.no_grad(): 278 | trans = transforms.Compose([transforms.ToTensor()]) 279 | for pred, gt in self.loader: 280 | if self.cuda: 281 | 282 | pred = trans(pred).cuda() 283 | gt = trans(gt).cuda() 284 | else: 285 | pred = trans(pred) 286 | gt = trans(gt) 287 | mea = torch.abs(pred - gt).mean() 288 | if mea == mea: # for Nan 289 | avg_mae += mea 290 | img_num += 1.0 291 | avg_mae /= img_num 292 | 293 | return avg_mae.item() 294 | 295 | def Eval_fmeasure(self): 296 | print('eval[FMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 297 | beta2 = 0.3 298 | avg_p, avg_r, img_num = 0.0, 0.0, 0.0 299 | with torch.no_grad(): 300 | trans = transforms.Compose([transforms.ToTensor()]) 301 | for pred, gt in self.loader: 302 | if self.cuda: 303 | pred = trans(pred).cuda() 304 | gt = trans(gt).cuda() 305 | else: 306 | pred = trans(pred) 307 | gt = trans(gt) 308 | prec, recall = self._eval_pr(pred, gt, 255) 309 | avg_p += prec 310 | avg_r += recall 311 | img_num += 1.0 312 | avg_p /= img_num 313 | avg_r /= img_num 314 | score = (1 + beta2) * avg_p * avg_r / (beta2 * avg_p + avg_r) 315 | score[score != score] = 0 # for Nan 316 | 317 | return score.max().item() 318 | def Eval_Emeasure(self): 319 | print('eval[EMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 320 | avg_e, img_num = 0.0, 0.0 321 | with torch.no_grad(): 322 | trans = transforms.Compose([transforms.ToTensor()]) 323 | for pred, gt in self.loader: 324 | if self.cuda: 325 | pred = trans(pred).cuda() 326 | gt = trans(gt).cuda() 327 | else: 328 | pred = trans(pred) 329 | gt = trans(gt) 330 | max_e = self._eval_e(pred, gt, 255) 331 | if max_e == max_e: 332 | avg_e += max_e 333 | img_num += 1.0 334 | 335 | avg_e /= img_num 336 | return avg_e 337 | def Eval_Smeasure(self): 338 | #print('eval[SMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 339 | alpha, avg_q, img_num = 0.5, 0.0, 0.0 340 | with torch.no_grad(): 341 | trans = transforms.Compose([transforms.ToTensor()]) 342 | for pred, gt in self.loader: 343 | if self.cuda: 344 | pred = trans(pred).cuda() 345 | gt = trans(gt).cuda() 346 | else: 347 | pred = trans(pred) 348 | gt = trans(gt) 349 | y = gt.mean() 350 | if y == 0: 351 | x = pred.mean() 352 | Q = 1.0 - x 353 | elif y == 1: 354 | x = pred.mean() 355 | Q = x 356 | else: 357 | Q = alpha * self._S_object(pred, gt) + (1-alpha) * self._S_region(pred, gt) 358 | if Q.item() < 0: 359 | Q = torch.FLoatTensor([0.0]) 360 | img_num += 1.0 361 | avg_q += Q.item() 362 | avg_q /= img_num 363 | 364 | return avg_q 365 | def LOG(self, output): 366 | with open(self.logfile, 'a') as f: 367 | f.write(output) 368 | 369 | def _eval_e(self, y_pred, y, num): 370 | if self.cuda: 371 | score = torch.zeros(num).cuda() 372 | else: 373 | score = torch.zeros(num) 374 | for i in range(num): 375 | fm = y_pred - y_pred.mean() 376 | gt = y - y.mean() 377 | align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20) 378 | enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4 379 | score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20) 380 | return score.max() 381 | 382 | def _eval_pr(self, y_pred, y, num): 383 | if self.cuda: 384 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda() 385 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 386 | else: 387 | prec, recall = torch.zeros(num), torch.zeros(num) 388 | thlist = torch.linspace(0, 1 - 1e-10, num) 389 | for i in range(num): 390 | y_temp = (y_pred >= thlist[i]).float() 391 | tp = (y_temp * y).sum() 392 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 393 | return prec, recall 394 | 395 | def _S_object(self, pred, gt): 396 | fg = torch.where(gt==0, torch.zeros_like(pred), pred) 397 | bg = torch.where(gt==1, torch.zeros_like(pred), 1-pred) 398 | o_fg = self._object(fg, gt) 399 | o_bg = self._object(bg, 1-gt) 400 | u = gt.mean() 401 | Q = u * o_fg + (1-u) * o_bg 402 | return Q 403 | 404 | def _object(self, pred, gt): 405 | temp = pred[gt == 1] 406 | x = temp.mean() 407 | sigma_x = temp.std() 408 | score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20) 409 | 410 | return score 411 | 412 | def _S_region(self, pred, gt): 413 | X, Y = self._centroid(gt) 414 | gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divideGT(gt, X, Y) 415 | p1, p2, p3, p4 = self._dividePrediction(pred, X, Y) 416 | Q1 = self._ssim(p1, gt1) 417 | Q2 = self._ssim(p2, gt2) 418 | Q3 = self._ssim(p3, gt3) 419 | Q4 = self._ssim(p4, gt4) 420 | Q = w1*Q1 + w2*Q2 + w3*Q3 + w4*Q4 421 | # print(Q) 422 | return Q 423 | 424 | def _centroid(self, gt): 425 | rows, cols = gt.size()[-2:] 426 | gt = gt.view(rows, cols) 427 | if gt.sum() == 0: 428 | if self.cuda: 429 | X = torch.eye(1).cuda() * round(cols / 2) 430 | Y = torch.eye(1).cuda() * round(rows / 2) 431 | else: 432 | X = torch.eye(1) * round(cols / 2) 433 | Y = torch.eye(1) * round(rows / 2) 434 | else: 435 | total = gt.sum() 436 | if self.cuda: 437 | i = torch.from_numpy(np.arange(0,cols)).cuda().float() 438 | j = torch.from_numpy(np.arange(0,rows)).cuda().float() 439 | else: 440 | i = torch.from_numpy(np.arange(0,cols)).float() 441 | j = torch.from_numpy(np.arange(0,rows)).float() 442 | X = torch.round((gt.sum(dim=0)*i).sum() / total) 443 | Y = torch.round((gt.sum(dim=1)*j).sum() / total) 444 | return X.long(), Y.long() 445 | 446 | def _divideGT(self, gt, X, Y): 447 | h, w = gt.size()[-2:] 448 | area = h*w 449 | gt = gt.view(h, w) 450 | LT = gt[:Y, :X] 451 | RT = gt[:Y, X:w] 452 | LB = gt[Y:h, :X] 453 | RB = gt[Y:h, X:w] 454 | X = X.float() 455 | Y = Y.float() 456 | w1 = X * Y / area 457 | w2 = (w - X) * Y / area 458 | w3 = X * (h - Y) / area 459 | w4 = 1 - w1 - w2 - w3 460 | return LT, RT, LB, RB, w1, w2, w3, w4 461 | 462 | def _dividePrediction(self, pred, X, Y): 463 | h, w = pred.size()[-2:] 464 | pred = pred.view(h, w) 465 | LT = pred[:Y, :X] 466 | RT = pred[:Y, X:w] 467 | LB = pred[Y:h, :X] 468 | RB = pred[Y:h, X:w] 469 | return LT, RT, LB, RB 470 | 471 | def _ssim(self, pred, gt): 472 | gt = gt.float() 473 | h, w = pred.size()[-2:] 474 | N = h*w 475 | x = pred.mean() 476 | y = gt.mean() 477 | sigma_x2 = ((pred - x)*(pred - x)).sum() / (N - 1 + 1e-20) 478 | sigma_y2 = ((gt - y)*(gt - y)).sum() / (N - 1 + 1e-20) 479 | sigma_xy = ((pred - x)*(gt - y)).sum() / (N - 1 + 1e-20) 480 | 481 | aplha = 4 * x * y *sigma_xy 482 | beta = (x*x + y*y) * (sigma_x2 + sigma_y2) 483 | 484 | if aplha != 0: 485 | Q = aplha / (beta + 1e-20) 486 | elif aplha == 0 and beta == 0: 487 | Q = 1.0 488 | else: 489 | Q = 0 490 | return Q 491 | -------------------------------------------------------------------------------- /utils/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | from torchvision import transforms 7 | 8 | 9 | class Eval_thread(): 10 | def __init__(self, loader, method, dataset, output_dir, cuda): 11 | self.loader = loader 12 | self.method = method 13 | self.dataset = dataset 14 | self.cuda = cuda 15 | self.logfile = os.path.join(output_dir, 'result.txt') 16 | def run(self): 17 | start_time = time.time() 18 | mae = self.Eval_mae() 19 | s = self.Eval_Smeasure() 20 | 21 | 22 | 23 | max_f = self.Eval_fmeasure() 24 | max_e = self.Eval_Emeasure() 25 | 26 | return mae,s,max_f,max_e 27 | 28 | 29 | #self.LOG('{} dataset with {} method get {:.4f} mae, {:.4f} max-fmeasure, {:.4f} max-Emeasure, {:.4f} S-measure..\n'.format(self.dataset, self.method, mae, max_f, max_e, s)) 30 | #return '[cost:{:.4f}s]{} dataset with {} method get {:.4f} mae, {:.4f} max-fmeasure, {:.4f} max-Emeasure, {:.4f} S-measure..'.format(time.time()-start_time, self.dataset, self.method, mae, max_f, max_e, s) 31 | 32 | def Eval_mae(self): 33 | #print('eval[MAE]:{} dataset with {} method.'.format(self.dataset, self.method)) 34 | avg_mae, img_num = 0.0, 0.0 35 | with torch.no_grad(): 36 | trans = transforms.Compose([transforms.ToTensor()]) 37 | for pred, gt in self.loader: 38 | if self.cuda: 39 | pred = trans(pred).cuda() 40 | gt = trans(gt).cuda() 41 | else: 42 | pred = trans(pred) 43 | gt = trans(gt) 44 | mea = torch.abs(pred - gt).mean() 45 | if mea == mea: # for Nan 46 | avg_mae += mea 47 | img_num += 1.0 48 | avg_mae /= img_num 49 | 50 | return avg_mae.item() 51 | 52 | def Eval_fmeasure(self): 53 | print('eval[FMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 54 | beta2 = 0.3 55 | avg_p, avg_r, img_num = 0.0, 0.0, 0.0 56 | with torch.no_grad(): 57 | trans = transforms.Compose([transforms.ToTensor()]) 58 | for pred, gt in self.loader: 59 | if self.cuda: 60 | pred = trans(pred).cuda() 61 | gt = trans(gt).cuda() 62 | else: 63 | pred = trans(pred) 64 | gt = trans(gt) 65 | prec, recall = self._eval_pr(pred, gt, 255) 66 | avg_p += prec 67 | avg_r += recall 68 | img_num += 1.0 69 | avg_p /= img_num 70 | avg_r /= img_num 71 | score = (1 + beta2) * avg_p * avg_r / (beta2 * avg_p + avg_r) 72 | score[score != score] = 0 # for Nan 73 | 74 | return score.max().item() 75 | def Eval_Emeasure(self): 76 | print('eval[EMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 77 | avg_e, img_num = 0.0, 0.0 78 | with torch.no_grad(): 79 | trans = transforms.Compose([transforms.ToTensor()]) 80 | for pred, gt in self.loader: 81 | if self.cuda: 82 | pred = trans(pred).cuda() 83 | gt = trans(gt).cuda() 84 | else: 85 | pred = trans(pred) 86 | gt = trans(gt) 87 | max_e = self._eval_e(pred, gt, 255) 88 | if max_e == max_e: 89 | avg_e += max_e 90 | img_num += 1.0 91 | 92 | avg_e /= img_num 93 | return avg_e.item() 94 | def Eval_Smeasure(self): 95 | #print('eval[SMeasure]:{} dataset with {} method.'.format(self.dataset, self.method)) 96 | alpha, avg_q, img_num = 0.5, 0.0, 0.0 97 | with torch.no_grad(): 98 | trans = transforms.Compose([transforms.ToTensor()]) 99 | for pred, gt in self.loader: 100 | if self.cuda: 101 | pred = trans(pred).cuda() 102 | gt = trans(gt).cuda() 103 | else: 104 | pred = trans(pred) 105 | gt = trans(gt) 106 | y = gt.mean() 107 | if y == 0: 108 | x = pred.mean() 109 | Q = 1.0 - x 110 | elif y == 1: 111 | x = pred.mean() 112 | Q = x 113 | else: 114 | Q = alpha * self._S_object(pred, gt) + (1-alpha) * self._S_region(pred, gt) 115 | if Q.item() < 0: 116 | Q = torch.FLoatTensor([0.0]) 117 | img_num += 1.0 118 | avg_q += Q.item() 119 | avg_q /= img_num 120 | 121 | return avg_q 122 | def LOG(self, output): 123 | with open(self.logfile, 'a') as f: 124 | f.write(output) 125 | 126 | def _eval_e(self, y_pred, y, num): 127 | if self.cuda: 128 | score = torch.zeros(num).cuda() 129 | else: 130 | score = torch.zeros(num) 131 | for i in range(num): 132 | fm = y_pred - y_pred.mean() 133 | gt = y - y.mean() 134 | align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20) 135 | enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4 136 | score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20) 137 | return score.max() 138 | 139 | def _eval_pr(self, y_pred, y, num): 140 | if self.cuda: 141 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda() 142 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 143 | else: 144 | prec, recall = torch.zeros(num), torch.zeros(num) 145 | thlist = torch.linspace(0, 1 - 1e-10, num) 146 | for i in range(num): 147 | y_temp = (y_pred >= thlist[i]).float() 148 | tp = (y_temp * y).sum() 149 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 150 | return prec, recall 151 | 152 | def _S_object(self, pred, gt): 153 | fg = torch.where(gt==0, torch.zeros_like(pred), pred) 154 | bg = torch.where(gt==1, torch.zeros_like(pred), 1-pred) 155 | o_fg = self._object(fg, gt) 156 | o_bg = self._object(bg, 1-gt) 157 | u = gt.mean() 158 | Q = u * o_fg + (1-u) * o_bg 159 | return Q 160 | 161 | def _object(self, pred, gt): 162 | temp = pred[gt == 1] 163 | x = temp.mean() 164 | sigma_x = temp.std() 165 | score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20) 166 | 167 | return score 168 | 169 | def _S_region(self, pred, gt): 170 | X, Y = self._centroid(gt) 171 | gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divideGT(gt, X, Y) 172 | p1, p2, p3, p4 = self._dividePrediction(pred, X, Y) 173 | Q1 = self._ssim(p1, gt1) 174 | Q2 = self._ssim(p2, gt2) 175 | Q3 = self._ssim(p3, gt3) 176 | Q4 = self._ssim(p4, gt4) 177 | Q = w1*Q1 + w2*Q2 + w3*Q3 + w4*Q4 178 | # print(Q) 179 | return Q 180 | 181 | def _centroid(self, gt): 182 | rows, cols = gt.size()[-2:] 183 | gt = gt.view(rows, cols) 184 | if gt.sum() == 0: 185 | if self.cuda: 186 | X = torch.eye(1).cuda() * round(cols / 2) 187 | Y = torch.eye(1).cuda() * round(rows / 2) 188 | else: 189 | X = torch.eye(1) * round(cols / 2) 190 | Y = torch.eye(1) * round(rows / 2) 191 | else: 192 | total = gt.sum() 193 | if self.cuda: 194 | i = torch.from_numpy(np.arange(0,cols)).cuda().float() 195 | j = torch.from_numpy(np.arange(0,rows)).cuda().float() 196 | else: 197 | i = torch.from_numpy(np.arange(0,cols)).float() 198 | j = torch.from_numpy(np.arange(0,rows)).float() 199 | X = torch.round((gt.sum(dim=0)*i).sum() / total) 200 | Y = torch.round((gt.sum(dim=1)*j).sum() / total) 201 | return X.long(), Y.long() 202 | 203 | def _divideGT(self, gt, X, Y): 204 | h, w = gt.size()[-2:] 205 | area = h*w 206 | gt = gt.view(h, w) 207 | LT = gt[:Y, :X] 208 | RT = gt[:Y, X:w] 209 | LB = gt[Y:h, :X] 210 | RB = gt[Y:h, X:w] 211 | X = X.float() 212 | Y = Y.float() 213 | w1 = X * Y / area 214 | w2 = (w - X) * Y / area 215 | w3 = X * (h - Y) / area 216 | w4 = 1 - w1 - w2 - w3 217 | return LT, RT, LB, RB, w1, w2, w3, w4 218 | 219 | def _dividePrediction(self, pred, X, Y): 220 | h, w = pred.size()[-2:] 221 | pred = pred.view(h, w) 222 | LT = pred[:Y, :X] 223 | RT = pred[:Y, X:w] 224 | LB = pred[Y:h, :X] 225 | RB = pred[Y:h, X:w] 226 | return LT, RT, LB, RB 227 | 228 | def _ssim(self, pred, gt): 229 | gt = gt.float() 230 | h, w = pred.size()[-2:] 231 | N = h*w 232 | x = pred.mean() 233 | y = gt.mean() 234 | sigma_x2 = ((pred - x)*(pred - x)).sum() / (N - 1 + 1e-20) 235 | sigma_y2 = ((gt - y)*(gt - y)).sum() / (N - 1 + 1e-20) 236 | sigma_xy = ((pred - x)*(gt - y)).sum() / (N - 1 + 1e-20) 237 | 238 | aplha = 4 * x * y *sigma_xy 239 | beta = (x*x + y*y) * (sigma_x2 + sigma_y2) 240 | 241 | if aplha != 0: 242 | Q = aplha / (beta + 1e-20) 243 | elif aplha == 0 and beta == 0: 244 | Q = 1.0 245 | else: 246 | Q = 0 247 | return Q 248 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from datetime import datetime 4 | import os 5 | #from apex import amp 6 | import torch.nn.functional as F 7 | 8 | 9 | def eval_mae(y_pred, y): 10 | """ 11 | evaluate MAE (for test or validation phase) 12 | :param y_pred: 13 | :param y: 14 | :return: Mean Absolute Error 15 | """ 16 | return torch.abs(y_pred - y).mean() 17 | 18 | 19 | 20 | 21 | def numpy2tensor(numpy): 22 | """ 23 | convert numpy_array in cpu to tensor in gpu 24 | :param numpy: 25 | :return: torch.from_numpy(numpy).cuda() 26 | """ 27 | return torch.from_numpy(numpy).cuda() 28 | 29 | 30 | def clip_gradient(optimizer, grad_clip): 31 | """ 32 | recalibrate the misdirection in the training 33 | :param optimizer: 34 | :param grad_clip: 35 | :return: 36 | """ 37 | for group in optimizer.param_groups: 38 | for param in group['params']: 39 | if param.grad is not None: 40 | param.grad.data.clamp_(-grad_clip, grad_clip) 41 | 42 | 43 | def adjust_lr(optimizer, epoch, decay_rate=0.1, decay_epoch=30): 44 | decay = decay_rate ** (epoch // decay_epoch) 45 | for param_group in optimizer.param_groups: 46 | param_group['lr'] *= decay 47 | 48 | --------------------------------------------------------------------------------