├── .gitignore ├── LICENSE ├── README.md └── blob_loss.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 neuronflow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # blob loss: instance imbalance aware loss functions for semantic segmentation 2 | Screenshot 2023-02-14 at 00 05 30 3 | 4 | 5 | ## example implementation - computation time 6 | note that this example implementation is not optimized for computation time. 7 | We would love to see your more efficient implementation! 8 | 9 | ## manuscript 10 | [https://arxiv.org/abs/2205.08209](https://arxiv.org/abs/2205.08209) 11 | 12 | ## citation 13 | Please cite blob loss when using it: 14 | 15 | ``` 16 | @misc{https://doi.org/10.48550/arxiv.2205.08209, 17 | doi = {10.48550/ARXIV.2205.08209}, 18 | 19 | url = {https://arxiv.org/abs/2205.08209}, 20 | 21 | author = {Kofler, Florian and Shit, Suprosanna and Ezhov, Ivan and Fidon, Lucas and Horvath, Izabela and Al-Maskari, Rami and Li, Hongwei and Bhatia, Harsharan and Loehr, Timo and Piraud, Marie and Erturk, Ali and Kirschke, Jan and Peeken, Jan and Vercauteren, Tom and Zimmer, Claus and Wiestler, Benedikt and Menze, Bjoern}, 22 | 23 | keywords = {Computer Vision and Pattern Recognition (cs.CV), Machine Learning (cs.LG), Image and Video Processing (eess.IV), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering}, 24 | 25 | title = {blob loss: instance imbalance aware loss functions for semantic segmentation}, 26 | 27 | publisher = {arXiv}, 28 | 29 | year = {2022}, 30 | 31 | copyright = {arXiv.org perpetual, non-exclusive license} 32 | } 33 | ``` 34 | 35 | ## evaluation 36 | HINT: Our new project [panoptica](https://github.com/BrainLesion/panoptica) might be helpful in evaluating your multi-instance segmentation problem. 37 | -------------------------------------------------------------------------------- /blob_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def vprint(*args): 5 | verbose = False 6 | if verbose: 7 | print(*args) 8 | 9 | 10 | def compute_compound_loss( 11 | criterion_dict: dict, 12 | raw_network_outputs: torch.Tensor, 13 | label: torch.Tensor, 14 | blob_loss_mode=False, 15 | masked=True, 16 | ): 17 | """ 18 | This computes a compound loss by looping through a criterion dict! 19 | """ 20 | # vprint("outputs:", outputs) 21 | losses = [] 22 | for entry in criterion_dict.values(): 23 | name = entry["name"] 24 | vprint("loss name:", name) 25 | criterion = entry["loss"] 26 | weight = entry["weight"] 27 | 28 | sigmoid = entry["sigmoid"] 29 | if blob_loss_mode == False: 30 | vprint("computing main loss!") 31 | if sigmoid == True: 32 | sigmoid_network_outputs = torch.sigmoid(raw_network_outputs) 33 | individual_loss = criterion(sigmoid_network_outputs, label) 34 | else: 35 | individual_loss = criterion(raw_network_outputs, label) 36 | elif blob_loss_mode == True: 37 | vprint("computing blob loss!") 38 | if masked == True: # this is the default blob loss 39 | if sigmoid == True: 40 | sigmoid_network_outputs = torch.sigmoid(raw_network_outputs) 41 | individual_loss = compute_blob_loss_multi( 42 | criterion=criterion, 43 | network_outputs=sigmoid_network_outputs, 44 | multi_label=label, 45 | ) 46 | else: 47 | individual_loss = compute_blob_loss_multi( 48 | criterion=criterion, 49 | network_outputs=raw_network_outputs, 50 | multi_label=label, 51 | ) 52 | elif masked == False: # without masking for ablation study 53 | if sigmoid == True: 54 | sigmoid_network_outputs = torch.sigmoid(raw_network_outputs) 55 | individual_loss = compute_no_masking_multi( 56 | criterion=criterion, 57 | network_outputs=sigmoid_network_outputs, 58 | multi_label=label, 59 | ) 60 | else: 61 | individual_loss = compute_no_masking_multi( 62 | criterion=criterion, 63 | network_outputs=raw_network_outputs, 64 | multi_label=label, 65 | ) 66 | 67 | weighted_loss = individual_loss * weight 68 | losses.append(weighted_loss) 69 | 70 | vprint("losses:", losses) 71 | loss = sum(losses) 72 | return loss 73 | 74 | 75 | def compute_blob_loss_multi( 76 | criterion, 77 | network_outputs: torch.Tensor, 78 | multi_label: torch.Tensor, 79 | ): 80 | """ 81 | 1. loop through elements in our batch 82 | 2. loop through blobs per element compute loss and divide by blobs to have element loss 83 | 2.1 we need to account for sigmoid and non/sigmoid in conjunction with BCE 84 | 3. divide by batch length to have a correct batch loss for back prop 85 | """ 86 | batch_length = multi_label.shape[0] 87 | vprint("batch_length:", batch_length) 88 | 89 | element_blob_loss = [] 90 | # loop over elements 91 | for element in range(batch_length): 92 | if element < batch_length: 93 | end_index = element + 1 94 | elif element == batch_length: 95 | end_index = None 96 | 97 | element_label = multi_label[element:end_index, ...] 98 | vprint("element label shape:", element_label.shape) 99 | 100 | vprint("element_label:", element_label.shape) 101 | 102 | element_output = network_outputs[element:end_index, ...] 103 | 104 | # loop through labels 105 | unique_labels = torch.unique(element_label) 106 | blob_count = len(unique_labels) - 1 107 | vprint("found this amount of blobs in batch element:", blob_count) 108 | 109 | label_loss = [] 110 | for ula in unique_labels: 111 | if ula == 0: 112 | vprint("ula is 0 we do nothing") 113 | else: 114 | # first we need one hot labels 115 | vprint("ula greater than 0:", ula) 116 | label_mask = element_label > 0 117 | # we flip labels 118 | label_mask = ~label_mask 119 | 120 | # we set the mask to true where our label of interest is located 121 | # vprint(torch.count_nonzero(label_mask)) 122 | label_mask[element_label == ula] = 1 123 | # vprint(torch.count_nonzero(label_mask)) 124 | vprint("label_mask", label_mask) 125 | # vprint("torch.unique(label_mask):", torch.unique(label_mask)) 126 | 127 | the_label = element_label == ula 128 | the_label_int = the_label.int() 129 | vprint("the_label:", torch.count_nonzero(the_label)) 130 | 131 | 132 | # debugging 133 | # masked_label = the_label * label_mask 134 | # vprint("masked_label:", torch.count_nonzero(masked_label)) 135 | 136 | masked_output = element_output * label_mask 137 | 138 | try: 139 | # we try with int labels first, but some losses require floats 140 | blob_loss = criterion(masked_output, the_label_int) 141 | except: 142 | # if int does not work we try float 143 | blob_loss = criterion(masked_output, the_label.float()) 144 | vprint("blob_loss:", blob_loss) 145 | 146 | label_loss.append(blob_loss) 147 | 148 | # compute mean 149 | vprint("label_loss:", label_loss) 150 | # mean_label_loss = 0 151 | vprint("blobs in crop:", len(label_loss)) 152 | if not len(label_loss) == 0: 153 | mean_label_loss = sum(label_loss) / len(label_loss) 154 | # mean_label_loss = sum(label_loss) / \ 155 | # torch.count_nonzero(label_loss) 156 | vprint("mean_label_loss", mean_label_loss) 157 | element_blob_loss.append(mean_label_loss) 158 | 159 | # compute mean 160 | vprint("element_blob_loss:", element_blob_loss) 161 | mean_element_blob_loss = 0 162 | vprint("elements in batch:", len(element_blob_loss)) 163 | if not len(element_blob_loss) == 0: 164 | mean_element_blob_loss = sum(element_blob_loss) / len(element_blob_loss) 165 | # element_blob_loss) / torch.count_nonzero(element_blob_loss) 166 | 167 | vprint("mean_element_blob_loss", mean_element_blob_loss) 168 | 169 | return mean_element_blob_loss 170 | 171 | 172 | def compute_no_masking_multi( 173 | criterion, 174 | network_outputs: torch.Tensor, 175 | multi_label: torch.Tensor, 176 | ): 177 | """ 178 | 1. loop through elements in our batch 179 | 2. loop through blobs per element compute loss and divide by blobs to have element loss 180 | 2.1 we need to account for sigmoid and non/sigmoid in conjunction with BCE 181 | 3. divide by batch length to have a correct batch loss for back prop 182 | """ 183 | batch_length = multi_label.shape[0] 184 | vprint("batch_length:", batch_length) 185 | 186 | element_blob_loss = [] 187 | # loop over elements 188 | for element in range(batch_length): 189 | if element < batch_length: 190 | end_index = element + 1 191 | elif element == batch_length: 192 | end_index = None 193 | 194 | element_label = multi_label[element:end_index, ...] 195 | vprint("element label shape:", element_label.shape) 196 | 197 | vprint("element_label:", element_label.shape) 198 | 199 | element_output = network_outputs[element:end_index, ...] 200 | 201 | # loop through labels 202 | unique_labels = torch.unique(element_label) 203 | blob_count = len(unique_labels) - 1 204 | vprint("found this amount of blobs in batch element:", blob_count) 205 | 206 | label_loss = [] 207 | for ula in unique_labels: 208 | if ula == 0: 209 | vprint("ula is 0 we do nothing") 210 | else: 211 | # first we need one hot labels 212 | vprint("ula greater than 0:", ula) 213 | 214 | the_label = element_label == ula 215 | the_label_int = the_label.int() 216 | 217 | vprint("the_label:", torch.count_nonzero(the_label)) 218 | 219 | # we compute the loss with no mask 220 | try: 221 | # we try with int labels first, but some losses require floats 222 | blob_loss = criterion(element_output, the_label_int) 223 | except: 224 | # if int does not work we try float 225 | blob_loss = criterion(element_output, the_label.float()) 226 | vprint("blob_loss:", blob_loss) 227 | 228 | label_loss.append(blob_loss) 229 | 230 | # compute mean 231 | vprint("label_loss:", label_loss) 232 | # mean_label_loss = 0 233 | vprint("blobs in crop:", len(label_loss)) 234 | if not len(label_loss) == 0: 235 | mean_label_loss = sum(label_loss) / len(label_loss) 236 | # mean_label_loss = sum(label_loss) / \ 237 | # torch.count_nonzero(label_loss) 238 | vprint("mean_label_loss", mean_label_loss) 239 | element_blob_loss.append(mean_label_loss) 240 | 241 | # compute mean 242 | vprint("element_blob_loss:", element_blob_loss) 243 | mean_element_blob_loss = 0 244 | vprint("elements in batch:", len(element_blob_loss)) 245 | if not len(element_blob_loss) == 0: 246 | mean_element_blob_loss = sum(element_blob_loss) / len(element_blob_loss) 247 | # element_blob_loss) / torch.count_nonzero(element_blob_loss) 248 | 249 | vprint("mean_element_blob_loss", mean_element_blob_loss) 250 | 251 | return mean_element_blob_loss 252 | 253 | 254 | def compute_loss( 255 | blob_loss_dict: dict, 256 | criterion_dict: dict, 257 | blob_criterion_dict: dict, 258 | raw_network_outputs: torch.Tensor, 259 | binary_label: torch.Tensor, 260 | multi_label: torch.Tensor, 261 | ): 262 | """ 263 | This function computes the total loss. It has a global main loss and the blob loss term which is computed separately for each connected component. The binary_label is the binarized label for the global part. The multi label features separate integer labels for each connected component. 264 | 265 | Example inputs should look like: 266 | 267 | blob_loss_dict = { 268 | "main_weight": 1, 269 | "blob_weight": 0, 270 | } 271 | 272 | criterion_dict = { 273 | "bce": { 274 | "name": "bce", 275 | "loss": BCEWithLogitsLoss(reduction="mean"), 276 | "weight": 1.0, 277 | "sigmoid": False, 278 | }, 279 | "dice": { 280 | "name": "dice", 281 | "loss": DiceLoss( 282 | include_background=True, 283 | to_onehot_y=False, 284 | sigmoid=True, 285 | softmax=False, 286 | squared_pred=False, 287 | ), 288 | "weight": 1.0, 289 | "sigmoid": False, 290 | }, 291 | } 292 | 293 | blob_criterion_dict = { 294 | "bce": { 295 | "name": "bce", 296 | "loss": BCEWithLogitsLoss(reduction="mean"), 297 | "weight": 1.0, 298 | "sigmoid": False, 299 | }, 300 | "dice": { 301 | "name": "dice", 302 | "loss": DiceLoss( 303 | include_background=True, 304 | to_onehot_y=False, 305 | sigmoid=True, 306 | softmax=False, 307 | squared_pred=False, 308 | ), 309 | "weight": 1.0, 310 | "sigmoid": False, 311 | }, 312 | } 313 | """ 314 | 315 | main_weight = blob_loss_dict["main_weight"] 316 | blob_weight = blob_loss_dict["blob_weight"] 317 | 318 | # main loss 319 | if main_weight > 0: 320 | vprint("main_weight greater than zero:", main_weight) 321 | # vprint("main_label:", main_label) 322 | main_loss = compute_compound_loss( 323 | criterion_dict=criterion_dict, 324 | raw_network_outputs=raw_network_outputs, 325 | label=binary_label, 326 | blob_loss_mode=False, 327 | ) 328 | 329 | if blob_weight > 0: 330 | vprint("blob_weight greater than zero:", blob_weight) 331 | blob_loss = compute_compound_loss( 332 | criterion_dict=blob_criterion_dict, 333 | raw_network_outputs=raw_network_outputs, 334 | label=multi_label, 335 | blob_loss_mode=True, 336 | ) 337 | 338 | # final loss 339 | if blob_weight == 0 and main_weight > 0: 340 | vprint( 341 | "main_weight:", 342 | main_weight, 343 | "// blob_weight:", 344 | blob_weight, 345 | "// computing main loss only", 346 | ) 347 | loss = main_loss 348 | blob_loss = 0 349 | 350 | elif main_weight == 0 and blob_weight > 0: 351 | vprint( 352 | "main_weight:", 353 | main_weight, 354 | "// blob_weight:", 355 | blob_weight, 356 | "// computing blob loss only", 357 | ) 358 | loss = blob_loss 359 | main_loss = 0 # we set this to 0 360 | 361 | elif main_weight > 0 and blob_weight > 0: 362 | vprint( 363 | "main_weight:", 364 | main_weight, 365 | "// blob_weight:", 366 | blob_weight, 367 | "// computing blob loss", 368 | ) 369 | loss = main_loss * main_weight + blob_loss * blob_weight 370 | 371 | else: 372 | vprint("defaulting to equal weighted blob loss") 373 | loss = main_loss + blob_loss 374 | 375 | vprint("blob loss:", blob_loss) 376 | vprint("main loss:", main_loss) 377 | vprint("effective loss:", loss) 378 | 379 | return loss, main_loss, blob_loss 380 | 381 | 382 | def get_loss_value(loss): 383 | if loss == 0: 384 | return 0 385 | 386 | return loss.item() 387 | --------------------------------------------------------------------------------