├── README.md
├── imgs
├── framework.png
└── qualitative_results.png
├── lib
├── model.py
└── res2net_v1b_base.py
├── mindspore
├── lib
│ ├── model.py
│ └── res2net_v1b_base.py
├── test.py
├── train.py
└── utils
│ └── dataloader.py
├── test.py
├── train.py
└── utils
├── __init__.py
├── __pycache__
├── CODNet.cpython-36.pyc
├── CODNet.cpython-38.pyc
├── SINet.cpython-36.pyc
├── SearchAttention.cpython-36.pyc
├── SearchAttention.cpython-38.pyc
├── __init__.cpython-36.pyc
├── __init__.cpython-38.pyc
└── res2net_v1b_base.cpython-38.pyc
├── dataloader.py
├── eva_funcs.py
├── evaluator.py
└── trainer.py
/README.md:
--------------------------------------------------------------------------------
1 | # Cross-level Feature Aggregation Network for Polyp Segmentation
2 |
3 | > **Authors:**
4 | > [Tao Zhou](https://taozh2017.github.io/),
5 | > [Yi Zhou](https://cse.seu.edu.cn/2021/0303/c23024a362239/page.htm),
6 | > [Kelei He](https://scholar.google.com/citations?user=0Do_BMIAAAAJ&hl=en),
7 | > [Chen Gong](https://gcatnjust.github.io/ChenGong/index.html),
8 | > [Jian Yang](https://scholar.google.com/citations?user=6CIDtZQAAAAJ&hl=en),
9 | > [Huazhu Fu](http://hzfu.github.io/), and
10 | > [Dinggang Shen](https://scholar.google.com/citations?user=v6VYQC8AAAAJ&hl=en).
11 |
12 |
13 | ## 1. Preface
14 |
15 | - This repository provides code for "_**Cross-level Feature Aggregation Network for Polyp Segmentation (CFANet)**_".
16 | ([paper](https://www.sciencedirect.com/science/article/pii/S0031320323002558))
17 |
18 | - If you have any questions about our paper, feel free to contact me. And if you are using CFANet for your research, please cite this paper ([BibTeX](#4-citation)).
19 |
20 |
21 | ### 1.1. :fire: NEWS :fire:
22 |
23 | - [2023/05/20] Release training/testing code.
24 |
25 | - [2020/05/10] Create repository.
26 |
27 |
28 | ### 2.1. Table of Contents
29 |
30 | - [Cross-level Feature Aggregation Network for Polyp Segmentation]
31 | - [2. Overview](#2-overview)
32 | - [2.1. Introduction](#21-introduction)
33 | - [2.2. Framework Overview](#22-framework-overview)
34 | - [2.3. Qualitative Results](#23-qualitative-results)
35 | - [3. Proposed Baseline](#3-proposed-baseline)
36 | - [3.1. Training/Testing](#31-trainingtesting)
37 | - [3.2 Evaluating your trained model:](#32-evaluating-your-trained-model)
38 | - [3.3 Pre-computed maps:](#33-pre-computed-maps)
39 | - [4. MindSpore](#4-mindspore)
40 | - [5. Citation](#5-citation)
41 | - [6. License](#6-license)
42 |
43 | Table of contents generated with markdown-toc
44 |
45 | ## 2. Overview
46 |
47 | ### 2.1. Introduction
48 |
49 | Accurate segmentation of polyps from colonoscopy images plays a critical role in the diagnosis and cure of colorectal cancer. Although effectiveness has been achieved in the field of polyp segmentation, there are still several challenges. Polyps often have a diversity of size and shape and have no sharp boundary between polyps and their surrounding. To address these challenges, we propose a novel Cross-level Feature Aggregation Network (CFA-Net) for polyp segmentation. Specifically, we first propose a boundary prediction network to generate boundary-aware features, which are incorporated into the segmentation network using a layer-wise strategy. In particular, we design a two-stream structure based segmentation network, to exploit hierarchical semantic information from cross-level features. Furthermore, a Cross-level Feature Fusion (CFF) module is proposed to integrate the adjacent features from different levels, which can characterize the cross-level and multi-scale information to handle scale variations of polyps. Further, a Boundary Aggregated Module (BAM) is proposed to incorporate boundary information into the segmentation network, which enhances these hierarchical features to generate finer segmentation maps. Quantitative and qualitative experiments on five public datasets demonstrate the effectiveness of our CFA-Net against other state-of-the-art polyp segmentation methods
50 |
51 | ### 2.2. Framework Overview
52 |
53 |
54 |
55 |
56 | Figure 1: Overview of the proposed CFANet.
57 |
58 |
59 |
60 | ### 2.3. Qualitative Results
61 |
62 |
63 |
64 |
65 | Figure 2: Qualitative Results.
66 |
67 |
68 |
69 | ## 3. Proposed Baseline
70 |
71 | ### 3.1. Training/Testing
72 |
73 | The training and testing experiments are conducted using [PyTorch](https://github.com/pytorch/pytorch) with
74 | a single NVIDIA Tesla P40 with 24 GB Memory.
75 |
76 | > Note that our model also supports low memory GPU, which means you can lower the batch size
77 |
78 |
79 | 1. Configuring your environment (Prerequisites):
80 |
81 | Note that CFANet is only tested on Ubuntu OS with the following environments.
82 | It may work on other operating systems as well but we do not guarantee that it will.
83 |
84 | + Creating a virtual environment in terminal: `conda create -n CFANet python=3.6`.
85 |
86 | + Installing necessary packages: PyTorch 1.1
87 |
88 | 1. Downloading necessary data:
89 |
90 | + downloading testing dataset and move it into `./data/TestDataset/`,
91 | which can be found in this [download link (Google Drive)](https://drive.google.com/file/d/1hwirZO201i_08fFgqmeqMuPuhPboHdVH/view?usp=sharing). It contains five sub-datsets: CVC-300 (60 test samples), CVC-ClinicDB (62 test samples), CVC-ColonDB (380 test samples), ETIS-LaribPolypDB (196 test samples), Kvasir (100 test samples).
92 |
93 | + downloading training dataset and move it into `./data/TrainDataset/`,
94 | which can be found in this [download link (Google Drive)](https://drive.google.com/file/d/1hzS21idjQlXnX9oxAgJI8KZzOBaz-OWj/view?usp=sharing). It contains two sub-datasets: Kvasir-SEG (900 train samples) and CVC-ClinicDB (550 train samples).
95 |
96 | + downloading pretrained weights and move it into `checkpoint/CFANet.pth`,
97 | which can be found in this [download link (Google Drive)](https://drive.google.com/file/d/1pgvgYebjVVm-QZN-VbGdtYmAyccQmKxZ/view?usp=sharing).
98 |
99 | + downloading Res2Net weights and and move it into `./lib/`,
100 | which can be found in this [download link (Google Drive)](https://drive.google.com/file/d/1_1N-cx1UpRQo7Ybsjno1PAg4KE1T9e5J/view?usp=sharing).
101 |
102 | 1. Training Configuration:
103 |
104 | + Assigning your costumed path, like `--save_model` and `--train_path` in `train.py`.
105 |
106 | + Just enjoy it!
107 |
108 | 1. Testing Configuration:
109 |
110 | + After you download all the pre-trained model and testing dataset, just run `test.py` to generate the final prediction map:
111 | replace your trained model directory (`--pth_path`).
112 |
113 | + Just enjoy it!
114 |
115 | ### 3.2 Evaluating your trained model:
116 |
117 | Matlab: One-key evaluation is written in MATLAB code ([link](https://drive.google.com/file/d/1_h4_CjD5GKEf7B1MRuzye97H0MXf2GE9/view?usp=sharing)),
118 | please follow this the instructions in `./eval/main.m` and just run it to generate the evaluation results in `./res/`.
119 | The complete evaluation toolbox (including data, map, eval code, and res): [new link](https://drive.google.com/file/d/1bnlz7nfJ9hhYsMLFSBr9smcI7k7p0pVy/view?usp=sharing).
120 |
121 | ### 3.3 Pre-computed maps:
122 | They can be found in [download link](https://drive.google.com/file/d/1FY2FFDw-VLwmZ-JbJ-h4uAizcpgiY5vg/view?usp=drive_link).
123 |
124 |
125 | ## 4. MindSpore
126 |
127 | You need to run `cd mindspore` first.
128 |
129 | 1. Environment Configuration:
130 |
131 | + MindSpore: 2.0.0-alpha
132 |
133 | + Python: 3.8.0
134 |
135 | 5. Training Configuration:
136 |
137 | + Assigning your costumed path, like `--save_model` , `--train_img_dir` and so on in `train.py`.
138 |
139 | + Just enjoy it!
140 |
141 | 6. Testing Configuration:
142 |
143 | + After you download all the pre-trained model and testing dataset, just run `test.py` to generate the final prediction map:
144 | replace your trained model directory (`--pth_path`).
145 |
146 | + Just enjoy it!
147 |
148 | ## 5. Citation
149 |
150 | Please cite our paper if you find the work useful:
151 |
152 | @article{zhou2023cross,
153 | title={Cross-level Feature Aggregation Network for Polyp Segmentation},
154 | author={Zhou, Tao and Zhou, Yi and He, Kelei and Gong, Chen and Yang, Jian and Fu, Huazhu and Shen, Dinggang},
155 | journal={Pattern Recognition},
156 | volume={140},
157 | pages={109555},
158 | year={2023},
159 | publisher={Elsevier}
160 | }
161 |
162 |
163 | ## 6. License
164 |
165 | The source code is free for research and education use only. Any comercial use should get formal permission first.
166 |
167 | ---
168 |
169 | **[⬆ back to top](#0-preface)**
170 |
--------------------------------------------------------------------------------
/imgs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/imgs/framework.png
--------------------------------------------------------------------------------
/imgs/qualitative_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/imgs/qualitative_results.png
--------------------------------------------------------------------------------
/lib/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from lib.res2net_v1b_base import Res2Net_model
4 |
5 |
6 | class global_module(nn.Module):
7 | def __init__(self, channels=64, r=4):
8 | super(global_module, self).__init__()
9 | out_channels = int(channels // r)
10 | # local_att
11 |
12 | self.global_att = nn.Sequential(
13 | nn.AdaptiveAvgPool2d(1),
14 | nn.Conv2d(channels, out_channels, kernel_size=1, stride=1, padding=0),
15 | nn.BatchNorm2d(out_channels),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(out_channels, channels, kernel_size=1, stride=1, padding=0),
18 | nn.BatchNorm2d(channels)
19 | )
20 |
21 | self.sig = nn.Sigmoid()
22 |
23 | def forward(self, x):
24 |
25 | xg = self.global_att(x)
26 | out = self.sig(xg)
27 |
28 | return out
29 |
30 |
31 | class BasicConv2d(nn.Module):
32 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
33 | super(BasicConv2d, self).__init__()
34 | self.conv = nn.Conv2d(in_planes, out_planes,
35 | kernel_size=kernel_size, stride=stride,
36 | padding=padding, dilation=dilation, bias=False)
37 | self.bn = nn.BatchNorm2d(out_planes)
38 | self.relu = nn.ReLU(inplace=True)
39 |
40 | def forward(self, x):
41 | x = self.conv(x)
42 | x = self.bn(x)
43 | return x
44 |
45 |
46 | class ChannelAttention(nn.Module):
47 | def __init__(self, in_planes, ratio=16):
48 | super(ChannelAttention, self).__init__()
49 |
50 | self.max_pool = nn.AdaptiveMaxPool2d(1)
51 |
52 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
53 | self.relu1 = nn.ReLU()
54 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
55 |
56 | self.sigmoid = nn.Sigmoid()
57 | def forward(self, x):
58 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
59 | out = max_out
60 | return self.sigmoid(out)
61 |
62 |
63 | class GateFusion(nn.Module):
64 | def __init__(self, in_planes):
65 | self.init__ = super(GateFusion, self).__init__()
66 |
67 | self.gate_1 = nn.Conv2d(in_planes*2, 1, kernel_size=1, bias=True)
68 | self.gate_2 = nn.Conv2d(in_planes*2, 1, kernel_size=1, bias=True)
69 |
70 |
71 | self.softmax = nn.Softmax(dim=1)
72 |
73 | def forward(self, x1, x2):
74 |
75 | ###
76 | cat_fea = torch.cat([x1,x2], dim=1)
77 |
78 | ###
79 | att_vec_1 = self.gate_1(cat_fea)
80 | att_vec_2 = self.gate_2(cat_fea)
81 |
82 | att_vec_cat = torch.cat([att_vec_1, att_vec_2], dim=1)
83 | att_vec_soft = self.softmax(att_vec_cat)
84 |
85 | att_soft_1, att_soft_2 = att_vec_soft[:, 0:1, :, :], att_vec_soft[:, 1:2, :, :]
86 | x_fusion = x1 * att_soft_1 + x2 * att_soft_2
87 |
88 | return x_fusion
89 |
90 |
91 | class BAM(nn.Module):
92 | # Partial Decoder Component (Identification Module)
93 | def __init__(self, channel):
94 | super(BAM, self).__init__()
95 |
96 |
97 | self.relu = nn.ReLU(True)
98 |
99 | self.global_att = global_module(channel)
100 |
101 | self.conv_layer = BasicConv2d(channel*2, channel, 3, padding=1)
102 |
103 |
104 | def forward(self, x, x_boun_atten):
105 |
106 | out1 = self.conv_layer(torch.cat((x, x_boun_atten), dim=1))
107 | out2 = self.global_att(out1)
108 | out3 = out1.mul(out2)
109 |
110 | out = x + out3
111 |
112 | return out
113 |
114 |
115 | class CFF(nn.Module):
116 | def __init__(self, in_channel1, in_channel2, out_channel):
117 | self.init__ = super(CFF, self).__init__()
118 |
119 |
120 | act_fn = nn.ReLU(inplace=True)
121 |
122 | ## ---------------------------------------- ##
123 | self.layer0 = BasicConv2d(in_channel1, out_channel // 2, 1)
124 | self.layer1 = BasicConv2d(in_channel2, out_channel // 2, 1)
125 |
126 | self.layer3_1 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel // 2),act_fn)
127 | self.layer3_2 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel // 2),act_fn)
128 |
129 | self.layer5_1 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(out_channel // 2),act_fn)
130 | self.layer5_2 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(out_channel // 2),act_fn)
131 |
132 | self.layer_out = nn.Sequential(nn.Conv2d(out_channel // 2, out_channel, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel),act_fn)
133 |
134 |
135 | def forward(self, x0, x1):
136 |
137 | ## ------------------------------------------------------------------ ##
138 | x0_1 = self.layer0(x0)
139 | x1_1 = self.layer1(x1)
140 |
141 | x_3_1 = self.layer3_1(torch.cat((x0_1, x1_1), dim=1))
142 | x_5_1 = self.layer5_1(torch.cat((x1_1, x0_1), dim=1))
143 |
144 | x_3_2 = self.layer3_2(torch.cat((x_3_1, x_5_1), dim=1))
145 | x_5_2 = self.layer5_2(torch.cat((x_5_1, x_3_1), dim=1))
146 |
147 | out = self.layer_out(x0_1 + x1_1 + torch.mul(x_3_2, x_5_2))
148 |
149 | return out
150 |
151 |
152 |
153 | ###############################################################################
154 | ## 2022/01/03
155 | ###############################################################################
156 | class CFANet(nn.Module):
157 | # resnet based encoder decoder
158 | def __init__(self, channel=64, opt=None):
159 | super(CFANet, self).__init__()
160 |
161 | act_fn = nn.ReLU(inplace=True)
162 |
163 | self.resnet = Res2Net_model(50)
164 | self.downSample = nn.MaxPool2d(2, stride=2)
165 |
166 | ## ---------------------------------------- ##
167 |
168 | self.layer0 = nn.Sequential(nn.Conv2d(64, channel, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(channel),act_fn)
169 | self.layer1 = nn.Sequential(nn.Conv2d(256, channel, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(channel),act_fn)
170 |
171 | self.low_fusion = GateFusion(channel)
172 |
173 | self.high_fusion1 = CFF(256, 512, channel)
174 | self.high_fusion2 = CFF(1024, 2048, channel)
175 |
176 | ## ---------------------------------------- ##
177 | self.layer_edge0 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
178 | self.layer_edge1 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
179 | self.layer_edge2 = nn.Sequential(nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),act_fn)
180 | self.layer_edge3 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1))
181 |
182 | ## ---------------------------------------- ##
183 | self.layer_cat_ori1 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
184 | self.layer_hig01 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
185 |
186 | self.layer_cat11 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
187 | self.layer_hig11 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
188 |
189 | self.layer_cat21 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
190 | self.layer_hig21 = nn.Sequential(nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),act_fn)
191 |
192 | self.layer_cat31 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),act_fn)
193 | self.layer_hig31 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1))
194 |
195 | self.layer_cat_ori2 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
196 | self.layer_hig02 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
197 |
198 | self.layer_cat12 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
199 | self.layer_hig12 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
200 |
201 | self.layer_cat22 = nn.Sequential(nn.Conv2d(channel*2, channel, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(channel),act_fn)
202 | self.layer_hig22 = nn.Sequential(nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),act_fn)
203 |
204 | self.layer_cat32 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),act_fn)
205 | self.layer_hig32 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1))
206 |
207 | self.layer_fil = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1))
208 |
209 | ## ---------------------------------------- ##
210 |
211 | self.atten_edge_0 = ChannelAttention(channel)
212 | self.atten_edge_1 = ChannelAttention(channel)
213 | self.atten_edge_2 = ChannelAttention(channel)
214 | self.atten_edge_ori = ChannelAttention(channel)
215 |
216 |
217 | self.cat_01 = BAM(channel)
218 | self.cat_11 = BAM(channel)
219 | self.cat_21 = BAM(channel)
220 | self.cat_31 = BAM(channel)
221 |
222 | self.cat_02 = BAM(channel)
223 | self.cat_12 = BAM(channel)
224 | self.cat_22 = BAM(channel)
225 | self.cat_32 = BAM(channel)
226 |
227 |
228 | ## ---------------------------------------- ##
229 | self.downSample = nn.MaxPool2d(2, stride=2)
230 | self.up_2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
231 | self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
232 | self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
233 |
234 |
235 |
236 | def forward(self, xx):
237 |
238 | # ---- feature abstraction -----
239 |
240 | x0, x1, x2, x3, x4 = self.resnet(xx)
241 |
242 |
243 |
244 | ## -------------------------------------- ##
245 |
246 | x0_1 = self.layer0(x0)
247 | x1_1 = self.layer1(x1)
248 |
249 | low_x = self.low_fusion(x0_1, x1_1) # 64*44
250 |
251 |
252 | edge_out0 = self.layer_edge0(self.up_2(low_x)) # 64*88
253 | edge_out1 = self.layer_edge1(self.up_2(edge_out0)) # 64*176
254 | edge_out2 = self.layer_edge2(self.up_2(edge_out1)) # 64*352
255 | edge_out3 = self.layer_edge3(edge_out2)
256 |
257 |
258 | etten_edge_ori = self.atten_edge_ori(low_x)
259 | etten_edge_0 = self.atten_edge_0(edge_out0)
260 | etten_edge_1 = self.atten_edge_1(edge_out1)
261 | etten_edge_2 = self.atten_edge_2(edge_out2)
262 |
263 |
264 | ## -------------------------------------- ##
265 | high_x01 = self.high_fusion1(self.downSample(x1), x2)
266 | high_x02 = self.high_fusion2(self.up_2(x3), self.up_4(x4))
267 |
268 | ## --------------- high 1 ----------------------- #
269 | cat_out_01 = self.cat_01(high_x01,low_x.mul(etten_edge_ori))
270 | hig_out01 = self.layer_hig01(self.up_2(cat_out_01))
271 |
272 | cat_out11 = self.cat_11(hig_out01,edge_out0.mul(etten_edge_0))
273 | hig_out11 = self.layer_hig11(self.up_2(cat_out11))
274 |
275 | cat_out21 = self.cat_21(hig_out11,edge_out1.mul(etten_edge_1))
276 | hig_out21 = self.layer_hig21(self.up_2(cat_out21))
277 |
278 | cat_out31 = self.cat_31(hig_out21,edge_out2.mul(etten_edge_2))
279 | sal_out1 = self.layer_hig31(cat_out31)
280 |
281 | ## ---------------- high 2 ---------------------- ##
282 | cat_out_02 = self.cat_02(high_x02,low_x.mul(etten_edge_ori))
283 | hig_out02 = self.layer_hig02(self.up_2(cat_out_02))
284 |
285 | cat_out12 = self.cat_12(hig_out02,edge_out0.mul(etten_edge_0))
286 | hig_out12 = self.layer_hig12(self.up_2(cat_out12))
287 |
288 | cat_out22 = self.cat_22(hig_out12,edge_out1.mul(etten_edge_1))
289 | hig_out22 = self.layer_hig22(self.up_2(cat_out22))
290 |
291 | cat_out32 = self.cat_32(hig_out22,edge_out2.mul(etten_edge_2))
292 | sal_out2 = self.layer_hig32(cat_out32)
293 |
294 | ## --------------------------------------------- ##
295 | sal_out3 = self.layer_fil(cat_out31+cat_out32)
296 |
297 | # ---- output ----
298 | return edge_out3, sal_out1, sal_out2, sal_out3
299 |
300 |
301 |
--------------------------------------------------------------------------------
/lib/res2net_v1b_base.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import math
4 | import torch.utils.model_zoo as model_zoo
5 | import torch
6 | import torch.nn.functional as F
7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b']
8 |
9 |
10 | model_urls = {
11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth',
12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth',
13 | }
14 |
15 |
16 | class Bottle2neck(nn.Module):
17 | expansion = 4
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'):
20 | """ Constructor
21 | Args:
22 | inplanes: input channel dimensionality
23 | planes: output channel dimensionality
24 | stride: conv stride. Replaces pooling layer.
25 | downsample: None when stride = 1
26 | baseWidth: basic width of conv3x3
27 | scale: number of scale.
28 | type: 'normal': normal set. 'stage': first block of a new stage.
29 | """
30 | super(Bottle2neck, self).__init__()
31 |
32 | width = int(math.floor(planes * (baseWidth/64.0)))
33 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False)
34 | self.bn1 = nn.BatchNorm2d(width*scale)
35 |
36 | if scale == 1:
37 | self.nums = 1
38 | else:
39 | self.nums = scale -1
40 | if stype == 'stage':
41 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1)
42 | convs = []
43 | bns = []
44 | for i in range(self.nums):
45 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False))
46 | bns.append(nn.BatchNorm2d(width))
47 | self.convs = nn.ModuleList(convs)
48 | self.bns = nn.ModuleList(bns)
49 |
50 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
52 |
53 | self.relu = nn.ReLU(inplace=True)
54 | self.downsample = downsample
55 | self.stype = stype
56 | self.scale = scale
57 | self.width = width
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | spx = torch.split(out, self.width, 1)
67 | for i in range(self.nums):
68 | if i==0 or self.stype=='stage':
69 | sp = spx[i]
70 | else:
71 | sp = sp + spx[i]
72 | sp = self.convs[i](sp)
73 | sp = self.relu(self.bns[i](sp))
74 | if i==0:
75 | out = sp
76 | else:
77 | out = torch.cat((out, sp), 1)
78 | if self.scale != 1 and self.stype=='normal':
79 | out = torch.cat((out, spx[self.nums]),1)
80 | elif self.scale != 1 and self.stype=='stage':
81 | out = torch.cat((out, self.pool(spx[self.nums])),1)
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 |
95 | class Res2Net(nn.Module):
96 |
97 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000):
98 | self.inplanes = 64
99 | super(Res2Net, self).__init__()
100 | self.baseWidth = baseWidth
101 | self.scale = scale
102 | self.conv1 = nn.Sequential(
103 | nn.Conv2d(3, 32, 3, 2, 1, bias=False),
104 | nn.BatchNorm2d(32),
105 | nn.ReLU(inplace=True),
106 | nn.Conv2d(32, 32, 3, 1, 1, bias=False),
107 | nn.BatchNorm2d(32),
108 | nn.ReLU(inplace=True),
109 | nn.Conv2d(32, 64, 3, 1, 1, bias=False)
110 | )
111 | self.bn1 = nn.BatchNorm2d(64)
112 | self.relu = nn.ReLU()
113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
114 | self.layer1 = self._make_layer(block, 64, layers[0])
115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
116 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
117 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
118 | self.avgpool = nn.AdaptiveAvgPool2d(1)
119 | self.fc = nn.Linear(512 * block.expansion, num_classes)
120 |
121 | for m in self.modules():
122 | if isinstance(m, nn.Conv2d):
123 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
124 | elif isinstance(m, nn.BatchNorm2d):
125 | nn.init.constant_(m.weight, 1)
126 | nn.init.constant_(m.bias, 0)
127 |
128 | def _make_layer(self, block, planes, blocks, stride=1):
129 | downsample = None
130 | if stride != 1 or self.inplanes != planes * block.expansion:
131 | downsample = nn.Sequential(
132 | nn.AvgPool2d(kernel_size=stride, stride=stride,
133 | ceil_mode=True, count_include_pad=False),
134 | nn.Conv2d(self.inplanes, planes * block.expansion,
135 | kernel_size=1, stride=1, bias=False),
136 | nn.BatchNorm2d(planes * block.expansion),
137 | )
138 |
139 | layers = []
140 | layers.append(block(self.inplanes, planes, stride, downsample=downsample,
141 | stype='stage', baseWidth = self.baseWidth, scale=self.scale))
142 | self.inplanes = planes * block.expansion
143 | for i in range(1, blocks):
144 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale))
145 |
146 | return nn.Sequential(*layers)
147 |
148 | def forward(self, x):
149 |
150 | x = self.conv1(x)
151 | x = self.bn1(x)
152 | x = self.relu(x)
153 | x0 = self.maxpool(x)
154 |
155 |
156 | x1 = self.layer1(x0)
157 | x2 = self.layer2(x1)
158 | x3 = self.layer3(x2)
159 | x4 = self.layer4(x3)
160 |
161 | x5 = self.avgpool(x4)
162 | x6 = x5.view(x5.size(0), -1)
163 | x7 = self.fc(x6)
164 |
165 | return x7
166 |
167 |
168 |
169 | class Res2Net_Ours(nn.Module):
170 |
171 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000):
172 | self.inplanes = 64
173 | super(Res2Net_Ours, self).__init__()
174 |
175 | self.baseWidth = baseWidth
176 | self.scale = scale
177 | self.conv1 = nn.Sequential(
178 | nn.Conv2d(3, 32, 3, 2, 1, bias=False),
179 | nn.BatchNorm2d(32),
180 | nn.ReLU(inplace=True),
181 | nn.Conv2d(32, 32, 3, 1, 1, bias=False),
182 | nn.BatchNorm2d(32),
183 | nn.ReLU(inplace=True),
184 | nn.Conv2d(32, 64, 3, 1, 1, bias=False)
185 | )
186 | self.bn1 = nn.BatchNorm2d(64)
187 | self.relu = nn.ReLU()
188 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
189 | self.layer1 = self._make_layer(block, 64, layers[0])
190 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
191 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
192 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
193 |
194 |
195 | for m in self.modules():
196 | if isinstance(m, nn.Conv2d):
197 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
198 | elif isinstance(m, nn.BatchNorm2d):
199 | nn.init.constant_(m.weight, 1)
200 | nn.init.constant_(m.bias, 0)
201 |
202 | def _make_layer(self, block, planes, blocks, stride=1):
203 | downsample = None
204 | if stride != 1 or self.inplanes != planes * block.expansion:
205 | downsample = nn.Sequential(
206 | nn.AvgPool2d(kernel_size=stride, stride=stride,
207 | ceil_mode=True, count_include_pad=False),
208 | nn.Conv2d(self.inplanes, planes * block.expansion,
209 | kernel_size=1, stride=1, bias=False),
210 | nn.BatchNorm2d(planes * block.expansion),
211 | )
212 |
213 | layers = []
214 | layers.append(block(self.inplanes, planes, stride, downsample=downsample,
215 | stype='stage', baseWidth = self.baseWidth, scale=self.scale))
216 | self.inplanes = planes * block.expansion
217 | for i in range(1, blocks):
218 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale))
219 |
220 | return nn.Sequential(*layers)
221 |
222 | def forward(self, x):
223 |
224 | x = self.conv1(x)
225 | x = self.bn1(x)
226 | x = self.relu(x)
227 | x0 = self.maxpool(x)
228 |
229 |
230 | x1 = self.layer1(x0)
231 | x2 = self.layer2(x1)
232 | x3 = self.layer3(x2)
233 | x4 = self.layer4(x3)
234 |
235 |
236 | return x0,x1,x2,x3,x4
237 |
238 |
239 |
240 | def res2net50_v1b(pretrained=False, **kwargs):
241 | """Constructs a Res2Net-50_v1b model.
242 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s.
243 | Args:
244 | pretrained (bool): If True, returns a model pre-trained on ImageNet
245 | """
246 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
247 | if pretrained:
248 | #model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'],map_location='cpu'))
249 |
250 | model_state = torch.load('./lib/res2net50_v1b_26w_4s-3cf99910.pth')
251 | model.load_state_dict(model_state)
252 |
253 | return model
254 |
255 | def res2net101_v1b(pretrained=False, **kwargs):
256 | """Constructs a Res2Net-50_v1b_26w_4s model.
257 | Args:
258 | pretrained (bool): If True, returns a model pre-trained on ImageNet
259 | """
260 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
261 | if pretrained:
262 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
263 | return model
264 |
265 |
266 |
267 | def res2net50_v1b_Ours(pretrained=False, **kwargs):
268 | """Constructs a Res2Net-50_v1b model.
269 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s.
270 | Args:
271 | pretrained (bool): If True, returns a model pre-trained on ImageNet
272 | """
273 | model = Res2Net_Ours(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
274 | if pretrained:
275 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s']))
276 | return model
277 |
278 | def res2net101_v1b_Ours(pretrained=False, **kwargs):
279 | """Constructs a Res2Net-50_v1b_26w_4s model.
280 | Args:
281 | pretrained (bool): If True, returns a model pre-trained on ImageNet
282 | """
283 | model = Res2Net_Ours(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
284 | if pretrained:
285 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
286 | return model
287 |
288 |
289 |
290 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs):
291 | """Constructs a Res2Net-50_v1b_26w_4s model.
292 | Args:
293 | pretrained (bool): If True, returns a model pre-trained on ImageNet
294 | """
295 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
296 | if pretrained:
297 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'],map_location='cpu'))
298 | return model
299 |
300 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs):
301 | """Constructs a Res2Net-50_v1b_26w_4s model.
302 | Args:
303 | pretrained (bool): If True, returns a model pre-trained on ImageNet
304 | """
305 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
306 | if pretrained:
307 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
308 | return model
309 |
310 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs):
311 | """Constructs a Res2Net-50_v1b_26w_4s model.
312 | Args:
313 | pretrained (bool): If True, returns a model pre-trained on ImageNet
314 | """
315 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth = 26, scale = 4, **kwargs)
316 | if pretrained:
317 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s']))
318 | return model
319 |
320 |
321 |
322 |
323 | def Res2Net_model(ind=50):
324 |
325 | if ind == 50:
326 | model_base = res2net50_v1b(pretrained=False)
327 | model = res2net50_v1b_Ours()
328 |
329 | if ind == 101:
330 | model_base = res2net101_v1b(pretrained=True)
331 | model = res2net101_v1b_Ours()
332 |
333 |
334 | pretrained_dict = model_base.state_dict()
335 | model_dict = model.state_dict()
336 |
337 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
338 |
339 | model_dict.update(pretrained_dict)
340 | model.load_state_dict(model_dict)
341 |
342 | return model
343 |
344 |
345 |
346 |
347 |
348 | if __name__ == '__main__':
349 | images = torch.rand(1, 3, 352, 352)
350 | model = res2net50_v1b_26w_4s(pretrained=False)
351 | model = model
352 | print(model(images).size())
353 |
--------------------------------------------------------------------------------
/mindspore/lib/model.py:
--------------------------------------------------------------------------------
1 | import pdb
2 |
3 | from mindspore import Tensor, ops, Parameter, nn, common
4 | import mindspore as ms
5 | from lib.res2net_v1b_base import Res2Net_model
6 |
7 |
8 | class global_module(nn.Cell):
9 | def __init__(self, channels=64, r=4):
10 | super(global_module, self).__init__()
11 | out_channels = int(channels // r)
12 | # local_att
13 |
14 | self.global_att = nn.SequentialCell(
15 | nn.AdaptiveAvgPool2d(1),
16 | nn.Conv2d(channels, out_channels, kernel_size=1, stride=1, pad_mode='valid', has_bias=True),
17 | nn.BatchNorm2d(out_channels, use_batch_statistics=True),
18 | nn.ReLU(),
19 | nn.Conv2d(out_channels, channels, kernel_size=1, stride=1, pad_mode='valid', has_bias=True),
20 | nn.BatchNorm2d(channels, use_batch_statistics=True)
21 | )
22 |
23 | self.sig = nn.Sigmoid()
24 |
25 | def construct(self, x):
26 | xg = self.global_att(x)
27 | out = self.sig(xg)
28 |
29 | return out
30 |
31 |
32 | class BasicConv2d(nn.Cell):
33 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, pad_mode='valid', dilation=1):
34 | super(BasicConv2d, self).__init__()
35 | self.conv = nn.Conv2d(in_planes, out_planes,
36 | kernel_size=kernel_size, stride=stride,
37 | pad_mode=pad_mode, dilation=dilation, has_bias=False)
38 | self.bn = nn.BatchNorm2d(out_planes)
39 | self.relu = nn.ReLU()
40 |
41 | def construct(self, x):
42 | x = self.conv(x)
43 | x = self.bn(x)
44 | return x
45 |
46 |
47 | class ChannelAttention(nn.Cell):
48 | def __init__(self, in_planes, ratio=16):
49 | super(ChannelAttention, self).__init__()
50 |
51 | # self.max_pool = nn.AdaptiveAvgPool2d(1)
52 |
53 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, pad_mode='valid', has_bias=False)
54 | self.relu1 = nn.ReLU()
55 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, pad_mode='valid', has_bias=False)
56 |
57 | self.sigmoid = nn.Sigmoid()
58 |
59 | def construct(self, x):
60 | _, x = ops.max(x, axis=2, keep_dims=True)
61 | _, x = ops.max(x, axis=3, keep_dims=True)
62 | # x = self.max_pool(x)
63 | max_out = self.fc2(self.relu1(self.fc1(x)))
64 | out = max_out
65 | return self.sigmoid(out)
66 |
67 |
68 | class GateFusion(nn.Cell):
69 | def __init__(self, in_planes):
70 | self.init__ = super(GateFusion, self).__init__()
71 |
72 | self.gate_1 = nn.Conv2d(in_planes * 2, 1, kernel_size=1, pad_mode='valid', has_bias=True)
73 | self.gate_2 = nn.Conv2d(in_planes * 2, 1, kernel_size=1, pad_mode='valid', has_bias=True)
74 |
75 | self.softmax = nn.Softmax(axis=1)
76 |
77 | def construct(self, x1, x2):
78 | ###
79 | cat_fea = ops.concat((x1, x2), axis=1)
80 |
81 | ###
82 | att_vec_1 = self.gate_1(cat_fea)
83 | att_vec_2 = self.gate_2(cat_fea)
84 |
85 | att_vec_cat = ops.concat((att_vec_1, att_vec_2), axis=1)
86 | att_vec_soft = self.softmax(att_vec_cat)
87 |
88 | att_soft_1, att_soft_2 = att_vec_soft[:, 0:1, :, :], att_vec_soft[:, 1:2, :, :]
89 | x_fusion = x1 * att_soft_1 + x2 * att_soft_2
90 |
91 | return x_fusion
92 |
93 |
94 | class BAM(nn.Cell):
95 | # Partial Decoder Component (Identification Module)
96 | def __init__(self, channel):
97 | super(BAM, self).__init__()
98 |
99 | self.relu = nn.ReLU()
100 |
101 | self.global_att = global_module(channel)
102 |
103 | self.conv_layer = BasicConv2d(channel * 2, channel, 3, pad_mode='same')
104 |
105 | def construct(self, x, x_boun_atten):
106 | out1 = self.conv_layer(ops.concat((x, x_boun_atten), axis=1))
107 | out2 = self.global_att(out1)
108 | out3 = out1 * out2
109 |
110 | out = x + out3
111 |
112 | return out
113 |
114 |
115 | class CFF(nn.Cell):
116 | def __init__(self, in_channel1, in_channel2, out_channel):
117 | self.init__ = super(CFF, self).__init__()
118 |
119 | act_fn = nn.ReLU()
120 |
121 | ## ---------------------------------------- ##
122 | self.layer0 = BasicConv2d(in_channel1, out_channel // 2, 1)
123 | self.layer1 = BasicConv2d(in_channel2, out_channel // 2, 1)
124 |
125 | self.layer3_1 = nn.SequentialCell(
126 | nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
127 | nn.BatchNorm2d(out_channel // 2), act_fn)
128 | self.layer3_2 = nn.SequentialCell(
129 | nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
130 | nn.BatchNorm2d(out_channel // 2), act_fn)
131 |
132 | self.layer5_1 = nn.SequentialCell(
133 | nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, pad_mode='same', has_bias=True),
134 | nn.BatchNorm2d(out_channel // 2), act_fn)
135 | self.layer5_2 = nn.SequentialCell(
136 | nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, pad_mode='same', has_bias=True),
137 | nn.BatchNorm2d(out_channel // 2), act_fn)
138 |
139 | self.layer_out = nn.SequentialCell(
140 | nn.Conv2d(out_channel // 2, out_channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
141 | nn.BatchNorm2d(out_channel), act_fn)
142 |
143 | def construct(self, x0, x1):
144 | ## ------------------------------------------------------------------ ##
145 | x0_1 = self.layer0(x0)
146 | x1_1 = self.layer1(x1)
147 |
148 | x_3_1 = self.layer3_1(ops.concat((x0_1, x1_1), axis=1))
149 | x_5_1 = self.layer5_1(ops.concat((x1_1, x0_1), axis=1))
150 |
151 | x_3_2 = self.layer3_2(ops.concat((x_3_1, x_5_1), axis=1))
152 | x_5_2 = self.layer5_2(ops.concat((x_5_1, x_3_1), axis=1))
153 |
154 | out = self.layer_out(x0_1 + x1_1 + x_3_2 * x_5_2)
155 |
156 | return out
157 |
158 |
159 | ###############################################################################
160 | ## 2022/01/03
161 | ###############################################################################
162 | class CFANet(nn.Cell):
163 | # resnet based encoder decoder
164 | def __init__(self, channel=64, opt=None):
165 | super(CFANet, self).__init__()
166 |
167 | act_fn = nn.ReLU()
168 |
169 | self.resnet = Res2Net_model(50)
170 | self.downSample = nn.MaxPool2d(kernel_size=2, stride=2)
171 |
172 | ## ---------------------------------------- ##
173 |
174 | self.layer0 = nn.SequentialCell(
175 | nn.Conv2d(64, channel, kernel_size=3, stride=2, pad_mode='same', has_bias=True),
176 | nn.BatchNorm2d(channel ),
177 | act_fn)
178 | self.layer1 = nn.SequentialCell(
179 | nn.Conv2d(256, channel, kernel_size=3, stride=2, pad_mode='same', has_bias=True),
180 | nn.BatchNorm2d(channel ),
181 | act_fn)
182 |
183 | self.low_fusion = GateFusion(channel)
184 |
185 | self.high_fusion1 = CFF(256, 512, channel)
186 | self.high_fusion2 = CFF(1024, 2048, channel)
187 |
188 | ## ---------------------------------------- ##
189 | self.layer_edge0 = nn.SequentialCell(
190 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
191 | nn.BatchNorm2d(channel ),
192 | act_fn)
193 | self.layer_edge1 = nn.SequentialCell(
194 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
195 | nn.BatchNorm2d(channel ),
196 | act_fn)
197 | self.layer_edge2 = nn.SequentialCell(
198 | nn.Conv2d(channel, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
199 | nn.BatchNorm2d(64 ),
200 | act_fn)
201 | self.layer_edge3 = nn.Conv2d(64, 1, kernel_size=1, pad_mode='valid', has_bias=True)
202 |
203 | ## ---------------------------------------- ##
204 | self.layer_cat_ori1 = nn.SequentialCell(
205 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
206 | nn.BatchNorm2d(channel ),
207 | act_fn)
208 | self.layer_hig01 = nn.SequentialCell(
209 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
210 | nn.BatchNorm2d(channel ),
211 | act_fn)
212 |
213 | self.layer_cat11 = nn.SequentialCell(
214 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
215 | nn.BatchNorm2d(channel ),
216 | act_fn)
217 | self.layer_hig11 = nn.SequentialCell(
218 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
219 | nn.BatchNorm2d(channel ),
220 | act_fn)
221 |
222 | self.layer_cat21 = nn.SequentialCell(
223 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
224 | nn.BatchNorm2d(channel ),
225 | act_fn)
226 | self.layer_hig21 = nn.SequentialCell(
227 | nn.Conv2d(channel, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
228 | nn.BatchNorm2d(64 ),
229 | act_fn)
230 |
231 | self.layer_cat31 = nn.SequentialCell(
232 | nn.Conv2d(64 * 2, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
233 | nn.BatchNorm2d(64 ),
234 | act_fn)
235 | self.layer_hig31 = nn.Conv2d(64, 1, kernel_size=1, pad_mode='valid', has_bias=True)
236 |
237 | self.layer_cat_ori2 = nn.SequentialCell(
238 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
239 | nn.BatchNorm2d(channel ),
240 | act_fn)
241 | self.layer_hig02 = nn.SequentialCell(
242 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
243 | nn.BatchNorm2d(channel ),
244 | act_fn)
245 |
246 | self.layer_cat12 = nn.SequentialCell(
247 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
248 | nn.BatchNorm2d(channel ),
249 | act_fn)
250 | self.layer_hig12 = nn.SequentialCell(
251 | nn.Conv2d(channel, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
252 | nn.BatchNorm2d(channel ),
253 | act_fn)
254 |
255 | self.layer_cat22 = nn.SequentialCell(
256 | nn.Conv2d(channel * 2, channel, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
257 | nn.BatchNorm2d(channel ),
258 | act_fn)
259 | self.layer_hig22 = nn.SequentialCell(
260 | nn.Conv2d(channel, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
261 | nn.BatchNorm2d(64 ),
262 | act_fn)
263 |
264 | self.layer_cat32 = nn.SequentialCell(
265 | nn.Conv2d(64 * 2, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=True),
266 | nn.BatchNorm2d(64),
267 | act_fn)
268 | self.layer_hig32 = nn.Conv2d(64, 1, kernel_size=1, pad_mode='valid', has_bias=True)
269 |
270 | self.layer_fil = nn.Conv2d(64, 1, kernel_size=1, pad_mode='valid', has_bias=True)
271 |
272 | ## ---------------------------------------- ##
273 |
274 | self.atten_edge_0 = ChannelAttention(channel)
275 | self.atten_edge_1 = ChannelAttention(channel)
276 | self.atten_edge_2 = ChannelAttention(channel)
277 | self.atten_edge_ori = ChannelAttention(channel)
278 |
279 | self.cat_01 = BAM(channel)
280 | self.cat_11 = BAM(channel)
281 | self.cat_21 = BAM(channel)
282 | self.cat_31 = BAM(channel)
283 |
284 | self.cat_02 = BAM(channel)
285 | self.cat_12 = BAM(channel)
286 | self.cat_22 = BAM(channel)
287 | self.cat_32 = BAM(channel)
288 |
289 | ## ---------------------------------------- ##
290 | self.downSample = nn.MaxPool2d(kernel_size=2, stride=2)
291 | self.up = nn.ResizeBilinear()
292 |
293 | def construct(self, xx):
294 | # ---- feature abstraction -----
295 |
296 | x0, x1, x2, x3, x4 = self.resnet(xx)
297 |
298 | ## -------------------------------------- ##
299 |
300 | x0_1 = self.layer0(x0)
301 | x1_1 = self.layer1(x1)
302 |
303 | low_x = self.low_fusion(x0_1, x1_1) # 64*44
304 |
305 | edge_out0 = self.layer_edge0(self.up(low_x, scale_factor=2, align_corners=True)) # 64*88
306 | edge_out1 = self.layer_edge1(self.up(edge_out0, scale_factor=2, align_corners=True)) # 64*176
307 | edge_out2 = self.layer_edge2(self.up(edge_out1, scale_factor=2, align_corners=True)) # 64*352
308 | edge_out3 = self.layer_edge3(edge_out2)
309 |
310 | etten_edge_ori = self.atten_edge_ori(low_x)
311 | etten_edge_0 = self.atten_edge_0(edge_out0)
312 | etten_edge_1 = self.atten_edge_1(edge_out1)
313 | etten_edge_2 = self.atten_edge_2(edge_out2)
314 |
315 | ## -------------------------------------- ##
316 | high_x01 = self.high_fusion1(self.downSample(x1), x2)
317 | high_x02 = self.high_fusion2(self.up(x3, scale_factor=2, align_corners=True),
318 | self.up(x4, scale_factor=4, align_corners=True))
319 |
320 | ## --------------- high 1 ----------------------- #
321 | cat_out_01 = self.cat_01(high_x01, low_x * etten_edge_ori)
322 | hig_out01 = self.layer_hig01(self.up(cat_out_01, scale_factor=2, align_corners=True))
323 |
324 | cat_out11 = self.cat_11(hig_out01, edge_out0 * (etten_edge_0))
325 | hig_out11 = self.layer_hig11(self.up(cat_out11, scale_factor=2, align_corners=True))
326 |
327 | cat_out21 = self.cat_21(hig_out11, edge_out1 * (etten_edge_1))
328 | hig_out21 = self.layer_hig21(self.up(cat_out21, scale_factor=2, align_corners=True))
329 |
330 | cat_out31 = self.cat_31(hig_out21, edge_out2 * (etten_edge_2))
331 | sal_out1 = self.layer_hig31(cat_out31)
332 |
333 | ## ---------------- high 2 ---------------------- ##
334 | cat_out_02 = self.cat_02(high_x02, low_x * (etten_edge_ori))
335 | hig_out02 = self.layer_hig02(self.up(cat_out_02, scale_factor=2, align_corners=True))
336 |
337 | cat_out12 = self.cat_12(hig_out02, edge_out0 * (etten_edge_0))
338 | hig_out12 = self.layer_hig12(self.up(cat_out12, scale_factor=2, align_corners=True))
339 |
340 | cat_out22 = self.cat_22(hig_out12, edge_out1 * (etten_edge_1))
341 | hig_out22 = self.layer_hig22(self.up(cat_out22, scale_factor=2, align_corners=True))
342 |
343 | cat_out32 = self.cat_32(hig_out22, edge_out2 * (etten_edge_2))
344 | sal_out2 = self.layer_hig32(cat_out32)
345 |
346 | ## --------------------------------------------- ##
347 | sal_out3 = self.layer_fil(cat_out31 + cat_out32)
348 |
349 | # ---- output ----
350 | return edge_out3, sal_out1, sal_out2, sal_out3
351 |
--------------------------------------------------------------------------------
/mindspore/lib/res2net_v1b_base.py:
--------------------------------------------------------------------------------
1 | import pdb
2 |
3 | from mindspore import Tensor, ops, Parameter, nn, common
4 | import mindspore as ms
5 | import math
6 |
7 | __all__ = ['Res2Net', 'res2net50_v1b']
8 |
9 | # model_urls = {
10 | # 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth',
11 | # 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth',
12 | # }
13 |
14 |
15 | class Bottle2neck(nn.Cell):
16 | expansion = 4
17 |
18 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'):
19 | """ Constructor
20 | Args:
21 | inplanes: input channel dimensionality
22 | planes: output channel dimensionality
23 | stride: conv stride. Replaces pooling layer.
24 | downsample: None when stride = 1
25 | baseWidth: basic width of conv3x3
26 | scale: number of scale.
27 | type: 'normal': normal set. 'stage': first block of a new stage.
28 | """
29 | super(Bottle2neck, self).__init__()
30 |
31 | width = int(math.floor(planes * (baseWidth / 64.0)))
32 | self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, pad_mode='valid', has_bias=False)
33 | self.bn1 = nn.BatchNorm2d(width * scale)
34 |
35 | if scale == 1:
36 | self.nums = 1
37 | else:
38 | self.nums = scale - 1
39 | if stype == 'stage':
40 | self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, pad_mode="same")
41 | convs = []
42 | bns = []
43 | for i in range(self.nums):
44 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, pad_mode='same', has_bias=False))
45 | bns.append(nn.BatchNorm2d(width))
46 | self.convs = nn.CellList(convs)
47 | self.bns = nn.CellList(bns)
48 |
49 | self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, pad_mode='valid', has_bias=False)
50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
51 |
52 | self.relu = nn.ReLU()
53 | self.downsample = downsample
54 | self.stype = stype
55 | self.scale = scale
56 | self.width = width
57 |
58 | def construct(self, x):
59 | residual = x
60 |
61 | out = self.conv1(x)
62 | out = self.bn1(out)
63 | out = self.relu(out)
64 |
65 | sp = None
66 | spx = ops.Split(1, self.scale)(out)
67 | for i in range(self.nums):
68 | if i == 0 or self.stype == 'stage':
69 | sp = spx[i]
70 | else:
71 | sp = sp + spx[i]
72 | sp = self.convs[i](sp)
73 | sp = self.relu(self.bns[i](sp))
74 | if i == 0:
75 | out = sp
76 | else:
77 | out = ops.Concat(1)((out, sp))
78 | if self.scale != 1 and self.stype == 'normal':
79 | out = ops.Concat(1)((out, spx[self.nums])) # torch.cat((out, spx[self.nums]), 1)
80 | elif self.scale != 1 and self.stype == 'stage':
81 | out = ops.Concat(1)((out, self.pool(spx[self.nums]))) # torch.cat((out, self.pool(spx[self.nums])), 1)
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 |
95 | class Res2Net(nn.Cell):
96 |
97 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000):
98 | self.inplanes = 64
99 | super(Res2Net, self).__init__()
100 | self.baseWidth = baseWidth
101 | self.scale = scale
102 | self.conv1 = nn.SequentialCell(
103 | nn.Conv2d(3, 32, kernel_size=3, stride=2, pad_mode='same', has_bias=False),
104 | nn.BatchNorm2d(32),
105 | nn.ReLU(),
106 | nn.Conv2d(32, 32, kernel_size=3, stride=1, pad_mode='same', has_bias=False),
107 | nn.BatchNorm2d(32),
108 | nn.ReLU(),
109 | nn.Conv2d(32, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=False)
110 | )
111 | self.bn1 = nn.BatchNorm2d(64)
112 | self.relu = nn.ReLU()
113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
114 | self.layer1 = self._make_layer(block, 64, layers[0])
115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
116 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
117 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
118 | self.avgpool = nn.AdaptiveAvgPool2d(1)
119 | self.fc = nn.Dense(512 * block.expansion, num_classes)
120 |
121 | # for m in self.cells():
122 | # if isinstance(m, nn.Conv2d):
123 | # common.initializer.HeNormal(mode='fan_out', nonlinearity='relu')(m.weight)
124 | # elif isinstance(m, nn.BatchNorm2d):
125 | # print(m)
126 | # common.initializer.Constant(1)(m.weight)
127 | # common.initializer.Constant(0)(m.bias)
128 |
129 | def _make_layer(self, block, planes, blocks, stride=1):
130 | downsample = None
131 | if stride != 1 or self.inplanes != planes * block.expansion:
132 | downsample = nn.SequentialCell(
133 | nn.AvgPool2d(kernel_size=stride, stride=stride),
134 | nn.Conv2d(self.inplanes, planes * block.expansion,
135 | kernel_size=1, stride=1, pad_mode='valid', has_bias=False),
136 | nn.BatchNorm2d(planes * block.expansion),
137 | )
138 |
139 | layers = []
140 | layers.append(block(self.inplanes, planes, stride, downsample=downsample,
141 | stype='stage', baseWidth=self.baseWidth, scale=self.scale))
142 | self.inplanes = planes * block.expansion
143 | for i in range(1, blocks):
144 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale))
145 |
146 | return nn.SequentialCell(*layers)
147 |
148 | def construct(self, x):
149 |
150 | x = self.conv1(x)
151 | x = self.bn1(x)
152 | x = self.relu(x)
153 | x0 = self.maxpool(x)
154 |
155 | x1 = self.layer1(x0)
156 | x2 = self.layer2(x1)
157 | x3 = self.layer3(x2)
158 | x4 = self.layer4(x3)
159 |
160 | x5 = self.avgpool(x4)
161 | x6 = x5.view(x5.size(0), -1)
162 | x7 = self.fc(x6)
163 |
164 | return x7
165 |
166 |
167 | class Res2Net_Ours(nn.Cell):
168 |
169 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000):
170 | self.inplanes = 64
171 | super(Res2Net_Ours, self).__init__()
172 |
173 | self.baseWidth = baseWidth
174 | self.scale = scale
175 | self.conv1 = nn.SequentialCell(
176 | nn.Conv2d(3, 32, kernel_size=3, stride=2, pad_mode='same', has_bias=False),
177 | nn.BatchNorm2d(32),
178 | nn.ReLU(),
179 | nn.Conv2d(32, 32, kernel_size=3, stride=1, pad_mode='same', has_bias=False),
180 | nn.BatchNorm2d(32),
181 | nn.ReLU(),
182 | nn.Conv2d(32, 64, kernel_size=3, stride=1, pad_mode='same', has_bias=False)
183 | )
184 | self.bn1 = nn.BatchNorm2d(64)
185 | self.relu = nn.ReLU()
186 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
187 | self.layer1 = self._make_layer(block, 64, layers[0])
188 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
189 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
190 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
191 |
192 | # for m in self.cells():
193 | # if isinstance(m, nn.Conv2d):
194 | # common.initializer.HeNormal(mode='fan_out', nonlinearity='relu')(m.weight)
195 | # elif isinstance(m, nn.BatchNorm2d):
196 | # print(m)
197 | # common.initializer.Constant(1)(m.weight)
198 | # common.initializer.Constant(0)(m.bias)
199 |
200 | def _make_layer(self, block, planes, blocks, stride=1):
201 | downsample = None
202 | if stride != 1 or self.inplanes != planes * block.expansion:
203 | downsample = nn.SequentialCell(
204 | nn.AvgPool2d(kernel_size=stride, stride=stride),
205 | nn.Conv2d(self.inplanes, planes * block.expansion,
206 | kernel_size=1, stride=1, pad_mode='valid', has_bias=False),
207 | nn.BatchNorm2d(planes * block.expansion),
208 | )
209 |
210 | layers = []
211 | layers.append(block(self.inplanes, planes, stride, downsample=downsample,
212 | stype='stage', baseWidth=self.baseWidth, scale=self.scale))
213 | self.inplanes = planes * block.expansion
214 | for i in range(1, blocks):
215 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale))
216 |
217 | return nn.SequentialCell(*layers)
218 |
219 | def construct(self, x):
220 |
221 | x = self.conv1(x)
222 | x = self.bn1(x)
223 | x = self.relu(x)
224 | x0 = self.maxpool(x)
225 |
226 | x1 = self.layer1(x0)
227 | x2 = self.layer2(x1)
228 | x3 = self.layer3(x2)
229 | x4 = self.layer4(x3)
230 |
231 | return x0, x1, x2, x3, x4
232 |
233 |
234 | def res2net50_v1b(pretrained=False, **kwargs):
235 | """Constructs a Res2Net-50_v1b model.
236 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s.
237 | Args:
238 | pretrained (bool): If True, returns a model pre-trained on ImageNet
239 | """
240 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
241 | if pretrained:
242 | # model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'],map_location='cpu'))
243 | param_dict = ms.load_checkpoint(
244 | './lib/res2net50_v1b_26w_4s-3cf99910.ckpt')
245 | ms.load_param_into_net(model, param_dict)
246 |
247 | return model
248 |
249 |
250 | def res2net50_v1b_Ours(pretrained=False, **kwargs):
251 | """Constructs a Res2Net-50_v1b model.
252 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s.
253 | Args:
254 | pretrained (bool): If True, returns a model pre-trained on ImageNet
255 | """
256 | model = Res2Net_Ours(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
257 | if pretrained:
258 | param_dict = ms.load_checkpoint(
259 | './lib/res2net50_v1b_26w_4s-3cf99910.ckpt')
260 | ms.load_param_into_net(model, param_dict)
261 | return model
262 |
263 |
264 | def Res2Net_model(ind=50):
265 | if ind == 50:
266 | model_base = res2net50_v1b(pretrained=False)
267 | model = res2net50_v1b_Ours()
268 |
269 | pretrained_dict = model_base.parameters_dict()
270 | model_dict = model.parameters_dict()
271 |
272 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
273 |
274 | ms.load_param_into_net(model, pretrained_dict)
275 |
276 | return model
277 |
--------------------------------------------------------------------------------
/mindspore/test.py:
--------------------------------------------------------------------------------
1 | import pdb
2 |
3 | import mindspore.numpy as np
4 | import os, argparse
5 | import mindspore as ms
6 | from mindspore import ops, nn
7 | from lib.model import CFANet
8 | from utils.dataloader import get_loader, get_loader_test
9 |
10 | import cv2
11 |
12 | ms.set_context(device_target='GPU')
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--testsize', type=int, default=352, help='testing size')
16 | parser.add_argument('--pth_path', type=str, default='./checkpoint/CFANet.pth')
17 |
18 | for _data_name in ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']:
19 |
20 | print('-----------starting -------------')
21 |
22 | # data_path = '/test/Polpy/Dataset/TestDataset/{}/'.format(_data_name)
23 | data_path = '../data/TestDataset/{}/'.format(_data_name)
24 | save_path = './Snapshot/seg_maps/{}/'.format(_data_name)
25 |
26 | opt = parser.parse_args()
27 | model = CFANet(channel=64)
28 |
29 | param_dict = ms.load_checkpoint(opt.pth_path)
30 | ms.load_param_into_net(model, param_dict)
31 | model.set_train(False)
32 |
33 | os.makedirs(save_path, exist_ok=True)
34 |
35 | image_root = '{}/images/'.format(data_path)
36 | gt_root = '{}/masks/'.format(data_path)
37 |
38 | test_loader = get_loader_test(image_root, gt_root, testsize=opt.testsize)
39 |
40 | for i, pack in enumerate(test_loader.create_tuple_iterator()):
41 | print(['--------------processing-------------', i])
42 |
43 | image, gt, name = pack
44 | image = ops.ExpandDims()(image, 0)
45 | gt = np.asarray(gt, np.float32)
46 | gt /= (gt.max() + 1e-8)
47 |
48 | _, _, _, res = model(image)
49 |
50 | res = nn.ResizeBilinear()(res, size=gt.shape, align_corners=False)
51 | res = res.sigmoid()[0]
52 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
53 | cv2.imwrite(save_path + str(name.asnumpy()), ops.Transpose()(res, (1, 2, 0)).asnumpy() * 255)
54 |
--------------------------------------------------------------------------------
/mindspore/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pdb
4 |
5 | import numpy as np
6 | import logging
7 | from datetime import datetime
8 |
9 | import mindspore as ms
10 | from mindspore import context
11 | from mindspore import nn, ops
12 |
13 | from lib.model import CFANet
14 | from utils.dataloader import get_loader, get_loader_test
15 |
16 | ms.set_context(device_target='GPU', mode=context.GRAPH_MODE)
17 |
18 | import cv2
19 |
20 |
21 | class Trainer:
22 | def __init__(self, train_loader, test_loader, model, optimizer, opt):
23 | self.train_loader = train_loader
24 | self.test_loader = test_loader
25 | self.model = model
26 | self.optimizer = optimizer
27 | self.opt = opt
28 | self.total_step = self.train_loader.get_dataset_size() # TODO
29 |
30 | self.grad_fn = ops.value_and_grad(self.forward_fn, None, self.optimizer.parameters, has_aux=True)
31 | self.loss_func = nn.BCEWithLogitsLoss()
32 | self.size_rates = [0.75, 1, 1.25]
33 | self.resize = nn.ResizeBilinear()
34 | self.eval_loss_func = nn.L1Loss()
35 | self.best_mae = 1
36 | self.best_epoch = 0
37 | self.decay_rate = 0.1
38 | self.decay_epoch = 30
39 |
40 | def forward_fn(self, images, gts, egs):
41 | cam_edge, sal_out1, sal_out2, sal_out3 = self.model(images)
42 | loss_edge = self.loss_func(cam_edge, egs)
43 | loss_sal1 = self.structure_loss(sal_out1, gts)
44 | loss_sal2 = self.structure_loss(sal_out2, gts)
45 | loss_sal3 = self.structure_loss(sal_out3, gts)
46 |
47 | loss_total = loss_edge + loss_sal1 + loss_sal2 + loss_sal3
48 | return loss_total, loss_edge, loss_sal1, loss_sal2, loss_sal3
49 |
50 | def train_step(self, images, gts, egs):
51 | (loss, loss_edge, loss_sal1, loss_sal2, loss_sal3), grads = self.grad_fn(images, gts, egs)
52 | self.optimizer(grads)
53 | return loss, loss_edge, loss_sal1, loss_sal2, loss_sal3
54 |
55 | def train(self, epochs):
56 | for epoch in range(1, epochs + 1):
57 | self.model.set_train(True)
58 | self.adjust_lr(epoch)
59 | for step, data_pack in enumerate(self.train_loader.create_tuple_iterator(), start=1):
60 | images, gts, egs = data_pack
61 | for rate in self.size_rates:
62 | # ---- rescale ----
63 | trainsize = int(round(opt.trainsize * rate / 32) * 32)
64 | if rate != 1:
65 | images = self.resize(images, size=(trainsize, trainsize), align_corners=True)
66 | gts = self.resize(gts, size=(trainsize, trainsize), align_corners=True)
67 | egs = self.resize(egs, size=(trainsize, trainsize), align_corners=True)
68 | loss, loss_edge, loss_sal1, loss_sal2, loss_sal3 = self.train_step(images, gts, egs)
69 |
70 | # -- output loss -- #
71 | if step % 10 == 0 or step == self.total_step:
72 | print(
73 | '[{}] => [Epoch Num: {:03d}/{:03d}] => [Global Step: {:04d}/{:04d}] => [Loss_edge: {:.4f} Loss_sal1: {:0.4f} Loss_sal2: {:0.4f} Loss_sal3: {:0.4f} Loss_total: {:0.4f}]'.
74 | format(datetime.now(), epoch, epochs, step, self.total_step, loss_edge.asnumpy(),
75 | loss_sal1.asnumpy(), loss_sal2.asnumpy(), loss_sal3.asnumpy(), loss.asnumpy()))
76 |
77 | logging.info(
78 | '#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss_edge: {:.4f} Loss_sal1: {:0.4f} Loss_sal2: {:0.4f} Loss_sal3: {:0.4f} Loss_total: {:0.4f}'.
79 | format(epoch, opt.epoch, step, self.total_step, loss_edge.asnumpy(), loss_sal1.asnumpy(),
80 | loss_sal2.asnumpy(), loss_sal3.asnumpy(), loss.asnumpy()))
81 |
82 | self.test(epoch)
83 |
84 | if epoch % self.opt.save_epoch == 0:
85 | ms.save_checkpoint(model, os.path.join(save_path, 'CODNet_%d.pth' % (epoch)))
86 |
87 | def test(self, epoch):
88 | self.model.set_train(False)
89 | mae_sum = 0
90 | for i, pack in enumerate(self.test_loader.create_tuple_iterator(), start=1):
91 | # ---- data prepare ----
92 | image, gt, name = pack
93 | image = ops.ExpandDims()(image, 0)
94 |
95 | # ---- forward ----
96 | _, _, _, res = self.model(image)
97 | res = self.resize(res, size=gt.shape, align_corners=False)
98 | # pdb.set_trace()
99 | res = nn.Sigmoid()(res[0][0]) # TODO
100 | # pdb.set_trace()
101 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
102 | # pdb.set_trace()
103 | mae_sum += self.eval_loss_func(res, gt)
104 | # pdb.set_trace()
105 | # mae_sum += np.sum(np.abs(res - gt)) * 1.0 / (gt.shape[0] * gt.shape[1])
106 |
107 | # ---- recording loss ----
108 | mae = mae_sum / self.test_loader.get_dataset_size()
109 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, self.best_mae, self.best_epoch))
110 | if epoch == 1:
111 | self.best_mae = mae
112 | self.best_epoch = epoch
113 | else:
114 | if mae < self.best_mae:
115 | self.best_mae = mae
116 | self.best_epoch = epoch
117 |
118 | ms.save_checkpoint(model, os.path.join(save_path, 'Cod_best.ckpt'))
119 | print('best epoch:{}'.format(epoch))
120 |
121 |
122 | def structure_loss(self, pred, mask):
123 | pred = nn.Sigmoid()(pred)
124 | weit = 1 + 5 * ops.Abs()(ops.AvgPool(kernel_size=31, strides=1, pad_mode='same')(mask) - mask)
125 | wbce = nn.BCELoss(reduction='none')(pred, mask)
126 | wbce = (weit * wbce).sum(axis=(2, 3)) / weit.sum(axis=(2, 3))
127 |
128 | inter = ((pred * mask) * weit).sum(axis=(2, 3))
129 | union = ((pred + mask) * weit).sum(axis=(2, 3))
130 | wiou = 1 - (inter + 1) / (union - inter + 1)
131 |
132 | return (wbce + wiou).mean()
133 |
134 | def adjust_lr(self, epoch):
135 | decay = self.decay_rate ** (epoch // self.decay_epoch) # TODO
136 | self.optimizer.get_lr().set_data(self.opt.lr * decay)
137 |
138 |
139 | if __name__ == "__main__":
140 | parser = argparse.ArgumentParser()
141 | parser.add_argument('--epoch', type=int, default=200, help='epoch number, default=30')
142 | parser.add_argument('--lr', type=float, default=1e-4, help='init learning rate, try `lr=1e-4`')
143 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size (Note: ~500MB per img in GPU)')
144 | parser.add_argument('--trainsize', type=int, default=352,
145 | help='the size of training image, try small resolutions for speed (like 256)')
146 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
147 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate per decay step')
148 | parser.add_argument('--decay_epoch', type=int, default=30, help='every N epochs decay lr')
149 | parser.add_argument('--gpu', type=int, default=0, help='choose which gpu you use')
150 | parser.add_argument('--save_epoch', type=int, default=5, help='every N epochs save your trained snapshot')
151 | parser.add_argument('--save_model', type=str, default='./Snapshot/CFANet/')
152 |
153 | parser.add_argument('--train_img_dir', type=str, default='./data/TrainDataset/images/')
154 | parser.add_argument('--train_gt_dir', type=str, default='./data/TrainDataset/masks/')
155 | parser.add_argument('--train_eg_dir', type=str, default='./data/TrainDataset/edges/')
156 |
157 | parser.add_argument('--test_img_dir', type=str, default='./data/TestDataset/CVC-300/images/')
158 | parser.add_argument('--test_gt_dir', type=str, default='./data/TestDataset/CVC-300/masks/')
159 | parser.add_argument('--test_eg_dir', type=str, default='./data/TestDataset/CVC-300/edges/')
160 |
161 | opt = parser.parse_args()
162 |
163 | ms.set_context(device_id=opt.gpu)
164 |
165 | save_path = opt.save_model
166 | os.makedirs(save_path, exist_ok=True)
167 |
168 | logging.basicConfig(filename=opt.save_model + '/log.log',
169 | format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level=logging.INFO, filemode='a',
170 | datefmt='%Y-%m-%d %I:%M:%S %p')
171 | logging.info("COD-Train")
172 | logging.info("Config")
173 | logging.info(
174 | 'epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};save_path:{};decay_epoch:{}'.format(opt.epoch,
175 | opt.lr,
176 | opt.batchsize,
177 | opt.trainsize,
178 | opt.clip,
179 | opt.decay_rate,
180 | opt.save_model,
181 | opt.decay_epoch))
182 |
183 | # TIPS: you also can use deeper network for better performance like channel=64
184 | model = CFANet(channel=64)
185 | # print('-' * 30, model, '-' * 30)
186 |
187 | total = sum([param.nelement() for param in model.get_parameters()])
188 | print('Number of parameter:%.2fM' % (total / 1e6))
189 |
190 | optimizer = nn.Adam(model.trainable_params(), learning_rate=opt.lr)
191 |
192 | train_loader = get_loader(opt.train_img_dir, opt.train_gt_dir, opt.train_eg_dir, batchsize=opt.batchsize,
193 | trainsize=opt.trainsize, num_workers=12)
194 | test_loader = get_loader_test(opt.test_img_dir, opt.test_gt_dir, testsize=opt.trainsize)
195 |
196 | total_step = train_loader.get_dataset_size()
197 | print('-' * 30, "\n[Training Dataset INFO]\nimg_dir: {}\ngt_dir: {}\nLearning Rate: {}\nBatch Size: {}\n"
198 | "Training Save: {}\ntotal_num: {}\n".format(opt.train_img_dir, opt.train_gt_dir, opt.lr,
199 | opt.batchsize, opt.save_model, total_step), '-' * 30)
200 |
201 | train = Trainer(train_loader, test_loader, model, optimizer, opt)
202 | train.train(opt.epoch)
203 |
--------------------------------------------------------------------------------
/mindspore/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import random
4 | import numpy as np
5 | from PIL import ImageEnhance
6 |
7 | from mindspore.dataset import transforms, vision, text
8 | from mindspore.dataset import GeneratorDataset
9 | from mindspore import ops
10 |
11 |
12 | # several data augumentation strategies
13 | def cv_random_flip(img, label, edge):
14 | # left right flip
15 | flip_flag = random.randint(0, 1)
16 | if flip_flag == 1:
17 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
18 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
19 | edge = edge.transpose(Image.FLIP_LEFT_RIGHT)
20 | return img, label, edge
21 |
22 |
23 | def randomCrop(image, label, edge):
24 | border = 30
25 | image_width = image.size[0]
26 | image_height = image.size[1]
27 | crop_win_width = np.random.randint(image_width - border, image_width)
28 | crop_win_height = np.random.randint(image_height - border, image_height)
29 | random_region = (
30 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
31 | (image_height + crop_win_height) >> 1)
32 | return image.crop(random_region), label.crop(random_region), edge.crop(random_region)
33 |
34 |
35 | def randomRotation(image, label, edge):
36 | mode = Image.BICUBIC
37 | if random.random() > 0.8:
38 | random_angle = np.random.randint(-15, 15)
39 | image = image.rotate(random_angle, mode)
40 | label = label.rotate(random_angle, mode)
41 | edge = edge.rotate(random_angle, mode)
42 | return image, label, edge
43 |
44 |
45 | def colorEnhance(image):
46 | bright_intensity = random.randint(5, 15) / 10.0
47 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
48 | contrast_intensity = random.randint(5, 15) / 10.0
49 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
50 | color_intensity = random.randint(0, 20) / 10.0
51 | image = ImageEnhance.Color(image).enhance(color_intensity)
52 | sharp_intensity = random.randint(0, 30) / 10.0
53 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
54 | return image
55 |
56 |
57 | def randomGaussian(image, mean=0.1, sigma=0.35):
58 | def gaussianNoisy(im, mean=mean, sigma=sigma):
59 | for _i in range(len(im)):
60 | im[_i] += random.gauss(mean, sigma)
61 | return im
62 |
63 | img = np.asarray(image)
64 | width, height = img.shape
65 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
66 | img = img.reshape([width, height])
67 | return Image.fromarray(np.uint8(img))
68 |
69 |
70 | def randomPeper(img):
71 | img = np.array(img)
72 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
73 | for i in range(noiseNum):
74 |
75 | randX = random.randint(0, img.shape[0] - 1)
76 |
77 | randY = random.randint(0, img.shape[1] - 1)
78 |
79 | if random.randint(0, 1) == 0:
80 |
81 | img[randX, randY] = 0
82 |
83 | else:
84 |
85 | img[randX, randY] = 255
86 | return Image.fromarray(img)
87 |
88 |
89 | def randomPeper_eg(img, edge):
90 | img = np.array(img)
91 | edge = np.array(edge)
92 |
93 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
94 | for i in range(noiseNum):
95 |
96 | randX = random.randint(0, img.shape[0] - 1)
97 |
98 | randY = random.randint(0, img.shape[1] - 1)
99 |
100 | if random.randint(0, 1) == 0:
101 |
102 | img[randX, randY] = 0
103 | edge[randX, randY] = 0
104 |
105 | else:
106 |
107 | img[randX, randY] = 255
108 | edge[randX, randY] = 255
109 |
110 | return Image.fromarray(img), Image.fromarray(edge)
111 |
112 |
113 | # dataset for training
114 | class PolypObjDataset:
115 | def __init__(self, image_root, gt_root, edge_root, trainsize):
116 | self.trainsize = trainsize
117 | # get filenames
118 |
119 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
120 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
121 | self.egs = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.jpg') or f.endswith('.png')]
122 |
123 | # self.grads = [grad_root + f for f in os.listdir(grad_root) if f.endswith('.jpg')
124 | # or f.endswith('.png')]
125 | # self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
126 | # or f.endswith('.png')]
127 | # sorted files
128 | self.images = sorted(self.images)
129 | self.gts = sorted(self.gts)
130 | self.egs = sorted(self.egs)
131 |
132 | # self.grads = sorted(self.grads)
133 | # self.depths = sorted(self.depths)
134 | # filter mathcing degrees of files
135 | self.filter_files()
136 |
137 | # get size of dataset
138 | self.size = len(self.images)
139 |
140 | def __getitem__(self, index):
141 | # read imgs/gts/grads/depths
142 | image = self.rgb_loader(self.images[index])
143 | gt = self.binary_loader(self.gts[index])
144 | eg = self.binary_loader(self.egs[index])
145 |
146 | # data augumentation
147 | image, gt, eg = cv_random_flip(image, gt, eg)
148 | image, gt, eg = randomCrop(image, gt, eg)
149 | image, gt, eg = randomRotation(image, gt, eg)
150 |
151 | image = colorEnhance(image)
152 | gt, eg = randomPeper_eg(gt, eg)
153 | return image, gt, eg
154 |
155 | def filter_files(self):
156 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
157 | images = []
158 | gts = []
159 | for img_path, gt_path in zip(self.images, self.gts):
160 | img = Image.open(img_path)
161 | gt = Image.open(gt_path)
162 | if img.size == gt.size:
163 | images.append(img_path)
164 | gts.append(gt_path)
165 | self.images = images
166 | self.gts = gts
167 |
168 | def rgb_loader(self, path):
169 | with open(path, 'rb') as f:
170 | img = Image.open(f)
171 | return img.convert('RGB')
172 |
173 | def binary_loader(self, path):
174 | with open(path, 'rb') as f:
175 | img = Image.open(f)
176 | return img.convert('L')
177 |
178 | def __len__(self):
179 | return self.size
180 |
181 |
182 | # dataloader for training
183 | def get_loader(image_root, gt_root, eg_root, batchsize, trainsize,
184 | shuffle=True, num_workers=12, pin_memory=True):
185 | # transforms
186 | img_transform = transforms.Compose([
187 | vision.Resize((trainsize, trainsize)),
188 | vision.ToTensor(),
189 | vision.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], is_hwc=False)])
190 | gt_transform = transforms.Compose([
191 | vision.Resize((trainsize, trainsize)),
192 | vision.ToTensor()])
193 |
194 | eg_transform = transforms.Compose([
195 | vision.Resize((trainsize, trainsize)),
196 | vision.ToTensor()])
197 |
198 | dataset = PolypObjDataset(image_root, gt_root, eg_root, trainsize)
199 | data_loader = GeneratorDataset(source=dataset, column_names=["images", "gts", "egs"])
200 | data_loader = data_loader.map(img_transform, ["images"])
201 | data_loader = data_loader.map(gt_transform, ["gts"])
202 | data_loader = data_loader.map(eg_transform, ["egs"])
203 |
204 | data_loader = data_loader.batch(batch_size=batchsize, num_parallel_workers=num_workers)
205 | return data_loader
206 |
207 |
208 | def get_loader_test(image_root, gt_root, testsize):
209 | testdata = test_dataset(image_root, gt_root, testsize)
210 | test_loader = GeneratorDataset(source=testdata, column_names=["image", "gt", "name"])
211 | transform = transforms.Compose([
212 | vision.Resize((testsize, testsize)),
213 | vision.ToTensor(),
214 | vision.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], is_hwc=False)])
215 | test_loader = test_loader.map(transform, "image")
216 | return test_loader
217 |
218 |
219 | class test_dataset:
220 | """load test dataset (batchsize=1)"""
221 |
222 | def __init__(self, image_root, gt_root, testsize):
223 | self.testsize = testsize
224 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
225 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
226 | self.images = sorted(self.images)
227 | self.gts = sorted(self.gts)
228 | # self.gt_transform = transforms.ToTensor()
229 | self.size = len(self.images)
230 |
231 | def __getitem__(self, index):
232 | image = self.rgb_loader(self.images[index])
233 | # image = self.transform(image).unsqueeze(0)
234 | gt = self.binary_loader(self.gts[index])
235 | name = self.images[index].split('/')[-1]
236 | if name.endswith('.jpg'):
237 | name = name.split('.jpg')[0] + '.png'
238 | return image, gt, name
239 |
240 | def rgb_loader(self, path):
241 | with open(path, 'rb') as f:
242 | img = Image.open(f)
243 | return img.convert('RGB')
244 |
245 | def binary_loader(self, path):
246 | with open(path, 'rb') as f:
247 | img = Image.open(f)
248 | return img.convert('L')
249 |
250 | def __len__(self):
251 | return self.size
252 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import os
5 | import argparse
6 |
7 | from lib.model import CFANet
8 | from utils.dataloader import get_loader,test_dataset
9 | from utils.eva_funcs import eval_Smeasure,eval_mae,numpy2tensor
10 |
11 |
12 | import scipy.io as scio
13 | import cv2
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--testsize', type=int, default=352, help='testing size')
17 | parser.add_argument('--pth_path', type=str, default='./checkpoint/CFANet.pth')
18 |
19 | for _data_name in ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']:
20 |
21 | print('-----------strating -------------')
22 |
23 | data_path = '/test/Polpy/Dataset/TestDataset/{}/'.format(_data_name)
24 | save_path = './Snapshot/seg_maps/{}/'.format(_data_name)
25 |
26 |
27 |
28 | opt = parser.parse_args()
29 | model = CFANet(channel=64).cuda()
30 |
31 |
32 | model.load_state_dict(torch.load(opt.pth_path))
33 | model.cuda()
34 | model.eval()
35 |
36 | os.makedirs(save_path, exist_ok=True)
37 |
38 | image_root = '{}/images/'.format(data_path)
39 | gt_root = '{}/masks/'.format(data_path)
40 |
41 |
42 |
43 | test_loader = test_dataset(image_root, gt_root, opt.testsize)
44 |
45 | for i in range(test_loader.size):
46 |
47 | print(['--------------processing-------------', i])
48 |
49 | image, gt, name = test_loader.load_data()
50 |
51 | gt = np.asarray(gt, np.float32)
52 | gt /= (gt.max() + 1e-8)
53 | image = image.cuda()
54 |
55 | _,_,_,res = model(image)
56 |
57 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
58 | res = res.sigmoid().data.cpu().numpy().squeeze()
59 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
60 |
61 | cv2.imwrite(save_path+name, res*255)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import argparse
5 |
6 | from lib.model import CFANet
7 | from utils.dataloader import get_loader,test_dataset
8 | from utils.trainer import adjust_lr
9 | from datetime import datetime
10 |
11 | import torch.nn.functional as F
12 | import numpy as np
13 | import logging
14 |
15 | best_mae = 1
16 | best_epoch = 0
17 |
18 |
19 | def structure_loss(pred, mask):
20 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask)
21 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
22 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
23 |
24 | pred = torch.sigmoid(pred)
25 | inter = ((pred*mask)*weit).sum(dim=(2,3))
26 | union = ((pred+mask)*weit).sum(dim=(2,3))
27 | wiou = 1-(inter+1)/(union-inter+1)
28 | return (wbce+wiou).mean()
29 |
30 |
31 | def train(train_loader, model, optimizer, epoch, opt, loss_func, total_step):
32 | """
33 | Training iteration
34 | :param train_loader:
35 | :param model:
36 | :param optimizer:
37 | :param epoch:
38 | :param opt:
39 | :param loss_func:
40 | :param total_step:
41 | :return:
42 | """
43 | model.train()
44 |
45 |
46 | size_rates = [0.75, 1, 1.25]
47 |
48 |
49 | for step, data_pack in enumerate(train_loader):
50 |
51 | images, gts, egs = data_pack
52 |
53 | for rate in size_rates:
54 |
55 | optimizer.zero_grad()
56 |
57 | images = images.cuda()
58 | gts = gts.cuda()
59 | egs = egs.cuda()
60 |
61 |
62 | # ---- rescale ----
63 | trainsize = int(round(opt.trainsize*rate/32)*32)
64 | if rate != 1:
65 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
66 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
67 | egs = F.upsample(egs, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
68 |
69 |
70 | cam_edge, sal_out1, sal_out2, sal_out3 = model(images)
71 | loss_edge = loss_func(cam_edge, egs)
72 | loss_sal1 = structure_loss(sal_out1, gts)
73 | loss_sal2 = structure_loss(sal_out2, gts)
74 | loss_sal3 = structure_loss(sal_out3, gts)
75 |
76 | loss_total = loss_edge + loss_sal1 + loss_sal2 + loss_sal3
77 |
78 | loss_total.backward()
79 | optimizer.step()
80 |
81 | if step % 10 == 0 or step == total_step:
82 | print('[{}] => [Epoch Num: {:03d}/{:03d}] => [Global Step: {:04d}/{:04d}] => [Loss_edge: {:.4f} Loss_sal1: {:0.4f} Loss_sal2: {:0.4f} Loss_sal3: {:0.4f} Loss_total: {:0.4f}]'.
83 | format(datetime.now(), epoch, opt.epoch, step, total_step, loss_edge.data, loss_sal1.data, loss_sal2.data, loss_sal3.data, loss_total.data))
84 |
85 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss_edge: {:.4f} Loss_sal1: {:0.4f} Loss_sal2: {:0.4f} Loss_sal3: {:0.4f} Loss_total: {:0.4f}'.
86 | format( epoch, opt.epoch, step, total_step, loss_edge.data, loss_sal1.data, loss_sal2.data, loss_sal3.data, loss_total.data))
87 |
88 |
89 | if (epoch) % opt.save_epoch == 0:
90 | torch.save(model.state_dict(), save_path + 'CODNet_%d.pth' % (epoch))
91 |
92 |
93 | def test(test_loader,model,epoch,save_path):
94 |
95 | global best_mae,best_epoch
96 | model.eval()
97 |
98 | with torch.no_grad():
99 | mae_sum=0
100 | for i in range(test_loader.size):
101 | image, gt, name = test_loader.load_data()
102 | gt = np.asarray(gt, np.float32)
103 |
104 | gt /= (gt.max() + 1e-8)
105 |
106 | image = image.cuda()
107 |
108 | _,_,_,res = model(image)
109 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
110 | res = res.sigmoid().data.cpu().numpy().squeeze()
111 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
112 | mae_sum +=np.sum(np.abs(res-gt))*1.0/(gt.shape[0]*gt.shape[1])
113 |
114 | mae = mae_sum / test_loader.size
115 |
116 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch,mae,best_mae,best_epoch))
117 | if epoch == 1:
118 | best_mae = mae
119 | else:
120 | if mae < best_mae:
121 | best_mae = mae
122 | best_epoch = epoch
123 |
124 | torch.save(model.state_dict(), save_path+'/Cod_best.pth')
125 | print('best epoch:{}'.format(epoch))
126 |
127 |
128 |
129 |
130 | if __name__ == "__main__":
131 | parser = argparse.ArgumentParser()
132 | parser.add_argument('--epoch', type=int, default=200, help='epoch number, default=30')
133 | parser.add_argument('--lr', type=float, default=1e-4, help='init learning rate, try `lr=1e-4`')
134 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size (Note: ~500MB per img in GPU)')
135 | parser.add_argument('--trainsize', type=int, default=352, help='the size of training image, try small resolutions for speed (like 256)')
136 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
137 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate per decay step')
138 | parser.add_argument('--decay_epoch', type=int, default=30, help='every N epochs decay lr')
139 | parser.add_argument('--gpu', type=int, default=0, help='choose which gpu you use')
140 | parser.add_argument('--save_epoch', type=int, default=5, help='every N epochs save your trained snapshot')
141 | parser.add_argument('--save_model', type=str, default='./Snapshot/CFANet/')
142 |
143 |
144 | parser.add_argument('--train_img_dir', type=str, default='./data/TrainDataset/images/')
145 | parser.add_argument('--train_gt_dir', type=str, default='./data/TrainDataset/masks/')
146 | parser.add_argument('--train_eg_dir', type=str, default='./data/TrainDataset/edges/')
147 |
148 | parser.add_argument('--test_img_dir', type=str, default='./data/TestDataset/CVC-300/images/')
149 | parser.add_argument('--test_gt_dir', type=str, default='./data/TestDataset/CVC-300/masks/')
150 | parser.add_argument('--test_eg_dir', type=str, default='./data/TestDataset/CVC-300/edges/')
151 |
152 |
153 | opt = parser.parse_args()
154 |
155 | torch.cuda.set_device(opt.gpu)
156 |
157 | save_path = opt.save_model
158 | os.makedirs(save_path, exist_ok=True)
159 |
160 |
161 | logging.basicConfig(filename=opt.save_model+'/log.log',format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level = logging.INFO,filemode='a',datefmt='%Y-%m-%d %I:%M:%S %p')
162 | logging.info("COD-Train")
163 | logging.info("Config")
164 | logging.info('epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};save_path:{};decay_epoch:{}'.format(opt.epoch,opt.lr,opt.batchsize,opt.trainsize,opt.clip,opt.decay_rate,opt.save_model,opt.decay_epoch))
165 |
166 |
167 |
168 | # TIPS: you also can use deeper network for better performance like channel=64
169 | model = CFANet(channel=64).cuda()
170 | #print('-' * 30, model, '-' * 30)
171 |
172 | total = sum([param.nelement() for param in model.parameters()])
173 | print('Number of parameter:%.2fM' % (total/1e6))
174 |
175 |
176 |
177 | optimizer = torch.optim.Adam(model.parameters(), opt.lr)
178 | LogitsBCE = torch.nn.BCEWithLogitsLoss()
179 |
180 | #net, optimizer = amp.initialize(model_SINet, optimizer, opt_level='O1') # NOTES: Ox not 0x
181 |
182 | train_loader = get_loader(opt.train_img_dir, opt.train_gt_dir, opt.train_eg_dir, batchsize=opt.batchsize,trainsize=opt.trainsize, num_workers=12)
183 | test_loader = test_dataset(opt.test_img_dir, opt.test_gt_dir, testsize=opt.trainsize)
184 |
185 | total_step = len(train_loader)
186 |
187 | print('-' * 30, "\n[Training Dataset INFO]\nimg_dir: {}\ngt_dir: {}\nLearning Rate: {}\nBatch Size: {}\n"
188 | "Training Save: {}\ntotal_num: {}\n".format(opt.train_img_dir, opt.train_gt_dir, opt.lr,
189 | opt.batchsize, opt.save_model, total_step), '-' * 30)
190 |
191 | for epoch_iter in range(1, opt.epoch):
192 |
193 | adjust_lr(optimizer, epoch_iter, opt.decay_rate, opt.decay_epoch)
194 |
195 | train(train_loader, model, optimizer, epoch_iter,opt, LogitsBCE, total_step)
196 | #test(test_loader, model, epoch_iter, opt.save_model)
197 |
198 |
199 |
200 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/CODNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/CODNet.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/CODNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/CODNet.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/SINet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/SINet.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/SearchAttention.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/SearchAttention.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/SearchAttention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/SearchAttention.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/res2net_v1b_base.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taozh2017/CFANet/7b4d91fd77b2d8857036f09172bc319d3dcefd5a/utils/__pycache__/res2net_v1b_base.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 |
9 |
10 | # several data augumentation strategies
11 | def cv_random_flip(img, label,edge):
12 | # left right flip
13 | flip_flag = random.randint(0, 1)
14 | if flip_flag == 1:
15 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
16 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
17 | edge = edge.transpose(Image.FLIP_LEFT_RIGHT)
18 | return img, label,edge
19 |
20 |
21 | def randomCrop(image, label,edge):
22 | border = 30
23 | image_width = image.size[0]
24 | image_height = image.size[1]
25 | crop_win_width = np.random.randint(image_width - border, image_width)
26 | crop_win_height = np.random.randint(image_height - border, image_height)
27 | random_region = (
28 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
29 | (image_height + crop_win_height) >> 1)
30 | return image.crop(random_region), label.crop(random_region), edge.crop(random_region)
31 |
32 |
33 | def randomRotation(image, label, edge):
34 | mode = Image.BICUBIC
35 | if random.random() > 0.8:
36 | random_angle = np.random.randint(-15, 15)
37 | image = image.rotate(random_angle, mode)
38 | label = label.rotate(random_angle, mode)
39 | edge = edge.rotate(random_angle, mode)
40 | return image, label, edge
41 |
42 |
43 | def colorEnhance(image):
44 | bright_intensity = random.randint(5, 15) / 10.0
45 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
46 | contrast_intensity = random.randint(5, 15) / 10.0
47 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
48 | color_intensity = random.randint(0, 20) / 10.0
49 | image = ImageEnhance.Color(image).enhance(color_intensity)
50 | sharp_intensity = random.randint(0, 30) / 10.0
51 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
52 | return image
53 |
54 |
55 | def randomGaussian(image, mean=0.1, sigma=0.35):
56 | def gaussianNoisy(im, mean=mean, sigma=sigma):
57 | for _i in range(len(im)):
58 | im[_i] += random.gauss(mean, sigma)
59 | return im
60 |
61 | img = np.asarray(image)
62 | width, height = img.shape
63 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
64 | img = img.reshape([width, height])
65 | return Image.fromarray(np.uint8(img))
66 |
67 |
68 | def randomPeper(img):
69 | img = np.array(img)
70 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
71 | for i in range(noiseNum):
72 |
73 | randX = random.randint(0, img.shape[0] - 1)
74 |
75 | randY = random.randint(0, img.shape[1] - 1)
76 |
77 | if random.randint(0, 1) == 0:
78 |
79 | img[randX, randY] = 0
80 |
81 | else:
82 |
83 | img[randX, randY] = 255
84 | return Image.fromarray(img)
85 |
86 |
87 | def randomPeper_eg(img, edge):
88 |
89 | img = np.array(img)
90 | edge = np.array(edge)
91 |
92 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
93 | for i in range(noiseNum):
94 |
95 | randX = random.randint(0, img.shape[0] - 1)
96 |
97 | randY = random.randint(0, img.shape[1] - 1)
98 |
99 | if random.randint(0, 1) == 0:
100 |
101 | img[randX, randY] = 0
102 | edge[randX, randY] = 0
103 |
104 | else:
105 |
106 | img[randX, randY] = 255
107 | edge[randX, randY] = 255
108 |
109 | return Image.fromarray(img), Image.fromarray(edge)
110 |
111 |
112 |
113 | # dataset for training
114 | class PolypObjDataset(data.Dataset):
115 | def __init__(self, image_root, gt_root, edge_root, trainsize):
116 | self.trainsize = trainsize
117 | # get filenames
118 |
119 |
120 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
121 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
122 | self.egs = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.jpg') or f.endswith('.png')]
123 |
124 |
125 |
126 | # self.grads = [grad_root + f for f in os.listdir(grad_root) if f.endswith('.jpg')
127 | # or f.endswith('.png')]
128 | # self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
129 | # or f.endswith('.png')]
130 | # sorted files
131 | self.images = sorted(self.images)
132 | self.gts = sorted(self.gts)
133 | self.egs = sorted(self.egs)
134 |
135 |
136 | # self.grads = sorted(self.grads)
137 | # self.depths = sorted(self.depths)
138 | # filter mathcing degrees of files
139 | self.filter_files()
140 | # transforms
141 | self.img_transform = transforms.Compose([
142 | transforms.Resize((self.trainsize, self.trainsize)),
143 | transforms.ToTensor(),
144 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
145 | self.gt_transform = transforms.Compose([
146 | transforms.Resize((self.trainsize, self.trainsize)),
147 | transforms.ToTensor()])
148 |
149 | self.eg_transform = transforms.Compose([
150 | transforms.Resize((self.trainsize, self.trainsize)),
151 | transforms.ToTensor()])
152 |
153 |
154 | # get size of dataset
155 | self.size = len(self.images)
156 |
157 | def __getitem__(self, index):
158 | # read imgs/gts/grads/depths
159 | image = self.rgb_loader(self.images[index])
160 | gt = self.binary_loader(self.gts[index])
161 | eg = self.binary_loader(self.egs[index])
162 |
163 | # data augumentation
164 | image, gt, eg = cv_random_flip(image, gt, eg)
165 | image, gt, eg = randomCrop(image, gt, eg)
166 | image, gt, eg = randomRotation(image, gt, eg)
167 |
168 | image = colorEnhance(image)
169 | gt,eg = randomPeper_eg(gt,eg)
170 |
171 | image = self.img_transform(image)
172 | gt = self.gt_transform(gt)
173 | eg = self.eg_transform(eg)
174 |
175 | return image, gt, eg
176 |
177 | def filter_files(self):
178 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
179 | images = []
180 | gts = []
181 | for img_path, gt_path in zip(self.images, self.gts):
182 | img = Image.open(img_path)
183 | gt = Image.open(gt_path)
184 | if img.size == gt.size:
185 | images.append(img_path)
186 | gts.append(gt_path)
187 | self.images = images
188 | self.gts = gts
189 |
190 | def rgb_loader(self, path):
191 | with open(path, 'rb') as f:
192 | img = Image.open(f)
193 | return img.convert('RGB')
194 |
195 | def binary_loader(self, path):
196 | with open(path, 'rb') as f:
197 | img = Image.open(f)
198 | return img.convert('L')
199 |
200 | def __len__(self):
201 | return self.size
202 |
203 |
204 | # dataloader for training
205 | def get_loader(image_root, gt_root, eg_root, batchsize, trainsize,
206 | shuffle=True, num_workers=12, pin_memory=True):
207 | dataset = PolypObjDataset(image_root, gt_root, eg_root,trainsize)
208 | data_loader = data.DataLoader(dataset=dataset,
209 | batch_size=batchsize,
210 | shuffle=shuffle,
211 | num_workers=num_workers,
212 | pin_memory=pin_memory)
213 | return data_loader
214 |
215 |
216 | # test dataset and loader
217 | class test_dataset_ori:
218 | def __init__(self, image_root, gt_root, testsize):
219 | self.testsize = testsize
220 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
221 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')]
222 | self.images = sorted(self.images)
223 | self.gts = sorted(self.gts)
224 | self.transform = transforms.Compose([
225 | transforms.Resize((self.testsize, self.testsize)),
226 | transforms.ToTensor(),
227 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
228 | self.gt_transform = transforms.ToTensor()
229 | self.size = len(self.images)
230 | self.index = 0
231 |
232 | def load_data(self):
233 | image = self.rgb_loader(self.images[self.index])
234 | image = self.transform(image).unsqueeze(0)
235 |
236 | gt = self.binary_loader(self.gts[self.index])
237 |
238 | name = self.images[self.index].split('/')[-1]
239 |
240 | image_for_post = self.rgb_loader(self.images[self.index])
241 | image_for_post = image_for_post.resize(gt.size)
242 |
243 | if name.endswith('.jpg'):
244 | name = name.split('.jpg')[0] + '.png'
245 |
246 | self.index += 1
247 | self.index = self.index % self.size
248 |
249 | return image, gt, name, np.array(image_for_post)
250 |
251 | def rgb_loader(self, path):
252 | with open(path, 'rb') as f:
253 | img = Image.open(f)
254 | return img.convert('RGB')
255 |
256 | def binary_loader(self, path):
257 | with open(path, 'rb') as f:
258 | img = Image.open(f)
259 | return img.convert('L')
260 |
261 | def __len__(self):
262 | return self.size
263 |
264 |
265 | class test_dataset:
266 | """load test dataset (batchsize=1)"""
267 | def __init__(self, image_root, gt_root, testsize):
268 | self.testsize = testsize
269 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
270 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
271 | self.images = sorted(self.images)
272 | self.gts = sorted(self.gts)
273 | self.transform = transforms.Compose([
274 | transforms.Resize((self.testsize, self.testsize)),
275 | transforms.ToTensor(),
276 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
277 | self.gt_transform = transforms.ToTensor()
278 | self.size = len(self.images)
279 | self.index = 0
280 |
281 | def load_data(self):
282 | image = self.rgb_loader(self.images[self.index])
283 | image = self.transform(image).unsqueeze(0)
284 | gt = self.binary_loader(self.gts[self.index])
285 | name = self.images[self.index].split('/')[-1]
286 | if name.endswith('.jpg'):
287 | name = name.split('.jpg')[0] + '.png'
288 | self.index += 1
289 | return image, gt, name
290 |
291 | def rgb_loader(self, path):
292 | with open(path, 'rb') as f:
293 | img = Image.open(f)
294 | return img.convert('RGB')
295 |
296 | def binary_loader(self, path):
297 | with open(path, 'rb') as f:
298 | img = Image.open(f)
299 | return img.convert('L')
300 |
301 |
302 |
--------------------------------------------------------------------------------
/utils/eva_funcs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Tue Sep 29 17:21:18 2020
5 |
6 | @author: taozhou
7 | """
8 |
9 | import os
10 | import time
11 |
12 | import numpy as np
13 | import torch
14 | from torchvision import transforms
15 |
16 |
17 | ###############################################################################
18 | ## basic funcs
19 | ###############################################################################
20 |
21 | def numpy2tensor(numpy):
22 | """
23 | convert numpy_array in cpu to tensor in gpu
24 | :param numpy:
25 | :return: torch.from_numpy(numpy).cuda()
26 | """
27 | return torch.from_numpy(numpy).cuda()
28 |
29 | def fun_eval_e(y_pred, y, num, cuda=True):
30 |
31 | if cuda:
32 | score = torch.zeros(num).cuda()
33 | else:
34 | score = torch.zeros(num)
35 |
36 | for i in range(num):
37 |
38 | fm = y_pred - y_pred.mean()
39 | gt = y - y.mean()
40 | align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20)
41 | enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4
42 | score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20)
43 | return score.max()
44 |
45 |
46 | def fun_eval_pr(y_pred, y, num, cuda=True):
47 |
48 | if cuda:
49 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda()
50 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
51 | else:
52 | prec, recall = torch.zeros(num), torch.zeros(num)
53 | thlist = torch.linspace(0, 1 - 1e-10, num)
54 |
55 | for i in range(num):
56 | y_temp = (y_pred >= thlist[i]).float()
57 | tp = (y_temp * y).sum()
58 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20)
59 | return prec, recall
60 |
61 |
62 | def fun_S_object(pred, gt):
63 |
64 | fg = torch.where(gt==0, torch.zeros_like(pred), pred)
65 | bg = torch.where(gt==1, torch.zeros_like(pred), 1-pred)
66 | o_fg = fun_object(fg, gt)
67 | o_bg = fun_object(bg, 1-gt)
68 | u = gt.mean()
69 | Q = u * o_fg + (1-u) * o_bg
70 | return Q
71 |
72 |
73 | def fun_object(pred, gt):
74 |
75 | temp = pred[gt == 1]
76 | x = temp.mean()
77 | sigma_x = temp.std()
78 | score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
79 |
80 | return score
81 |
82 |
83 | def fun_S_region(pred, gt):
84 |
85 | X, Y = fun_centroid(gt)
86 | gt1, gt2, gt3, gt4, w1, w2, w3, w4 = fun_divideGT(gt, X, Y)
87 | p1, p2, p3, p4 = fun_dividePrediction(pred, X, Y)
88 | Q1 = fun_ssim(p1, gt1)
89 | Q2 = fun_ssim(p2, gt2)
90 | Q3 = fun_ssim(p3, gt3)
91 | Q4 = fun_ssim(p4, gt4)
92 | Q = w1*Q1 + w2*Q2 + w3*Q3 + w4*Q4
93 |
94 | return Q
95 |
96 | def fun_centroid(gt, cuda=True):
97 |
98 | rows, cols = gt.size()[-2:]
99 | gt = gt.view(rows, cols)
100 |
101 | if gt.sum() == 0:
102 |
103 | if cuda:
104 | X = torch.eye(1).cuda() * round(cols / 2)
105 | Y = torch.eye(1).cuda() * round(rows / 2)
106 | else:
107 | X = torch.eye(1) * round(cols / 2)
108 | Y = torch.eye(1) * round(rows / 2)
109 |
110 | else:
111 | total = gt.sum()
112 |
113 | if cuda:
114 | i = torch.from_numpy(np.arange(0,cols)).cuda().float()
115 | j = torch.from_numpy(np.arange(0,rows)).cuda().float()
116 | else:
117 | i = torch.from_numpy(np.arange(0,cols)).float()
118 | j = torch.from_numpy(np.arange(0,rows)).float()
119 |
120 | X = torch.round((gt.sum(dim=0)*i).sum() / total)
121 | Y = torch.round((gt.sum(dim=1)*j).sum() / total)
122 |
123 | return X.long(), Y.long()
124 |
125 |
126 | def fun_divideGT(gt, X, Y):
127 |
128 | h, w = gt.size()[-2:]
129 | area = h*w
130 | gt = gt.view(h, w)
131 | LT = gt[:Y, :X]
132 | RT = gt[:Y, X:w]
133 | LB = gt[Y:h, :X]
134 | RB = gt[Y:h, X:w]
135 | X = X.float()
136 | Y = Y.float()
137 | w1 = X * Y / area
138 | w2 = (w - X) * Y / area
139 | w3 = X * (h - Y) / area
140 | w4 = 1 - w1 - w2 - w3
141 |
142 | return LT, RT, LB, RB, w1, w2, w3, w4
143 |
144 | def fun_dividePrediction(pred, X, Y):
145 |
146 | h, w = pred.size()[-2:]
147 | pred = pred.view(h, w)
148 | LT = pred[:Y, :X]
149 | RT = pred[:Y, X:w]
150 | LB = pred[Y:h, :X]
151 | RB = pred[Y:h, X:w]
152 |
153 | return LT, RT, LB, RB
154 |
155 |
156 | def fun_ssim(pred, gt):
157 |
158 | gt = gt.float()
159 | h, w = pred.size()[-2:]
160 | N = h*w
161 | x = pred.mean()
162 | y = gt.mean()
163 | sigma_x2 = ((pred - x)*(pred - x)).sum() / (N - 1 + 1e-20)
164 | sigma_y2 = ((gt - y)*(gt - y)).sum() / (N - 1 + 1e-20)
165 | sigma_xy = ((pred - x)*(gt - y)).sum() / (N - 1 + 1e-20)
166 |
167 | aplha = 4 * x * y *sigma_xy
168 | beta = (x*x + y*y) * (sigma_x2 + sigma_y2)
169 |
170 | if aplha != 0:
171 | Q = aplha / (beta + 1e-20)
172 | elif aplha == 0 and beta == 0:
173 | Q = 1.0
174 | else:
175 | Q = 0
176 |
177 | return Q
178 |
179 | ###############################################################################
180 | ## metric funcs
181 | ###############################################################################
182 | def eval_mae(pred,gt,cuda=True):
183 |
184 | with torch.no_grad():
185 |
186 | trans = transforms.Compose([transforms.ToTensor()])
187 |
188 | if cuda:
189 | pred = pred.cuda()
190 | gt = gt.cuda()
191 | # else:
192 | # pred = trans(pred)
193 | # gt = trans(gt)
194 |
195 | mae = torch.abs(pred - gt).mean()
196 |
197 | return mae.cpu().detach().numpy()
198 |
199 |
200 | def eval_Smeasure(pred,gt,cuda=True):
201 |
202 | alpha, avg_q, img_num = 0.5, 0.0, 0.0
203 |
204 | with torch.no_grad():
205 |
206 | trans = transforms.Compose([transforms.ToTensor()])
207 |
208 | y = gt.mean()
209 |
210 | ##
211 | if y == 0:
212 | x = pred.mean()
213 | Q = 1.0 - x
214 | elif y == 1:
215 | x = pred.mean()
216 | Q = x
217 | else:
218 | Q = alpha * fun_S_object(pred, gt) + (1-alpha) * fun_S_region(pred, gt)
219 | if Q.item() < 0:
220 | Q = torch.FLoatTensor([0.0])
221 |
222 | return Q.item()
223 |
224 |
225 | def eval_fmeasure(pred, gt, cuda=True):
226 | print('eval[FMeasure]:{} dataset with {} method.'.format(self.dataset, self.method))
227 |
228 | beta2 = 0.3
229 | avg_p, avg_r, img_num = 0.0, 0.0, 0.0
230 |
231 | ##
232 | with torch.no_grad():
233 | trans = transforms.Compose([transforms.ToTensor()])
234 | if cuda:
235 | pred = trans(pred).cuda()
236 | gt = trans(gt).cuda()
237 | else:
238 | pred = trans(pred)
239 | gt = trans(gt)
240 |
241 | prec, recall = fun_eval_pr(pred, gt, 255)
242 |
243 | return prec, recall
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 | class Eval_thread():
256 | def __init__(self, loader, method, dataset, output_dir, cuda):
257 | self.loader = loader
258 | self.method = method
259 | self.dataset = dataset
260 | self.cuda = cuda
261 | self.logfile = os.path.join(output_dir, 'result.txt')
262 | def run(self):
263 | start_time = time.time()
264 | mae = self.Eval_mae()
265 | s = self.Eval_Smeasure()
266 |
267 | return mae,s
268 |
269 | #max_f = self.Eval_fmeasure()
270 | #max_e = self.Eval_Emeasure()
271 |
272 | #self.LOG('{} dataset with {} method get {:.4f} mae, {:.4f} max-fmeasure, {:.4f} max-Emeasure, {:.4f} S-measure..\n'.format(self.dataset, self.method, mae, max_f, max_e, s))
273 | #return '[cost:{:.4f}s]{} dataset with {} method get {:.4f} mae, {:.4f} max-fmeasure, {:.4f} max-Emeasure, {:.4f} S-measure..'.format(time.time()-start_time, self.dataset, self.method, mae, max_f, max_e, s)
274 |
275 | def Eval_mae(self):
276 |
277 | with torch.no_grad():
278 | trans = transforms.Compose([transforms.ToTensor()])
279 | for pred, gt in self.loader:
280 | if self.cuda:
281 |
282 | pred = trans(pred).cuda()
283 | gt = trans(gt).cuda()
284 | else:
285 | pred = trans(pred)
286 | gt = trans(gt)
287 | mea = torch.abs(pred - gt).mean()
288 | if mea == mea: # for Nan
289 | avg_mae += mea
290 | img_num += 1.0
291 | avg_mae /= img_num
292 |
293 | return avg_mae.item()
294 |
295 | def Eval_fmeasure(self):
296 | print('eval[FMeasure]:{} dataset with {} method.'.format(self.dataset, self.method))
297 | beta2 = 0.3
298 | avg_p, avg_r, img_num = 0.0, 0.0, 0.0
299 | with torch.no_grad():
300 | trans = transforms.Compose([transforms.ToTensor()])
301 | for pred, gt in self.loader:
302 | if self.cuda:
303 | pred = trans(pred).cuda()
304 | gt = trans(gt).cuda()
305 | else:
306 | pred = trans(pred)
307 | gt = trans(gt)
308 | prec, recall = self._eval_pr(pred, gt, 255)
309 | avg_p += prec
310 | avg_r += recall
311 | img_num += 1.0
312 | avg_p /= img_num
313 | avg_r /= img_num
314 | score = (1 + beta2) * avg_p * avg_r / (beta2 * avg_p + avg_r)
315 | score[score != score] = 0 # for Nan
316 |
317 | return score.max().item()
318 | def Eval_Emeasure(self):
319 | print('eval[EMeasure]:{} dataset with {} method.'.format(self.dataset, self.method))
320 | avg_e, img_num = 0.0, 0.0
321 | with torch.no_grad():
322 | trans = transforms.Compose([transforms.ToTensor()])
323 | for pred, gt in self.loader:
324 | if self.cuda:
325 | pred = trans(pred).cuda()
326 | gt = trans(gt).cuda()
327 | else:
328 | pred = trans(pred)
329 | gt = trans(gt)
330 | max_e = self._eval_e(pred, gt, 255)
331 | if max_e == max_e:
332 | avg_e += max_e
333 | img_num += 1.0
334 |
335 | avg_e /= img_num
336 | return avg_e
337 | def Eval_Smeasure(self):
338 | #print('eval[SMeasure]:{} dataset with {} method.'.format(self.dataset, self.method))
339 | alpha, avg_q, img_num = 0.5, 0.0, 0.0
340 | with torch.no_grad():
341 | trans = transforms.Compose([transforms.ToTensor()])
342 | for pred, gt in self.loader:
343 | if self.cuda:
344 | pred = trans(pred).cuda()
345 | gt = trans(gt).cuda()
346 | else:
347 | pred = trans(pred)
348 | gt = trans(gt)
349 | y = gt.mean()
350 | if y == 0:
351 | x = pred.mean()
352 | Q = 1.0 - x
353 | elif y == 1:
354 | x = pred.mean()
355 | Q = x
356 | else:
357 | Q = alpha * self._S_object(pred, gt) + (1-alpha) * self._S_region(pred, gt)
358 | if Q.item() < 0:
359 | Q = torch.FLoatTensor([0.0])
360 | img_num += 1.0
361 | avg_q += Q.item()
362 | avg_q /= img_num
363 |
364 | return avg_q
365 | def LOG(self, output):
366 | with open(self.logfile, 'a') as f:
367 | f.write(output)
368 |
369 | def _eval_e(self, y_pred, y, num):
370 | if self.cuda:
371 | score = torch.zeros(num).cuda()
372 | else:
373 | score = torch.zeros(num)
374 | for i in range(num):
375 | fm = y_pred - y_pred.mean()
376 | gt = y - y.mean()
377 | align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20)
378 | enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4
379 | score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20)
380 | return score.max()
381 |
382 | def _eval_pr(self, y_pred, y, num):
383 | if self.cuda:
384 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda()
385 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
386 | else:
387 | prec, recall = torch.zeros(num), torch.zeros(num)
388 | thlist = torch.linspace(0, 1 - 1e-10, num)
389 | for i in range(num):
390 | y_temp = (y_pred >= thlist[i]).float()
391 | tp = (y_temp * y).sum()
392 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20)
393 | return prec, recall
394 |
395 | def _S_object(self, pred, gt):
396 | fg = torch.where(gt==0, torch.zeros_like(pred), pred)
397 | bg = torch.where(gt==1, torch.zeros_like(pred), 1-pred)
398 | o_fg = self._object(fg, gt)
399 | o_bg = self._object(bg, 1-gt)
400 | u = gt.mean()
401 | Q = u * o_fg + (1-u) * o_bg
402 | return Q
403 |
404 | def _object(self, pred, gt):
405 | temp = pred[gt == 1]
406 | x = temp.mean()
407 | sigma_x = temp.std()
408 | score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
409 |
410 | return score
411 |
412 | def _S_region(self, pred, gt):
413 | X, Y = self._centroid(gt)
414 | gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divideGT(gt, X, Y)
415 | p1, p2, p3, p4 = self._dividePrediction(pred, X, Y)
416 | Q1 = self._ssim(p1, gt1)
417 | Q2 = self._ssim(p2, gt2)
418 | Q3 = self._ssim(p3, gt3)
419 | Q4 = self._ssim(p4, gt4)
420 | Q = w1*Q1 + w2*Q2 + w3*Q3 + w4*Q4
421 | # print(Q)
422 | return Q
423 |
424 | def _centroid(self, gt):
425 | rows, cols = gt.size()[-2:]
426 | gt = gt.view(rows, cols)
427 | if gt.sum() == 0:
428 | if self.cuda:
429 | X = torch.eye(1).cuda() * round(cols / 2)
430 | Y = torch.eye(1).cuda() * round(rows / 2)
431 | else:
432 | X = torch.eye(1) * round(cols / 2)
433 | Y = torch.eye(1) * round(rows / 2)
434 | else:
435 | total = gt.sum()
436 | if self.cuda:
437 | i = torch.from_numpy(np.arange(0,cols)).cuda().float()
438 | j = torch.from_numpy(np.arange(0,rows)).cuda().float()
439 | else:
440 | i = torch.from_numpy(np.arange(0,cols)).float()
441 | j = torch.from_numpy(np.arange(0,rows)).float()
442 | X = torch.round((gt.sum(dim=0)*i).sum() / total)
443 | Y = torch.round((gt.sum(dim=1)*j).sum() / total)
444 | return X.long(), Y.long()
445 |
446 | def _divideGT(self, gt, X, Y):
447 | h, w = gt.size()[-2:]
448 | area = h*w
449 | gt = gt.view(h, w)
450 | LT = gt[:Y, :X]
451 | RT = gt[:Y, X:w]
452 | LB = gt[Y:h, :X]
453 | RB = gt[Y:h, X:w]
454 | X = X.float()
455 | Y = Y.float()
456 | w1 = X * Y / area
457 | w2 = (w - X) * Y / area
458 | w3 = X * (h - Y) / area
459 | w4 = 1 - w1 - w2 - w3
460 | return LT, RT, LB, RB, w1, w2, w3, w4
461 |
462 | def _dividePrediction(self, pred, X, Y):
463 | h, w = pred.size()[-2:]
464 | pred = pred.view(h, w)
465 | LT = pred[:Y, :X]
466 | RT = pred[:Y, X:w]
467 | LB = pred[Y:h, :X]
468 | RB = pred[Y:h, X:w]
469 | return LT, RT, LB, RB
470 |
471 | def _ssim(self, pred, gt):
472 | gt = gt.float()
473 | h, w = pred.size()[-2:]
474 | N = h*w
475 | x = pred.mean()
476 | y = gt.mean()
477 | sigma_x2 = ((pred - x)*(pred - x)).sum() / (N - 1 + 1e-20)
478 | sigma_y2 = ((gt - y)*(gt - y)).sum() / (N - 1 + 1e-20)
479 | sigma_xy = ((pred - x)*(gt - y)).sum() / (N - 1 + 1e-20)
480 |
481 | aplha = 4 * x * y *sigma_xy
482 | beta = (x*x + y*y) * (sigma_x2 + sigma_y2)
483 |
484 | if aplha != 0:
485 | Q = aplha / (beta + 1e-20)
486 | elif aplha == 0 and beta == 0:
487 | Q = 1.0
488 | else:
489 | Q = 0
490 | return Q
491 |
--------------------------------------------------------------------------------
/utils/evaluator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 | from torchvision import transforms
7 |
8 |
9 | class Eval_thread():
10 | def __init__(self, loader, method, dataset, output_dir, cuda):
11 | self.loader = loader
12 | self.method = method
13 | self.dataset = dataset
14 | self.cuda = cuda
15 | self.logfile = os.path.join(output_dir, 'result.txt')
16 | def run(self):
17 | start_time = time.time()
18 | mae = self.Eval_mae()
19 | s = self.Eval_Smeasure()
20 |
21 |
22 |
23 | max_f = self.Eval_fmeasure()
24 | max_e = self.Eval_Emeasure()
25 |
26 | return mae,s,max_f,max_e
27 |
28 |
29 | #self.LOG('{} dataset with {} method get {:.4f} mae, {:.4f} max-fmeasure, {:.4f} max-Emeasure, {:.4f} S-measure..\n'.format(self.dataset, self.method, mae, max_f, max_e, s))
30 | #return '[cost:{:.4f}s]{} dataset with {} method get {:.4f} mae, {:.4f} max-fmeasure, {:.4f} max-Emeasure, {:.4f} S-measure..'.format(time.time()-start_time, self.dataset, self.method, mae, max_f, max_e, s)
31 |
32 | def Eval_mae(self):
33 | #print('eval[MAE]:{} dataset with {} method.'.format(self.dataset, self.method))
34 | avg_mae, img_num = 0.0, 0.0
35 | with torch.no_grad():
36 | trans = transforms.Compose([transforms.ToTensor()])
37 | for pred, gt in self.loader:
38 | if self.cuda:
39 | pred = trans(pred).cuda()
40 | gt = trans(gt).cuda()
41 | else:
42 | pred = trans(pred)
43 | gt = trans(gt)
44 | mea = torch.abs(pred - gt).mean()
45 | if mea == mea: # for Nan
46 | avg_mae += mea
47 | img_num += 1.0
48 | avg_mae /= img_num
49 |
50 | return avg_mae.item()
51 |
52 | def Eval_fmeasure(self):
53 | print('eval[FMeasure]:{} dataset with {} method.'.format(self.dataset, self.method))
54 | beta2 = 0.3
55 | avg_p, avg_r, img_num = 0.0, 0.0, 0.0
56 | with torch.no_grad():
57 | trans = transforms.Compose([transforms.ToTensor()])
58 | for pred, gt in self.loader:
59 | if self.cuda:
60 | pred = trans(pred).cuda()
61 | gt = trans(gt).cuda()
62 | else:
63 | pred = trans(pred)
64 | gt = trans(gt)
65 | prec, recall = self._eval_pr(pred, gt, 255)
66 | avg_p += prec
67 | avg_r += recall
68 | img_num += 1.0
69 | avg_p /= img_num
70 | avg_r /= img_num
71 | score = (1 + beta2) * avg_p * avg_r / (beta2 * avg_p + avg_r)
72 | score[score != score] = 0 # for Nan
73 |
74 | return score.max().item()
75 | def Eval_Emeasure(self):
76 | print('eval[EMeasure]:{} dataset with {} method.'.format(self.dataset, self.method))
77 | avg_e, img_num = 0.0, 0.0
78 | with torch.no_grad():
79 | trans = transforms.Compose([transforms.ToTensor()])
80 | for pred, gt in self.loader:
81 | if self.cuda:
82 | pred = trans(pred).cuda()
83 | gt = trans(gt).cuda()
84 | else:
85 | pred = trans(pred)
86 | gt = trans(gt)
87 | max_e = self._eval_e(pred, gt, 255)
88 | if max_e == max_e:
89 | avg_e += max_e
90 | img_num += 1.0
91 |
92 | avg_e /= img_num
93 | return avg_e.item()
94 | def Eval_Smeasure(self):
95 | #print('eval[SMeasure]:{} dataset with {} method.'.format(self.dataset, self.method))
96 | alpha, avg_q, img_num = 0.5, 0.0, 0.0
97 | with torch.no_grad():
98 | trans = transforms.Compose([transforms.ToTensor()])
99 | for pred, gt in self.loader:
100 | if self.cuda:
101 | pred = trans(pred).cuda()
102 | gt = trans(gt).cuda()
103 | else:
104 | pred = trans(pred)
105 | gt = trans(gt)
106 | y = gt.mean()
107 | if y == 0:
108 | x = pred.mean()
109 | Q = 1.0 - x
110 | elif y == 1:
111 | x = pred.mean()
112 | Q = x
113 | else:
114 | Q = alpha * self._S_object(pred, gt) + (1-alpha) * self._S_region(pred, gt)
115 | if Q.item() < 0:
116 | Q = torch.FLoatTensor([0.0])
117 | img_num += 1.0
118 | avg_q += Q.item()
119 | avg_q /= img_num
120 |
121 | return avg_q
122 | def LOG(self, output):
123 | with open(self.logfile, 'a') as f:
124 | f.write(output)
125 |
126 | def _eval_e(self, y_pred, y, num):
127 | if self.cuda:
128 | score = torch.zeros(num).cuda()
129 | else:
130 | score = torch.zeros(num)
131 | for i in range(num):
132 | fm = y_pred - y_pred.mean()
133 | gt = y - y.mean()
134 | align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20)
135 | enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4
136 | score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20)
137 | return score.max()
138 |
139 | def _eval_pr(self, y_pred, y, num):
140 | if self.cuda:
141 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda()
142 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda()
143 | else:
144 | prec, recall = torch.zeros(num), torch.zeros(num)
145 | thlist = torch.linspace(0, 1 - 1e-10, num)
146 | for i in range(num):
147 | y_temp = (y_pred >= thlist[i]).float()
148 | tp = (y_temp * y).sum()
149 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20)
150 | return prec, recall
151 |
152 | def _S_object(self, pred, gt):
153 | fg = torch.where(gt==0, torch.zeros_like(pred), pred)
154 | bg = torch.where(gt==1, torch.zeros_like(pred), 1-pred)
155 | o_fg = self._object(fg, gt)
156 | o_bg = self._object(bg, 1-gt)
157 | u = gt.mean()
158 | Q = u * o_fg + (1-u) * o_bg
159 | return Q
160 |
161 | def _object(self, pred, gt):
162 | temp = pred[gt == 1]
163 | x = temp.mean()
164 | sigma_x = temp.std()
165 | score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
166 |
167 | return score
168 |
169 | def _S_region(self, pred, gt):
170 | X, Y = self._centroid(gt)
171 | gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divideGT(gt, X, Y)
172 | p1, p2, p3, p4 = self._dividePrediction(pred, X, Y)
173 | Q1 = self._ssim(p1, gt1)
174 | Q2 = self._ssim(p2, gt2)
175 | Q3 = self._ssim(p3, gt3)
176 | Q4 = self._ssim(p4, gt4)
177 | Q = w1*Q1 + w2*Q2 + w3*Q3 + w4*Q4
178 | # print(Q)
179 | return Q
180 |
181 | def _centroid(self, gt):
182 | rows, cols = gt.size()[-2:]
183 | gt = gt.view(rows, cols)
184 | if gt.sum() == 0:
185 | if self.cuda:
186 | X = torch.eye(1).cuda() * round(cols / 2)
187 | Y = torch.eye(1).cuda() * round(rows / 2)
188 | else:
189 | X = torch.eye(1) * round(cols / 2)
190 | Y = torch.eye(1) * round(rows / 2)
191 | else:
192 | total = gt.sum()
193 | if self.cuda:
194 | i = torch.from_numpy(np.arange(0,cols)).cuda().float()
195 | j = torch.from_numpy(np.arange(0,rows)).cuda().float()
196 | else:
197 | i = torch.from_numpy(np.arange(0,cols)).float()
198 | j = torch.from_numpy(np.arange(0,rows)).float()
199 | X = torch.round((gt.sum(dim=0)*i).sum() / total)
200 | Y = torch.round((gt.sum(dim=1)*j).sum() / total)
201 | return X.long(), Y.long()
202 |
203 | def _divideGT(self, gt, X, Y):
204 | h, w = gt.size()[-2:]
205 | area = h*w
206 | gt = gt.view(h, w)
207 | LT = gt[:Y, :X]
208 | RT = gt[:Y, X:w]
209 | LB = gt[Y:h, :X]
210 | RB = gt[Y:h, X:w]
211 | X = X.float()
212 | Y = Y.float()
213 | w1 = X * Y / area
214 | w2 = (w - X) * Y / area
215 | w3 = X * (h - Y) / area
216 | w4 = 1 - w1 - w2 - w3
217 | return LT, RT, LB, RB, w1, w2, w3, w4
218 |
219 | def _dividePrediction(self, pred, X, Y):
220 | h, w = pred.size()[-2:]
221 | pred = pred.view(h, w)
222 | LT = pred[:Y, :X]
223 | RT = pred[:Y, X:w]
224 | LB = pred[Y:h, :X]
225 | RB = pred[Y:h, X:w]
226 | return LT, RT, LB, RB
227 |
228 | def _ssim(self, pred, gt):
229 | gt = gt.float()
230 | h, w = pred.size()[-2:]
231 | N = h*w
232 | x = pred.mean()
233 | y = gt.mean()
234 | sigma_x2 = ((pred - x)*(pred - x)).sum() / (N - 1 + 1e-20)
235 | sigma_y2 = ((gt - y)*(gt - y)).sum() / (N - 1 + 1e-20)
236 | sigma_xy = ((pred - x)*(gt - y)).sum() / (N - 1 + 1e-20)
237 |
238 | aplha = 4 * x * y *sigma_xy
239 | beta = (x*x + y*y) * (sigma_x2 + sigma_y2)
240 |
241 | if aplha != 0:
242 | Q = aplha / (beta + 1e-20)
243 | elif aplha == 0 and beta == 0:
244 | Q = 1.0
245 | else:
246 | Q = 0
247 | return Q
248 |
--------------------------------------------------------------------------------
/utils/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from datetime import datetime
4 | import os
5 | #from apex import amp
6 | import torch.nn.functional as F
7 |
8 |
9 | def eval_mae(y_pred, y):
10 | """
11 | evaluate MAE (for test or validation phase)
12 | :param y_pred:
13 | :param y:
14 | :return: Mean Absolute Error
15 | """
16 | return torch.abs(y_pred - y).mean()
17 |
18 |
19 |
20 |
21 | def numpy2tensor(numpy):
22 | """
23 | convert numpy_array in cpu to tensor in gpu
24 | :param numpy:
25 | :return: torch.from_numpy(numpy).cuda()
26 | """
27 | return torch.from_numpy(numpy).cuda()
28 |
29 |
30 | def clip_gradient(optimizer, grad_clip):
31 | """
32 | recalibrate the misdirection in the training
33 | :param optimizer:
34 | :param grad_clip:
35 | :return:
36 | """
37 | for group in optimizer.param_groups:
38 | for param in group['params']:
39 | if param.grad is not None:
40 | param.grad.data.clamp_(-grad_clip, grad_clip)
41 |
42 |
43 | def adjust_lr(optimizer, epoch, decay_rate=0.1, decay_epoch=30):
44 | decay = decay_rate ** (epoch // decay_epoch)
45 | for param_group in optimizer.param_groups:
46 | param_group['lr'] *= decay
47 |
48 |
--------------------------------------------------------------------------------