├── .gitignore ├── ACELoss_pipeline.png ├── LICENSE ├── README.md ├── aceloss.py ├── figure1.png ├── figure2.png ├── table1.png └── table2.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /ACELoss_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/ACELoss/a9677af35e9c7cf50ea2c6a3a68be3bc8c25fb1a/ACELoss_pipeline.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 xdluo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Active Contour Euler Elastica Loss Functions 2 | Official implementations of paper: [Learning Euler's Elastica Model for Medical Image Segmentation](https://arxiv.org/pdf/2011.00526.pdf), and a short version was accepted by ISBI 2021 . 3 | * Implemented a novel active contour-based loss function, a combination of region term, length term, and elastica term (mean curvature). 4 | * Reimplemented some popular active contour-based loss functions in different ways, such as 3D Active-Contour-Loss based on Sobel filter and max-and min-pool. 5 | 6 | ## Introduction and Some Results 7 | * ### **Pipeline of ACE loss**. 8 | ![](https://github.com/Luoxd1996/Active_Contour_Euler_Elastica_Loss/blob/main/ACELoss_pipeline.png) 9 | * ### **2D results and visualization**. 10 | ![](https://github.com/Luoxd1996/Active_Contour_Euler_Elastica_Loss/blob/main/table1.png) 11 | ![](https://github.com/Luoxd1996/Active_Contour_Euler_Elastica_Loss/blob/main/figure1.png) 12 | * ### **3D results and visualization**. 13 | ![](https://github.com/Luoxd1996/Active_Contour_Euler_Elastica_Loss/blob/main/table2.png) 14 | ![](https://github.com/Luoxd1996/Active_Contour_Euler_Elastica_Loss/blob/main/figure2.png) 15 | 16 | * If you want to use these methods just as constrains (combining with dice loss or ce loss), you can use **torch.mean()** to replace **torch.sum()**. 17 | 18 | ## Requirements 19 | Some important required packages include: 20 | * [Pytorch][torch_link] version >= 0.4.1. 21 | * Python >= 3.6. 22 | 23 | Follow official guidance to install. [Pytorch][torch_link]. 24 | 25 | [torch_link]:https://pytorch.org/ 26 | 27 | ## Citation 28 | If you find Active Contour Based Loss Functions are useful in your research, please consider to cite: 29 | 30 | @inproceedings{chen2020aceloss, 31 | title={Learning Euler's Elastica Model for Medical Image Segmentation}, 32 | author={Chen, Xu and Luo, Xiangde and Zhao, Yitian and Zhang, Shaoting and Wang, Guotai and Zheng, Yalin}, 33 | journal={arXiv preprint arXiv:2011.00526}, 34 | year={2020} 35 | } 36 | 37 | @inproceedings{chen2019learning, 38 | title={Learning Active Contour Models for Medical Image Segmentation}, 39 | author={Chen, Xu and Williams, Bryan M and Vallabhaneni, Srinivasa R and Czanner, Gabriela and Williams, Rachel and Zheng, Yalin}, 40 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 41 | pages={11632--11640}, 42 | year={2019} 43 | } 44 | 45 | ## Other Active Contour Based Loss Functions 46 | * Active Contour Loss ([ACLoss](https://github.com/xuuuuuuchen/Active-Contour-Loss)). 47 | * Geodesic Active Contour Loss ([GAC](https://ieeexplore.ieee.org/document/9187860)). 48 | * Elastic-Interaction-based Loss ([EILoss](https://github.com/charrywhite/elastic_interaction_based_loss)) 49 | 50 | ## Acknowledgement 51 | * We thank [Dr. Jun Ma](https://github.com/JunMa11) for instructive discussion of curvature implementation and also thank [Mr. Yechong Huang](https://github.com/huohuayuzhong) for instructive help during the implementation processing of 3D curvature, Sobel, and Laplace operators. 52 | -------------------------------------------------------------------------------- /aceloss.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ACLoss(nn.Module): 9 | """ 10 | Active Contour Loss 11 | based on sobel filter 12 | """ 13 | 14 | def __init__(self, miu=1.0, classes=3): 15 | super(ACLoss, self).__init__() 16 | 17 | self.miu = miu 18 | self.classes = classes 19 | sobel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) 20 | sobel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) 21 | 22 | self.sobel_x = nn.Parameter(torch.from_numpy(sobel_x).float().expand(self.classes, 1, 3, 3), 23 | requires_grad=False) 24 | self.sobel_y = nn.Parameter(torch.from_numpy(sobel_y).float().expand(self.classes, 1, 3, 3), 25 | requires_grad=False) 26 | 27 | self.diff_x = nn.Conv2d(self.classes, self.classes, groups=self.classes, kernel_size=3, stride=1, padding=1, 28 | bias=False) 29 | self.diff_x.weight = self.sobel_x 30 | self.diff_y = nn.Conv2d(self.classes, self.classes, groups=self.classes, kernel_size=3, stride=1, padding=1, 31 | bias=False) 32 | self.diff_y.weight = self.sobel_y 33 | 34 | def forward(self, predication, label): 35 | grd_x = self.diff_x(predication) 36 | grd_y = self.diff_y(predication) 37 | 38 | # length 39 | length = torch.sum( 40 | torch.abs(torch.sqrt(grd_x ** 2 + grd_y ** 2 + 1e-8))) 41 | length = (length - length.min()) / (length.max() - length.min() + 1e-8) 42 | length = torch.sum(length) 43 | 44 | # region 45 | label = label.float() 46 | c_in = torch.ones_like(predication) 47 | c_out = torch.zeros_like(predication) 48 | region_in = torch.abs(torch.sum(predication * ((label - c_in) ** 2))) 49 | region_out = torch.abs( 50 | torch.sum((1 - predication) * ((label - c_out) ** 2))) 51 | region = self.miu * region_in + region_out 52 | 53 | return region + length 54 | 55 | 56 | class ACLossV2(nn.Module): 57 | """ 58 | Active Contour Loss 59 | based on maxpooling & minpooling 60 | """ 61 | 62 | def __init__(self, miu=1.0, classes=3): 63 | super(ACLossV2, self).__init__() 64 | 65 | self.miu = miu 66 | self.classes = classes 67 | 68 | def forward(self, predication, label): 69 | min_pool_x = nn.functional.max_pool2d( 70 | predication * -1, (3, 3), 1, 1) * -1 71 | contour = torch.relu(nn.functional.max_pool2d( 72 | min_pool_x, (3, 3), 1, 1) - min_pool_x) 73 | 74 | # length 75 | length = torch.sum(torch.abs(contour)) 76 | 77 | # region 78 | label = label.float() 79 | c_in = torch.ones_like(predication) 80 | c_out = torch.zeros_like(predication) 81 | region_in = torch.abs(torch.sum(predication * ((label - c_in) ** 2))) 82 | region_out = torch.abs( 83 | torch.sum((1 - predication) * ((label - c_out) ** 2))) 84 | region = self.miu * region_in + region_out 85 | 86 | return region + length 87 | 88 | 89 | def ACELoss(y_pred, y_true, u=1, a=1, b=1): 90 | """ 91 | Active Contour Loss 92 | based on total variations and mean curvature 93 | """ 94 | def first_derivative(input): 95 | u = input 96 | m = u.shape[2] 97 | n = u.shape[3] 98 | 99 | ci_0 = (u[:, :, 1, :] - u[:, :, 0, :]).unsqueeze(2) 100 | ci_1 = u[:, :, 2:, :] - u[:, :, 0:m - 2, :] 101 | ci_2 = (u[:, :, -1, :] - u[:, :, m - 2, :]).unsqueeze(2) 102 | ci = torch.cat([ci_0, ci_1, ci_2], 2) / 2 103 | 104 | cj_0 = (u[:, :, :, 1] - u[:, :, :, 0]).unsqueeze(3) 105 | cj_1 = u[:, :, :, 2:] - u[:, :, :, 0:n - 2] 106 | cj_2 = (u[:, :, :, -1] - u[:, :, :, n - 2]).unsqueeze(3) 107 | cj = torch.cat([cj_0, cj_1, cj_2], 3) / 2 108 | 109 | return ci, cj 110 | 111 | def second_derivative(input, ci, cj): 112 | u = input 113 | m = u.shape[2] 114 | n = u.shape[3] 115 | 116 | cii_0 = (u[:, :, 1, :] + u[:, :, 0, :] - 117 | 2 * u[:, :, 0, :]).unsqueeze(2) 118 | cii_1 = u[:, :, 2:, :] + u[:, :, :-2, :] - 2 * u[:, :, 1:-1, :] 119 | cii_2 = (u[:, :, -1, :] + u[:, :, -2, :] - 120 | 2 * u[:, :, -1, :]).unsqueeze(2) 121 | cii = torch.cat([cii_0, cii_1, cii_2], 2) 122 | 123 | cjj_0 = (u[:, :, :, 1] + u[:, :, :, 0] - 124 | 2 * u[:, :, :, 0]).unsqueeze(3) 125 | cjj_1 = u[:, :, :, 2:] + u[:, :, :, :-2] - 2 * u[:, :, :, 1:-1] 126 | cjj_2 = (u[:, :, :, -1] + u[:, :, :, -2] - 127 | 2 * u[:, :, :, -1]).unsqueeze(3) 128 | 129 | cjj = torch.cat([cjj_0, cjj_1, cjj_2], 3) 130 | 131 | cij_0 = ci[:, :, :, 1:n] 132 | cij_1 = ci[:, :, :, -1].unsqueeze(3) 133 | 134 | cij_a = torch.cat([cij_0, cij_1], 3) 135 | cij_2 = ci[:, :, :, 0].unsqueeze(3) 136 | cij_3 = ci[:, :, :, 0:n - 1] 137 | cij_b = torch.cat([cij_2, cij_3], 3) 138 | cij = cij_a - cij_b 139 | 140 | return cii, cjj, cij 141 | 142 | def region(y_pred, y_true, u=1): 143 | label = y_true.float() 144 | c_in = torch.ones_like(y_pred) 145 | c_out = torch.zeros_like(y_pred) 146 | region_in = torch.abs(torch.sum(y_pred * ((label - c_in) ** 2))) 147 | region_out = torch.abs( 148 | torch.sum((1 - y_pred) * ((label - c_out) ** 2))) 149 | region = u * region_in + region_out 150 | return region 151 | 152 | def elastica(input, a=1, b=1): 153 | ci, cj = first_derivative(input) 154 | cii, cjj, cij = second_derivative(input, ci, cj) 155 | beta = 1e-8 156 | length = torch.sqrt(beta + ci ** 2 + cj ** 2) 157 | curvature = (beta + ci ** 2) * cjj + \ 158 | (beta + cj ** 2) * cii - 2 * ci * cj * cij 159 | curvature = torch.abs(curvature) / ((ci ** 2 + cj ** 2) ** 1.5 + beta) 160 | elastica = torch.sum((a + b * (curvature ** 2)) * torch.abs(length)) 161 | return elastica 162 | 163 | loss = region(y_pred, y_true, u=u) + elastica(y_pred, a=a, b=b) 164 | return loss 165 | 166 | 167 | class ACLoss3D(nn.Module): 168 | """ 169 | Active Contour Loss 170 | based on sobel filter 171 | """ 172 | 173 | def __init__(self, classes=4, alpha=1): 174 | super(ACLoss3D, self).__init__() 175 | self.alpha = alpha 176 | self.classes = classes 177 | sobel = np.array([[[1., 2., 1.], 178 | [2., 4., 2.], 179 | [1., 2., 1.]], 180 | 181 | [[0., 0., 0.], 182 | [0., 0., 0.], 183 | [0., 0., 0.]], 184 | 185 | [[-1., -2., -1.], 186 | [-2., -4., -2.], 187 | [-1., -2., -1.]]]) 188 | 189 | self.sobel_x = nn.Parameter( 190 | torch.from_numpy(sobel.transpose(0, 1, 2)).float().unsqueeze(0).unsqueeze(0).expand(self.classes, 1, 3, 3, 191 | 3), requires_grad=False) 192 | self.sobel_y = nn.Parameter( 193 | torch.from_numpy(sobel.transpose(1, 0, 2)).float().unsqueeze(0).unsqueeze(0).expand(self.classes, 1, 3, 3, 194 | 3), requires_grad=False) 195 | self.sobel_z = nn.Parameter( 196 | torch.from_numpy(sobel.transpose(1, 2, 0)).float().unsqueeze(0).unsqueeze(0).expand(self.classes, 1, 3, 3, 197 | 3), requires_grad=False) 198 | 199 | self.diff_x = nn.Conv3d(self.classes, self.classes, groups=self.classes, kernel_size=3, stride=1, padding=1, 200 | bias=False) 201 | self.diff_x.weight = self.sobel_x 202 | self.diff_y = nn.Conv3d(self.classes, self.classes, groups=self.classes, kernel_size=3, stride=1, padding=1, 203 | bias=False) 204 | self.diff_y.weight = self.sobel_y 205 | self.diff_z = nn.Conv3d(self.classes, self.classes, groups=self.classes, kernel_size=3, stride=1, padding=1, 206 | bias=False) 207 | self.diff_z.weight = self.sobel_z 208 | 209 | def forward(self, predication, label): 210 | grd_x = self.diff_x(predication) 211 | grd_y = self.diff_y(predication) 212 | grd_z = self.diff_z(predication) 213 | 214 | # length 215 | length = torch.sqrt(grd_x ** 2 + grd_y ** 2 + grd_z ** 2 + 1e-8) 216 | length = (length - length.min()) / (length.max() - length.min() + 1e-8) 217 | length = torch.sum(length) 218 | 219 | # region 220 | label = label.float() 221 | c_in = torch.ones_like(predication) 222 | c_out = torch.zeros_like(predication) 223 | region_in = torch.abs(torch.sum(predication * ((label - c_in) ** 2))) 224 | region_out = torch.abs( 225 | torch.sum((1 - predication) * ((label - c_out) ** 2))) 226 | region = region_in + region_out 227 | 228 | return self.alpha * region + length 229 | 230 | 231 | class ACLoss3DV2(nn.Module): 232 | """ 233 | Active Contour Loss 234 | based on minpooling & maxpooling 235 | """ 236 | 237 | def __init__(self, miu=1.0, classes=3): 238 | super(ACLoss3DV2, self).__init__() 239 | 240 | self.miu = miu 241 | self.classes = classes 242 | 243 | def forward(self, predication, label): 244 | min_pool_x = nn.functional.max_pool3d( 245 | predication * -1, (3, 3, 3), 1, 1) * -1 246 | contour = torch.relu(nn.functional.max_pool3d( 247 | min_pool_x, (3, 3, 3), 1, 1) - min_pool_x) 248 | 249 | # length 250 | length = torch.sum(torch.abs(contour)) 251 | 252 | # region 253 | label = label.float() 254 | c_in = torch.ones_like(predication) 255 | c_out = torch.zeros_like(predication) 256 | region_in = torch.abs(torch.sum(predication * ((label - c_in) ** 2))) 257 | region_out = torch.abs( 258 | torch.sum((1 - predication) * ((label - c_out) ** 2))) 259 | region = self.miu * region_in + region_out 260 | 261 | return region + length 262 | 263 | 264 | class ACELoss3D(nn.Module): 265 | """ 266 | Active contour based elastic model loss 267 | based on total variations and mean curvature 268 | """ 269 | 270 | def __init__(self, alpha=1e-3, beta=1.0, miu=1, classes=3): 271 | super(ACELoss3D, self).__init__() 272 | self.alpha = alpha 273 | self.beta = beta 274 | self.miu = miu 275 | 276 | def first_derivative(self, input): 277 | u = input 278 | m = u.shape[2] 279 | n = u.shape[3] 280 | k = u.shape[4] 281 | 282 | ci_0 = (u[:, :, 1, :, :] - u[:, :, 0, :, :]).unsqueeze(2) 283 | ci_1 = u[:, :, 2:, :, :] - u[:, :, 0:m - 2, :, :] 284 | ci_2 = (u[:, :, -1, :, :] - u[:, :, m - 2, :, :]).unsqueeze(2) 285 | ci = torch.cat([ci_0, ci_1, ci_2], 2) / 2 286 | 287 | cj_0 = (u[:, :, :, 1, :] - u[:, :, :, 0, :]).unsqueeze(3) 288 | cj_1 = u[:, :, :, 2:, :] - u[:, :, :, 0:n - 2, :] 289 | cj_2 = (u[:, :, :, -1, :] - u[:, :, :, n - 2, :]).unsqueeze(3) 290 | cj = torch.cat([cj_0, cj_1, cj_2], 3) / 2 291 | 292 | ck_0 = (u[:, :, :, :, 1] - u[:, :, :, :, 0]).unsqueeze(4) 293 | ck_1 = u[:, :, :, :, 2:] - u[:, :, :, :, 0:k - 2] 294 | ck_2 = (u[:, :, :, :, -1] - u[:, :, :, :, k - 2]).unsqueeze(4) 295 | ck = torch.cat([ck_0, ck_1, ck_2], 4) / 2 296 | 297 | return ci, cj, ck 298 | 299 | def second_derivative(self, input, ci, cj, ck): 300 | u = input 301 | m = u.shape[2] 302 | n = u.shape[3] 303 | k = u.shape[4] 304 | 305 | cii_0 = (u[:, :, 1, :, :] + u[:, :, 0, :, :] - 306 | 2 * u[:, :, 0, :, :]).unsqueeze(2) 307 | cii_1 = u[:, :, 2:, :, :] + \ 308 | u[:, :, :-2, :, :] - 2 * u[:, :, 1:-1, :, :] 309 | cii_2 = (u[:, :, -1, :, :] + u[:, :, -2, :, :] - 310 | 2 * u[:, :, -1, :, :]).unsqueeze(2) 311 | cii = torch.cat([cii_0, cii_1, cii_2], 2) 312 | 313 | cjj_0 = (u[:, :, :, 1, :] + u[:, :, :, 0, :] - 314 | 2 * u[:, :, :, 0, :]).unsqueeze(3) 315 | cjj_1 = u[:, :, :, 2:, :] + \ 316 | u[:, :, :, :-2, :] - 2 * u[:, :, :, 1:-1, :] 317 | cjj_2 = (u[:, :, :, -1, :] + u[:, :, :, -2, :] - 318 | 2 * u[:, :, :, -1, :]).unsqueeze(3) 319 | 320 | cjj = torch.cat([cjj_0, cjj_1, cjj_2], 3) 321 | 322 | ckk_0 = (u[:, :, :, :, 1] + u[:, :, :, :, 0] - 323 | 2 * u[:, :, :, :, 0]).unsqueeze(4) 324 | ckk_1 = u[:, :, :, :, 2:] + \ 325 | u[:, :, :, :, :-2] - 2 * u[:, :, :, :, 1:-1] 326 | ckk_2 = (u[:, :, :, :, -1] + u[:, :, :, :, -2] - 327 | 2 * u[:, :, :, :, -1]).unsqueeze(4) 328 | 329 | ckk = torch.cat([ckk_0, ckk_1, ckk_2], 4) 330 | 331 | cij_0 = ci[:, :, :, 1:n, :] 332 | cij_1 = ci[:, :, :, -1, :].unsqueeze(3) 333 | 334 | cij_a = torch.cat([cij_0, cij_1], 3) 335 | cij_2 = ci[:, :, :, 0, :].unsqueeze(3) 336 | cij_3 = ci[:, :, :, 0:n - 1, :] 337 | cij_b = torch.cat([cij_2, cij_3], 3) 338 | cij = cij_a - cij_b 339 | 340 | cik_0 = ci[:, :, :, :, 1:n] 341 | cik_1 = ci[:, :, :, :, -1].unsqueeze(4) 342 | 343 | cik_a = torch.cat([cik_0, cik_1], 4) 344 | cik_2 = ci[:, :, :, :, 0].unsqueeze(4) 345 | cik_3 = ci[:, :, :, :, 0:k - 1] 346 | cik_b = torch.cat([cik_2, cik_3], 4) 347 | cik = cik_a - cik_b 348 | 349 | cjk_0 = cj[:, :, :, :, 1:n] 350 | cjk_1 = cj[:, :, :, :, -1].unsqueeze(4) 351 | 352 | cjk_a = torch.cat([cjk_0, cjk_1], 4) 353 | cjk_2 = cj[:, :, :, :, 0].unsqueeze(4) 354 | cjk_3 = cj[:, :, :, :, 0:k - 1] 355 | cjk_b = torch.cat([cjk_2, cjk_3], 4) 356 | cjk = cjk_a - cjk_b 357 | 358 | return cii, cjj, ckk, cij, cik, cjk 359 | 360 | def region(self, y_pred, y_true, u=1): 361 | label = y_true.float() 362 | c_in = torch.ones_like(y_pred) 363 | c_out = torch.zeros_like(y_pred) 364 | region_in = torch.abs(torch.sum(y_pred * ((label - c_in) ** 2))) 365 | region_out = torch.abs( 366 | torch.sum((1 - y_pred) * ((label - c_out) ** 2))) 367 | region = u * region_in + region_out 368 | return region 369 | 370 | def elastica(self, input, a=1, b=1): 371 | ci, cj, ck = self.first_derivative(input) 372 | cii, cjj, ckk, cij, cik, cjk = self.second_derivative( 373 | input, ci, cj, ck) 374 | beta = 1e-8 375 | length = torch.sqrt(beta + ci ** 2 + cj ** 2 + ck ** 2) 376 | curvature = (1 + ci ** 2 + cj ** 2) * ckk + (1 + cj ** 2 + ck ** 2) * cii + ( 377 | 1 + ci ** 2 + ck ** 2) * cjj - 2 * cik * cjk * cij 378 | curvature = torch.abs(curvature) / \ 379 | ((1 + ci ** 2 + cj ** 2 + ck ** 2) ** 0.5 + beta) 380 | elastica = torch.sum(a + b * (curvature ** 2) * torch.abs(length)) 381 | return elastica 382 | 383 | def forward(self, y_pred, y_true): 384 | loss = self.region(y_pred, y_true, u=self.miu) + \ 385 | self.elastica(y_pred, a=self.alpha, b=self.beta) 386 | return loss 387 | 388 | 389 | class FastACELoss3D(nn.Module): 390 | """ 391 | Active contour based elastic model loss 392 | based on sobel and laplace filter 393 | """ 394 | 395 | def __init__(self, miu=1, alpha=1e-3, beta=2.0, classes=4, types="laplace"): 396 | super(FastACELoss3D, self).__init__() 397 | self.miu = miu 398 | self.alpha = alpha 399 | self.beta = beta 400 | self.classes = classes 401 | self.types = types 402 | sobel = np.array([[[1., 2., 1.], 403 | [2., 4., 2.], 404 | [1., 2., 1.]], 405 | 406 | [[0., 0., 0.], 407 | [0., 0., 0.], 408 | [0., 0., 0.]], 409 | 410 | [[-1., -2., -1.], 411 | [-2., -4., -2.], 412 | [-1., -2., -1.]]]) 413 | laplace_kernel = np.ones((3, 3, 3)) 414 | laplace_kernel[1, 1, 1] = -26 415 | 416 | self.sobel_x = nn.Parameter( 417 | torch.from_numpy(sobel.transpose(0, 1, 2)).float().unsqueeze(0).unsqueeze(0).expand(self.classes, 1, 3, 3, 418 | 3), requires_grad=False) 419 | self.sobel_y = nn.Parameter( 420 | torch.from_numpy(sobel.transpose(1, 0, 2)).float().unsqueeze(0).unsqueeze(0).expand(self.classes, 1, 3, 3, 421 | 3), requires_grad=False) 422 | self.sobel_z = nn.Parameter( 423 | torch.from_numpy(sobel.transpose(1, 2, 0)).float().unsqueeze(0).unsqueeze(0).expand(self.classes, 1, 3, 3, 424 | 3), requires_grad=False) 425 | self.laplace = nn.Parameter( 426 | torch.from_numpy(laplace_kernel).float().unsqueeze( 427 | 0).unsqueeze(0).expand(self.classes, 1, 3, 3, 3), 428 | requires_grad=False) 429 | 430 | self.diff_x = nn.Conv3d(self.classes, self.classes, groups=self.classes, kernel_size=3, stride=1, padding=1, 431 | bias=False) 432 | self.diff_x.weight = self.sobel_x 433 | self.diff_y = nn.Conv3d(self.classes, self.classes, groups=self.classes, kernel_size=3, stride=1, padding=1, 434 | bias=False) 435 | self.diff_y.weight = self.sobel_y 436 | self.diff_z = nn.Conv3d(self.classes, self.classes, groups=self.classes, kernel_size=3, stride=1, padding=1, 437 | bias=False) 438 | self.diff_z.weight = self.sobel_z 439 | 440 | self.laplace_operator = nn.Conv3d(self.classes, self.classes, groups=self.classes, kernel_size=3, stride=1, 441 | padding=1, 442 | bias=False) 443 | self.laplace_operator.weight = self.laplace 444 | 445 | def forward(self, predication, label): 446 | grd_x = self.diff_x(predication) 447 | grd_y = self.diff_y(predication) 448 | grd_z = self.diff_z(predication) 449 | diff = self.laplace_operator(predication) 450 | 451 | # length 452 | length = torch.sqrt(grd_x ** 2 + grd_y ** 2 + grd_z ** 2 + 1e-8) 453 | length = (length - length.min()) / (length.max() - length.min() + 1e-8) 454 | 455 | # curvature 456 | if self.types: 457 | curvature = torch.abs(diff) 458 | curvature = (curvature - curvature.min()) / \ 459 | (curvature.max() - curvature.min() + 1e-8) 460 | else: 461 | """ 462 | maybe more powerful 463 | """ 464 | curvature = torch.abs( 465 | diff) / ((grd_x ** 2 + grd_y ** 2 + grd_z ** 2 + 1) ** 0.5 + 1e-8) 466 | curvature = (curvature - curvature.min()) / \ 467 | (curvature.max() - curvature.min() + 1e-8) 468 | # region 469 | label = label.float() 470 | c_in = torch.ones_like(predication) 471 | c_out = torch.zeros_like(predication) 472 | region_in = torch.abs(torch.sum(predication * ((label - c_in) ** 2))) 473 | region_out = torch.abs( 474 | torch.sum((1 - predication) * ((label - c_out) ** 2))) 475 | region = self.miu * region_in + region_out 476 | 477 | # elastic 478 | elastic = torch.sum((self.alpha + self.beta * curvature ** 2) * length) 479 | return region + elastic 480 | 481 | 482 | class FastACELoss3DV2(nn.Module): 483 | """ 484 | Active contour based elastic model loss 485 | based on minpooling & maxpooling and laplace filter 486 | """ 487 | 488 | def __init__(self, miu=1, alpha=1e-3, beta=2.0, classes=4, types="other"): 489 | super(FastACELoss3DV2, self).__init__() 490 | self.miu = miu 491 | self.alpha = alpha 492 | self.beta = beta 493 | self.classes = classes 494 | self.types = types 495 | laplace_kernel = np.ones((3, 3, 3)) 496 | laplace_kernel[1, 1, 1] = -26 497 | 498 | self.laplace = nn.Parameter( 499 | torch.from_numpy(laplace_kernel).float().unsqueeze( 500 | 0).unsqueeze(0).expand(self.classes, 1, 3, 3, 3), 501 | requires_grad=False) 502 | 503 | self.laplace_operator = nn.Conv3d(self.classes, self.classes, groups=self.classes, kernel_size=3, stride=1, 504 | padding=1, 505 | bias=False) 506 | self.laplace_operator.weight = self.laplace 507 | 508 | def forward(self, predication, label): 509 | min_pool_x = nn.functional.max_pool3d(predication * -1, 3, 1, 1) * -1 510 | contour = torch.relu(nn.functional.max_pool3d( 511 | min_pool_x, 3, 1, 1) - min_pool_x) 512 | 513 | diff = self.laplace_operator(predication) 514 | 515 | # length 516 | length = torch.abs(contour) 517 | 518 | # curvature 519 | if self.types: 520 | curvature = torch.abs(diff) 521 | curvature = (curvature - curvature.min()) / \ 522 | (curvature.max() - curvature.min() + 1e-8) 523 | else: 524 | """ 525 | maybe more powerful 526 | """ 527 | curvature = torch.abs(diff) / ((length ** 2 + 1) ** 0.5 + 1e-8) 528 | curvature = (curvature - curvature.min()) / \ 529 | (curvature.max() - curvature.min() + 1e-8) 530 | # region 531 | label = label.float() 532 | c_in = torch.ones_like(predication) 533 | c_out = torch.zeros_like(predication) 534 | region_in = torch.abs(torch.sum(predication * ((label - c_in) ** 2))) 535 | region_out = torch.abs( 536 | torch.sum((1 - predication) * ((label - c_out) ** 2))) 537 | region = self.miu * region_in + region_out 538 | 539 | # elastic 540 | elastic = torch.sum((self.alpha + self.beta * curvature ** 2) * length) 541 | return region + elastic 542 | 543 | 544 | "test demo" 545 | x2 = torch.rand((2, 3, 97, 80)) 546 | y2 = torch.rand((2, 3, 97, 80)) 547 | time1 = time.time() 548 | print("ACLoss:", ACLoss()(x2, y2).item()) 549 | print(time.time() - time1) 550 | time2 = time.time() 551 | print("ACLossV2:", ACLossV2()(x2, y2).item()) 552 | print(time.time() - time2) 553 | time3 = time.time() 554 | print("ACELoss:", ACELoss(x2, y2).item()) 555 | print(time.time() - time3) 556 | time6 = time.time() 557 | x3 = torch.rand((2, 4, 112, 97, 80)) 558 | y3 = torch.rand((2, 4, 112, 97, 80)) 559 | time7 = time.time() 560 | print("ACLoss3D:", ACLoss3D()(x3, y3).item()) 561 | print(time.time() - time7) 562 | time8 = time.time() 563 | print("ACLoss3DV2:", ACLoss3DV2()(x3, y3).item()) 564 | print(time.time() - time8) 565 | time9 = time.time() 566 | print("ACELoss3D:", ACELoss3D().cuda()(x3.cuda(), y3.cuda()).item()) 567 | print(time.time() - time9) 568 | time10 = time.time() 569 | print("FastACELoss3D:", FastACELoss3D().cuda()(x3.cuda(), y3.cuda()).item()) 570 | print(time.time() - time10) 571 | time11 = time.time() 572 | print("FastACELoss3DV2:", FastACELoss3DV2().cuda()(x3.cuda(), y3.cuda()).item()) 573 | print(time.time() - time11) 574 | time12 = time.time() 575 | -------------------------------------------------------------------------------- /figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/ACELoss/a9677af35e9c7cf50ea2c6a3a68be3bc8c25fb1a/figure1.png -------------------------------------------------------------------------------- /figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/ACELoss/a9677af35e9c7cf50ea2c6a3a68be3bc8c25fb1a/figure2.png -------------------------------------------------------------------------------- /table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/ACELoss/a9677af35e9c7cf50ea2c6a3a68be3bc8c25fb1a/table1.png -------------------------------------------------------------------------------- /table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/ACELoss/a9677af35e9c7cf50ea2c6a3a68be3bc8c25fb1a/table2.png --------------------------------------------------------------------------------