├── graph_cuts_loss.png
├── README.md
└── graph_cuts_loss.py
/graph_cuts_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zzhenggit/graph_cuts_loss/HEAD/graph_cuts_loss.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph Cuts Loss to Boost Model Accuracy and Generalizability for Medical Image Segmentation
2 | Pytorch implementation for our ICCVW paper '[Graph Cuts Loss to Boost Model Accuracy and Generalizability for Medical Image Segmentation](https://openaccess.thecvf.com/content/ICCV2021W/CVAMD/papers/Zheng_Graph_Cuts_Loss_To_Boost_Model_Accuracy_and_Generalizability_for_ICCVW_2021_paper.pdf)'.
3 |
4 |
5 |
6 |
7 |
8 | ## Acknowledgement
9 | **Implementations of the losses cited in our work are public avaliable.** \
10 | clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation [(clDice)](https://github.com/jocpae/clDice) \
11 | An Elastic Interaction-Based Loss Function for Medical Image Segmentation [(EIB)](https://github.com/charrywhite/elastic_interaction_based_loss) \
12 | Learning Active Contour Models for Medical Image Segmentation [(AC)](https://github.com/xuuuuuuchen/Active-Contour-Loss)\
13 | Learning Euler's Elastica Model for Medical Image Segmentation [(ACE)](https://github.com/HiLab-git/ACELoss) \
14 | Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural Networks [(HD)](https://github.com/JunMa11/SegWithDistMap/blob/5a67153bc730eb82de396ef63f57594f558e23cd/code/train_LA_HD.py#L106) \
15 | Boundary loss for highly unbalanced segmentation [(BD)](https://github.com/LIVIAETS/boundary-loss)
16 |
17 | ## Note
18 | Contact: Zhou Zheng (zzheng@mori.m.is.nagoya-u.ac.jp)
--------------------------------------------------------------------------------
/graph_cuts_loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn
4 |
5 |
6 | # original 2D GC loss with no approximation
7 | class GC_2D_Original(torch.nn.Module):
8 |
9 | def __init__(self, lmda, sigma):
10 | super(GC_2D_Original, self).__init__()
11 | self.lmda = lmda
12 | self.sigma = sigma
13 |
14 | def forward(self, input, target):
15 | # input: B * C * H * W, after sigmoid operation
16 | # target: B * C * H * W
17 |
18 | # region term equals to BCE
19 | bce = torch.nn.BCELoss()
20 | region_term = bce(input=input, target=target)
21 |
22 | # boundary_term
23 | '''
24 | x5 x1 x6
25 | x2 x x4
26 | x7 x3 x8
27 | '''
28 | # vertical: x <-> x1, x3 <-> x1
29 | target_vert = torch.abs(target[:, :, 1:, :] - target[:, :, :-1, :]) # delta(yu, yv)
30 | input_vert = input[:, :, 1:, :] - input[:, :, :-1, :] # pu - pv
31 |
32 | # horizontal: x <-> x2, x4 <-> x
33 | target_hori = torch.abs(target[:, :, :, 1:] - target[:, :, :, :-1]) # delta(yu, yv)
34 | input_hori = input[:, :, :, 1:] - input[:, :, :, :-1] # pu - pv
35 |
36 | # diagonal1: x <-> x5, x8 <-> x
37 | target_diag1 = torch.abs(target[:, :, 1:, 1:] - target[:, :, :-1, :-1]) # delta(yu, yv)
38 | input_diag1 = input[:, :, 1:, 1:] - input[:, :, :-1, :-1] # pu - pv
39 |
40 | # diagonal2: x <-> x7, x6 <-> x
41 | target_diag2 = torch.abs(target[:, :, 1:, :-1] - target[:, :, :-1, 1:]) # delta(yu, yv)
42 | input_diag2 = input[:, :, 1:, :-1] - input[:, :, :-1, 1:] # pu - pv
43 |
44 | dist1 = 1.0 # dist(u, v), e.g. x <-> x1
45 | dist2 = 2.0 ** 0.5 # dist(u, v) , e.g. x <-> x6
46 |
47 | p1 = torch.exp(-(input_vert ** 2) / (2 * self.sigma * self.sigma)) / dist1 * target_vert
48 | p2 = torch.exp(-(input_hori ** 2) / (2 * self.sigma * self.sigma)) / dist1 * target_hori
49 |
50 | p3 = torch.exp(-(input_diag1 ** 2) / (2 * self.sigma * self.sigma)) / dist2 * target_diag1
51 | p4 = torch.exp(-(input_diag2 ** 2) / (2 * self.sigma * self.sigma)) / dist2 * target_diag2
52 |
53 | boundary_term = (torch.sum(p1) / torch.sum(target_vert) +
54 | torch.sum(p2) / torch.sum(target_hori) +
55 | torch.sum(p3) / torch.sum(target_diag1) +
56 | torch.sum(p4) / torch.sum(target_diag2)) / 4 # equation (5)
57 |
58 | return self.lmda * region_term + boundary_term
59 |
60 |
61 | # 2D GC loss with boundary approximation in equation (7) to eliminate sigma
62 | class GC_2D(torch.nn.Module):
63 |
64 | def __init__(self, lmda):
65 | super(GC_2D, self).__init__()
66 | self.lmda = lmda
67 |
68 | def forward(self, input, target):
69 | # input: B * C * H * W, after sigmoid operation
70 | # target: B * C * H * W
71 |
72 | # region term equals to BCE
73 | bce = torch.nn.BCELoss()
74 | region_term = bce(input=input, target=target)
75 |
76 | # boundary_term
77 | '''
78 | x5 x1 x6
79 | x2 x x4
80 | x7 x3 x8
81 | '''
82 | # vertical: x <-> x1, x3 <-> x1
83 | target_vert = torch.abs(target[:, :, 1:, :] - target[:, :, :-1, :]) # delta(yu, yv)
84 | input_vert = torch.abs(input[:, :, 1:, :] - input[:, :, :-1, :]) # |pu - pv|
85 |
86 | # horizontal: x <-> x2, x4 <-> x
87 | target_hori = torch.abs(target[:, :, :, 1:] - target[:, :, :, :-1]) # delta(yu, yv)
88 | input_hori = torch.abs(input[:, :, :, 1:] - input[:, :, :, :-1]) # |pu - pv|
89 |
90 | # diagonal1: x <-> x5, x8 <-> x
91 | target_diag1 = torch.abs(target[:, :, 1:, 1:] - target[:, :, :-1, :-1]) # delta(yu, yv)
92 | input_diag1 = torch.abs(input[:, :, 1:, 1:] - input[:, :, :-1, :-1]) # |pu - pv|
93 |
94 | # diagonal2: x <-> x7, x6 <-> x
95 | target_diag2 = torch.abs(target[:, :, 1:, :-1] - target[:, :, :-1, 1:]) # delta(yu, yv)
96 | input_diag2 = torch.abs(input[:, :, 1:, :-1] - input[:, :, :-1, 1:]) # |pu - pv|
97 |
98 | p1 = input_vert * target_vert
99 | p2 = input_hori * target_hori
100 | p3 = input_diag1 * target_diag1
101 | p4 = input_diag2 * target_diag2
102 |
103 | boundary_term = 1 - (torch.sum(p1) / torch.sum(target_vert) +
104 | torch.sum(p2) / torch.sum(target_hori) +
105 | torch.sum(p3) / torch.sum(target_diag1) +
106 | torch.sum(p4) / torch.sum(target_diag2)) / 4 # equation (7), and normalized to (0,1)
107 |
108 | return self.lmda * region_term + boundary_term
109 |
110 |
111 | # 3D GC loss with boundary approximation in equation (7) to eliminate sigma
112 | class GC_3D_v1(torch.nn.Module):
113 | def __init__(self, lmda):
114 | super(GC_3D_v1, self).__init__()
115 | self.lmda = lmda
116 |
117 | def forward(self, input, target):
118 | # input: B * C * H * W * D, after sigmoid operation
119 | # target: B * C * H * W * D
120 |
121 | # region term
122 | bce = torch.nn.BCELoss()
123 | region_term = bce(input=input, target=target)
124 |
125 | # boundary term
126 | '''
127 | example [[[[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]],[[19, 20, 21], [22, 23, 24], [25, 26, 27]]]]]
128 | element 14 has 26 neighborhoods, a total of 13 operations
129 | '''
130 | # x5 <-> x14, x14 <-> x23
131 | input_1 = torch.abs(input[..., 1:, :, :] - input[..., :-1, :, :]) # |pu - pv|
132 | target_1 = torch.abs(target[..., 1:, :, :] - target[..., :-1, :, :]) # delta(yu, yv)
133 | # x11 <-> x14, x14 <-> x17
134 | input_2 = torch.abs(input[..., :, 1:, :] - input[..., :, :-1, :])
135 | target_2 = torch.abs(target[..., :, 1:, :] - target[..., :, :-1, :])
136 | # x13 <-> x14, x14 <-> x15
137 | input_3 = torch.abs(input[..., :, :, 1:] - input[..., :, :, :-1])
138 | target_3 = torch.abs(target[..., :, :, 1:] - target[..., :, :, :-1])
139 | # x2 <-> x14, x14 <-> x26
140 | input_4 = torch.abs(input[..., 1:, 1:, :] - input[..., :-1, :-1, :])
141 | target_4 = torch.abs(target[..., 1:, 1:, :] - target[..., :-1, :-1, :])
142 | # x8 <-> x14, x14 <-> x20
143 | input_5 = torch.abs(input[..., 1:, :-1, :] - input[..., :-1, 1:, :])
144 | target_5 = torch.abs(target[..., 1:, :-1, :] - target[..., :-1, 1:, :])
145 | # x10 <-> x14, x14 <-> x18
146 | input_6 = torch.abs(input[..., :, 1:, 1:] - input[..., :, :-1, :-1])
147 | target_6 = torch.abs(target[..., :, 1:, 1:] - target[..., :, :-1, :-1])
148 | # x12 <-> x14, x14 <-> x16
149 | input_7 = torch.abs(input[..., :, 1:, :-1] - input[..., :, :-1, 1:])
150 | target_7 = torch.abs(target[..., :, 1:, :-1] - target[..., :, :-1, 1:])
151 | # x6 <-> x14, x14 <-> x22
152 | input_8 = torch.abs(input[..., 1:, :, :-1] - input[..., :-1, :, 1:])
153 | target_8 = torch.abs(target[..., 1:, :, :-1] - target[..., :-1, :, 1:])
154 | # x4 <-> x14, x14 <-> x24
155 | input_9 = torch.abs(input[..., 1:, :, 1:] - input[..., :-1, :, :-1])
156 | target_9 = torch.abs(target[..., 1:, :, 1:] - target[..., :-1, :, :-1])
157 | # x9 <-> x14, x14 <-> x19
158 | input_10 = torch.abs(input[..., 1:, :-1, :-1] - input[..., :-1, 1:, 1:])
159 | target_10 = torch.abs(target[..., 1:, :-1, :-1] - target[..., :-1, 1:, 1:])
160 | # x3 <-> x14, x14 <-> x25
161 | input_11 = torch.abs(input[..., 1:, 1:, :-1] - input[..., :-1, :-1, 1:])
162 | target_11 = torch.abs(target[..., 1:, 1:, :-1] - target[..., :-1, :-1, 1:])
163 | # x1 <-> x14, x14 <-> x27
164 | input_12 = torch.abs(input[..., :-1, :-1, :-1] - input[..., 1:, 1:, 1:])
165 | target_12 = torch.abs(target[..., :-1, :-1, :-1] - target[..., 1:, 1:, 1:])
166 | # x7 <-> x14, x14 <-> x21
167 | input_13 = torch.abs(input[..., :-1, 1:, :-1] - input[..., 1:, :-1, 1:])
168 | target_13 = torch.abs(target[..., :-1, 1:, :-1] - target[..., 1:, :-1, 1:])
169 |
170 | p1 = input_1 * target_1
171 | p2 = input_2 * target_2
172 | p3 = input_3 * target_3
173 | p4 = input_4 * target_4
174 | p5 = input_5 * target_5
175 | p6 = input_6 * target_6
176 | p7 = input_7 * target_7
177 | p8 = input_8 * target_8
178 | p9 = input_9 * target_9
179 | p10 = input_10 * target_10
180 | p11 = input_11 * target_11
181 | p12 = input_12 * target_12
182 | p13 = input_13 * target_13
183 |
184 | smooth = 1e-5 # avoid zero division when target is zero
185 | boundary_term = 1 - (torch.sum(p1) / (torch.sum(target_1) + smooth) +
186 | torch.sum(p2) / (torch.sum(target_2) + smooth) +
187 | torch.sum(p3) / (torch.sum(target_3) + smooth) +
188 | torch.sum(p4) / (torch.sum(target_4) + smooth) +
189 | torch.sum(p5) / (torch.sum(target_5) + smooth) +
190 | torch.sum(p6) / (torch.sum(target_6) + smooth) +
191 | torch.sum(p7) / (torch.sum(target_7) + smooth) +
192 | torch.sum(p8) / (torch.sum(target_8) + smooth) +
193 | torch.sum(p9) / (torch.sum(target_9) + smooth) +
194 | torch.sum(p10) / (torch.sum(target_10) + smooth) +
195 | torch.sum(p11) / (torch.sum(target_11) + smooth) +
196 | torch.sum(p12) / (torch.sum(target_12) + smooth) +
197 | torch.sum(p13) / (torch.sum(target_13) + smooth)) / 13 # equation (5), and normalized to (0,1)
198 |
199 | return self.lmda * region_term + boundary_term
200 |
201 |
202 | # this 3D version further eliminates the abs operation
203 | class GC_3D_v2(torch.nn.Module):
204 | def __init__(self, lmda):
205 | super(GC_3D_v2, self).__init__()
206 | self.lmda = lmda
207 |
208 | def forward(self, input, target):
209 | # input: B * C * H * W * D, after sigmoid operation
210 | # target: B * C * H * W * D
211 |
212 | # region term
213 | bce = torch.nn.BCELoss()
214 | region_term = bce(input=input, target=target)
215 |
216 | # boundary term
217 | '''
218 | example [[[[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]],[[19, 20, 21], [22, 23, 24], [25, 26, 27]]]]]
219 | element 14 has 26 neighborhoods, a total of 13 operations
220 | '''
221 | # x5 <-> x14, x14 <-> x23
222 | input_1 = input[..., 1:, :, :] - input[..., :-1, :, :]
223 | target_1 = target[..., 1:, :, :] - target[..., :-1, :, :]
224 | # x11 <-> x14, x14 <-> x17
225 | input_2 = input[..., :, 1:, :] - input[..., :, :-1, :]
226 | target_2 = target[..., :, 1:, :] - target[..., :, :-1, :]
227 | # x13 <-> x14, x14 <-> x15
228 | input_3 = input[..., :, :, 1:] - input[..., :, :, :-1]
229 | target_3 = target[..., :, :, 1:] - target[..., :, :, :-1]
230 | # x2 <-> x14, x14 <-> x26
231 | input_4 = input[..., 1:, 1:, :] - input[..., :-1, :-1, :]
232 | target_4 = target[..., 1:, 1:, :] - target[..., :-1, :-1, :]
233 | # x8 <-> x14, x14 <-> x20
234 | input_5 = input[..., 1:, :-1, :] - input[..., :-1, 1:, :]
235 | target_5 = target[..., 1:, :-1, :] - target[..., :-1, 1:, :]
236 | # x10 <-> x14, x14 <-> x18
237 | input_6 = input[..., :, 1:, 1:] - input[..., :, :-1, :-1]
238 | target_6 = target[..., :, 1:, 1:] - target[..., :, :-1, :-1]
239 | # x12 <-> x14, x14 <-> x16
240 | input_7 = input[..., :, 1:, :-1] - input[..., :, :-1, 1:]
241 | target_7 = target[..., :, 1:, :-1] - target[..., :, :-1, 1:]
242 | # x6 <-> x14, x14 <-> x22
243 | input_8 = input[..., 1:, :, :-1] - input[..., :-1, :, 1:]
244 | target_8 = target[..., 1:, :, :-1] - target[..., :-1, :, 1:]
245 | # x4 <-> x14, x14 <-> x24
246 | input_9 = input[..., 1:, :, 1:] - input[..., :-1, :, :-1]
247 | target_9 = target[..., 1:, :, 1:] - target[..., :-1, :, :-1]
248 | # x9 <-> x14, x14 <-> x19
249 | input_10 = input[..., 1:, :-1, :-1] - input[..., :-1, 1:, 1:]
250 | target_10 = target[..., 1:, :-1, :-1] - target[..., :-1, 1:, 1:]
251 | # x3 <-> x14, x14 <-> x25
252 | input_11 = input[..., 1:, 1:, :-1] - input[..., :-1, :-1, 1:]
253 | target_11 = target[..., 1:, 1:, :-1] - target[..., :-1, :-1, 1:]
254 | # x1 <-> x14, x14 <-> x27
255 | input_12 = input[..., :-1, :-1, :-1] - input[..., 1:, 1:, 1:]
256 | target_12 = target[..., :-1, :-1, :-1] - target[..., 1:, 1:, 1:]
257 | # x7 <-> x14, x14 <-> x21
258 | input_13 = input[..., :-1, 1:, :-1] - input[..., 1:, :-1, 1:]
259 | target_13 = target[..., :-1, 1:, :-1] - target[..., 1:, :-1, 1:]
260 |
261 | p1 = input_1 * target_1
262 | p2 = input_2 * target_2
263 | p3 = input_3 * target_3
264 | p4 = input_4 * target_4
265 | p5 = input_5 * target_5
266 | p6 = input_6 * target_6
267 | p7 = input_7 * target_7
268 | p8 = input_8 * target_8
269 | p9 = input_9 * target_9
270 | p10 = input_10 * target_10
271 | p11 = input_11 * target_11
272 | p12 = input_12 * target_12
273 | p13 = input_13 * target_13
274 |
275 | smooth = 1e-5 # avoid zero division when target only has one class
276 | boundary_term = 1 - (torch.sum(p1) / (torch.sum(target_1 * target_1) + smooth) +
277 | torch.sum(p2) / (torch.sum(target_2 * target_2) + smooth) +
278 | torch.sum(p3) / (torch.sum(target_3 * target_3) + smooth) +
279 | torch.sum(p4) / (torch.sum(target_4 * target_4) + smooth) +
280 | torch.sum(p5) / (torch.sum(target_5 * target_5) + smooth) +
281 | torch.sum(p6) / (torch.sum(target_6 * target_6) + smooth) +
282 | torch.sum(p7) / (torch.sum(target_7 * target_7) + smooth) +
283 | torch.sum(p8) / (torch.sum(target_8 * target_8) + smooth) +
284 | torch.sum(p9) / (torch.sum(target_9 * target_9) + smooth) +
285 | torch.sum(p10) / (torch.sum(target_10 * target_10) + smooth) +
286 | torch.sum(p11) / (torch.sum(target_11 * target_11) + smooth) +
287 | torch.sum(p12) / (torch.sum(target_12 * target_12) + smooth) +
288 | torch.sum(p13) / (torch.sum(target_13 * target_13) + smooth)) / 13
289 |
290 | return self.lmda * region_term + boundary_term
291 |
--------------------------------------------------------------------------------