├── .gitignore ├── README.md └── losses ├── binary_losses.py └── multiclass_losses.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image segmentation loss functions implemented in Keras 2 | Binary and multiclass loss function for image segmentation with one-hot encoded masks of 3 | shape=(, , , ). Implemented in Keras. 4 | 5 | ## Loss functions 6 | All loss functions are implemented using Keras callback structure: 7 | 8 | ```python 9 | def example_loss() -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 10 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 11 | pass 12 | 13 | return loss 14 | ``` 15 | 16 | ### Binary loss functions 17 | | Loss function | Implementation | 18 | |----------------------------------|-----------------------------------------------------------------------------------------------------------------| 19 | | Dice's coefficient loss | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/binary_losses.py#L50 | 20 | | Weighted Dice cross entropy loss | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/binary_losses.py#L74 | 21 | | Tversky loss | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/binary_losses.py#L114 | 22 | | Weighted cross entropy | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/binary_losses.py#L137 | 23 | | Balanced cross entropy | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/binary_losses.py#L177 | 24 | | Focal loss | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/binary_losses.py#L223 | 25 | 26 | ### Multiclass loss functions 27 | | Loss function | Implementation | 28 | |------------------------------------------|---------------------------------------------------------------------------------------------------------------------| 29 | | Weighted Tanimoto loss | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/multiclass_losses.py#L8 | 30 | | Weighted Dice's coefficient loss | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/multiclass_losses.py#L42 | 31 | | Weighted squared Dice's coefficient loss | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/multiclass_losses.py#L74 | 32 | | Weighted cross entropy | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/multiclass_losses.py#L107 | 33 | | Focal loss | https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/multiclass_losses.py#L150 | 34 | -------------------------------------------------------------------------------- /losses/binary_losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.backend as K 3 | from typing import Callable 4 | 5 | 6 | def binary_tversky_coef(y_true: tf.Tensor, y_pred: tf.Tensor, beta: float, smooth: float = 1.) -> tf.Tensor: 7 | """ 8 | Tversky coefficient is a generalization of the Dice's coefficient. It adds an extra weight (β) to false positives 9 | and false negatives: 10 | 11 | TC(p, p̂) = p*p̂/[p*p̂ + β*(1-p)*p̂ + (1-β)*p*(1-p̂)] 12 | 13 | When β=1/2, Tversky coefficient is equal to the Dice's coefficient: 14 | 15 | TL(p, p̂) = p*p̂/[p*p̂ + (1/2)*(1-p)*p̂ + (1-(1/2))*p*(1-p̂)] 16 | = p*p̂/[p*p̂ + (1/2)*p̂ - (1/2)*p*p̂ + (1/2)*p*(1-p̂)] 17 | = p*p̂/[p*p̂ + (1/2)*p̂ - (1/2)*p*p̂ + (1/2)*p - (1/2)*p*p̂)] 18 | = p*p̂/[p*p - p*p̂̂ + (1/2)*p̂ + (1/2)*p] 19 | = p*p̂/[(1/2)*p̂ + (1/2)*p] 20 | = p*p̂/[(1/2)*(p̂+p)] 21 | = 2*p*p̂/(p̂+p) 22 | 23 | :param y_true: True masks (tf.Tensor, shape=(, , , 1)) 24 | :param y_pred: Predicted masks (tf.Tensor, shape=(, , , 1)) 25 | :param beta: Weight coefficient (float) 26 | :param smooth: Smoothing factor (float, default = 1.) 27 | :return: Tversky coefficient (tf.Tensor, shape=(, , )) 28 | """ 29 | axis_to_reduce = range(1, K.ndim(y_pred)) # All axis but first (batch) 30 | numerator = K.sum(y_true * y_pred, axis=axis_to_reduce) # p*p̂ 31 | denominator = y_true * y_pred + beta * (1 - y_true) * y_pred + (1 - beta) * y_true * (1 - y_pred) # p*p̂ + β*(1-p)*p̂ + (1-β)*p*(1-p̂) 32 | denominator = K.sum(denominator, axis=axis_to_reduce) 33 | 34 | return (numerator + smooth) / (denominator + smooth) # (p*p̂ + smooth)/[p*p̂ + β*(1-p)*p̂ + (1-β)*p*(1-p̂) + smooth] 35 | 36 | 37 | def convert_to_logits(y_pred: tf.Tensor) -> tf.Tensor: 38 | """ 39 | Converting output of sigmoid to logits. 40 | 41 | :param y_pred: Predictions after sigmoid (, shape=(None, , , 1)). 42 | :return: Logits (tf.Tensor, shape=(, , , 1)) 43 | """ 44 | # To avoid unwanted behaviour of log operation 45 | y_pred = K.clip(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon()) 46 | 47 | return K.log(y_pred / (1 - y_pred)) 48 | 49 | 50 | def binary_dice_coef_loss(smooth: float = 1.) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 51 | """ 52 | Dice coefficient loss: 53 | 54 | DL(p, p̂) = 1 - (2*p*p̂+smooth)/(p+p̂+smooth) 55 | 56 | Used as loss function for binary image segmentation with one-hot encoded masks. 57 | 58 | :param smooth: Smoothing factor (float, default=1.) 59 | :return: Dice coefficient loss function (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 60 | """ 61 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 62 | """ 63 | Compute the Dice loss (Tversky loss with β=0.5). 64 | 65 | :param y_true: True masks (tf.Tensor, shape=(, , , 1)) 66 | :param y_pred: Predicted masks (tf.Tensor, shape=(, , , 1)) 67 | :return: Dice coefficient loss for each observation in batch (tf.Tensor, shape=(,)) 68 | """ 69 | return 1 - binary_tversky_coef(y_true=y_true, y_pred=y_pred, beta=0.5, smooth=smooth) 70 | 71 | return loss 72 | 73 | 74 | def binary_weighted_dice_crossentropy_loss(smooth: float = 1., 75 | beta: float = 0.5) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 76 | """ 77 | Weighted Dice cross entropy combination loss is a weighted combination between Dice's coefficient loss and 78 | binary cross entropy: 79 | 80 | DL(p, p̂) = 1 - (2*p*p̂+smooth)/(p+p̂+smooth) 81 | CE(p, p̂) = - [p*log(p̂ + 1e-7) + (1-p)*log(1-p̂ + 1e-7)] 82 | WDCE(p, p̂) = weight*DL + (1-weight)*CE 83 | = weight*[1 - (2*p*p̂+smooth)/(p+p̂+smooth)] - (1-weight)*[p*log(p̂ + 1e-7) + (1-p)*log(1-p̂ + 1e-7)] 84 | 85 | Used as loss function for binary image segmentation with one-hot encoded masks. 86 | 87 | :param smooth: Smoothing factor (float, default=1.) 88 | :param beta: Loss weight coefficient (float, default=0.5) 89 | :return: Dice cross entropy combination loss (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 90 | """ 91 | assert 0. <= beta <= 1., "Loss weight has to be between 0.0 and 1.0" 92 | 93 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 94 | """ 95 | Compute the Dice cross entropy combination loss. 96 | 97 | :param y_true: True masks (tf.Tensor, shape=(, , , 1)) 98 | :param y_pred: Predicted masks (tf.Tensor, shape=(, , , 1)) 99 | :return: Dice cross entropy combination loss (tf.Tensor, shape=(,)) 100 | """ 101 | cross_entropy = K.binary_crossentropy(target=y_true, output=y_true) 102 | 103 | # Average over each data point/image in batch 104 | axis_to_reduce = range(1, K.ndim(cross_entropy)) 105 | cross_entropy = K.mean(x=cross_entropy, axis=axis_to_reduce) 106 | 107 | dice_coefficient = binary_tversky_coef(y_true=y_true, y_pred=y_pred, beta=0.5, smooth=smooth) 108 | 109 | return beta*(1. - dice_coefficient) + (1. - beta)*cross_entropy 110 | 111 | return loss 112 | 113 | 114 | def binary_tversky_loss(beta: float) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 115 | """ 116 | Tversky loss: 117 | 118 | TL(p, p̂) = 1 - p*p̂/[p*p̂ + β*(1-p)*p̂ + (1-β)*p*(1-p̂)] 119 | 120 | Used as loss function for binary image segmentation with one-hot encoded masks. 121 | 122 | :param beta: Weight coefficient (float) 123 | :return: Tversky loss function (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 124 | """ 125 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 126 | """ 127 | Compute the Tversky loss. 128 | :param y_true: True masks (tf.Tensor, shape=(, , , 1)) 129 | :param y_pred: Predicted masks (tf.Tensor, shape=(, , , 1)) 130 | :return: Tversky loss (tf.Tensor, shape=(,)) 131 | """ 132 | return 1-binary_tversky_coef(y_true, y_pred, beta=beta) 133 | 134 | return loss 135 | 136 | 137 | def binary_weighted_cross_entropy(beta: float, is_logits: bool = False) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 138 | """ 139 | Weighted cross entropy. All positive examples get weighted by the coefficient beta: 140 | 141 | WCE(p, p̂) = −[β*p*log(p̂) + (1−p)*log(1−p̂)] 142 | 143 | To decrease the number of false negatives, set β>1. To decrease the number of false positives, set β<1. 144 | 145 | If last layer of network is a sigmoid function, y_pred needs to be reversed into logits before computing the 146 | weighted cross entropy. To do this, we're using the same method as implemented in Keras binary_crossentropy: 147 | https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/keras/backend.py#L3525 148 | 149 | Used as loss function for binary image segmentation with one-hot encoded masks. 150 | 151 | :param beta: Weight coefficient (float) 152 | :param is_logits: If y_pred are logits (bool, default=False) 153 | :return: Weighted cross entropy loss function (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 154 | """ 155 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 156 | """ 157 | Computes the weighted cross entropy. 158 | 159 | :param y_true: True masks (tf.Tensor, shape=(, , , 1)) 160 | :param y_pred: Predicted masks (tf.Tensor, shape=(, , , 1)) 161 | :return: Weighted cross entropy (tf.Tensor, shape=(,)) 162 | """ 163 | if not is_logits: 164 | y_pred = convert_to_logits(y_pred) 165 | 166 | wce_loss = tf.nn.weighted_cross_entropy_with_logits(labels=y_true, logits=y_pred, pos_weight=beta) 167 | 168 | # Average over each data point/image in batch 169 | axis_to_reduce = range(1, K.ndim(wce_loss)) 170 | wce_loss = K.mean(wce_loss, axis=axis_to_reduce) 171 | 172 | return wce_loss 173 | 174 | return loss 175 | 176 | 177 | def binary_balanced_cross_entropy(beta: float, is_logits: bool = False) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 178 | """ 179 | Balanced cross entropy. Similar to weighted cross entropy (see weighted_cross_entropy), 180 | but both positive and negative examples get weighted: 181 | 182 | BCE(p, p̂) = −[β*p*log(p̂) + (1-β)*(1−p)*log(1−p̂)] 183 | 184 | If last layer of network is a sigmoid function, y_pred needs to be reversed into logits before computing the 185 | balanced cross entropy. To do this, we're using the same method as implemented in Keras binary_crossentropy: 186 | https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/keras/backend.py#L3525 187 | 188 | Used as loss function for binary image segmentation with one-hot encoded masks. 189 | 190 | :param beta: Weight coefficient (float) 191 | :param is_logits: If y_pred are logits (bool, default=False) 192 | :return: Balanced cross entropy loss function (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 193 | """ 194 | if beta == 1.: # To avoid division by zero 195 | beta -= tf.keras.backend.epsilon() 196 | 197 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 198 | """ 199 | Computes the balanced cross entropy in the following way: 200 | 201 | BCE(p, p̂) = −[(β/(1-β))*p*log(p̂) + (1−p)*log(1−p̂)]*(1-β) = −[β*p*log(p̂) + (1-β)*(1−p)*log(1−p̂)] 202 | 203 | :param y_true: Ground truth (tf.Tensor, shape=(, , , 1)) 204 | :param y_pred: Predictions (tf.Tensor, shape=(, , , 1)) 205 | :return: Balanced cross entropy (tf.Tensor, shape=(,)) 206 | """ 207 | if not is_logits: 208 | y_pred = convert_to_logits(y_pred) 209 | 210 | pos_weight = beta / (1 - beta) 211 | bce_loss = tf.nn.weighted_cross_entropy_with_logits(labels=y_true, logits=y_pred, pos_weight=pos_weight) 212 | bce_loss = bce_loss * (1 - beta) 213 | 214 | # Average over each data point/image in batch 215 | axis_to_reduce = range(1, K.ndim(bce_loss)) 216 | bce_loss = K.mean(bce_loss, axis=axis_to_reduce) 217 | 218 | return bce_loss 219 | 220 | return loss 221 | 222 | 223 | def binary_focal_loss(beta: float, gamma: float = 2.) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 224 | """ 225 | Focal loss is derived from balanced cross entropy, where focal loss adds an extra focus on hard examples in the 226 | dataset: 227 | 228 | FL(p, p̂) = −[β*(1-p̂)ᵞ*p*log(p̂) + (1-β)*p̂ᵞ*(1−p)*log(1−p̂)] 229 | 230 | When γ = 0, we obtain balanced cross entropy. 231 | 232 | Paper: https://arxiv.org/pdf/1708.02002.pdf 233 | 234 | Used as loss function for binary image segmentation with one-hot encoded masks. 235 | 236 | :param beta: Weight coefficient (float) 237 | :param gamma: Focusing parameter, γ ≥ 0 (float, default=2.) 238 | :return: Focal loss (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 239 | """ 240 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 241 | """ 242 | Computes the focal loss. 243 | 244 | :param y_true: True masks (tf.Tensor, shape=(, , , 1)) 245 | :param y_pred: Predicted masks (tf.Tensor, shape=(, , , 1)) 246 | :return: Focal loss (tf.Tensor, shape=(,)) 247 | """ 248 | f_loss = beta * (1 - y_pred) ** gamma * y_true * K.log(y_pred) # β*(1-p̂)ᵞ*p*log(p̂) 249 | f_loss += (1 - beta) * y_pred ** gamma * (1 - y_true) * K.log(1 - y_pred) # (1-β)*p̂ᵞ*(1−p)*log(1−p̂) 250 | f_loss = -f_loss # −[β*(1-p̂)ᵞ*p*log(p̂) + (1-β)*p̂ᵞ*(1−p)*log(1−p̂)] 251 | 252 | # Average over each data point/image in batch 253 | axis_to_reduce = range(1, K.ndim(f_loss)) 254 | f_loss = K.mean(f_loss, axis=axis_to_reduce) 255 | 256 | return f_loss 257 | 258 | return loss 259 | -------------------------------------------------------------------------------- /losses/multiclass_losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import backend as K 3 | from tensorflow.keras.activations import softmax 4 | from typing import Callable, Union 5 | import numpy as np 6 | 7 | 8 | def multiclass_weighted_tanimoto_loss(class_weights: Union[list, np.ndarray, tf.Tensor]) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 9 | """ 10 | Weighted Tanimoto loss. 11 | 12 | Defined in the paper "ResUNet-a: a deep learning framework for semantic segmentation of remotely sensed data", 13 | under 3.2.4. Generalization to multiclass imbalanced problems. See https://arxiv.org/pdf/1904.00592.pdf 14 | 15 | Used as loss function for multi-class image segmentation with one-hot encoded masks. 16 | 17 | :param class_weights: Class weight coefficients (Union[list, np.ndarray, tf.Tensor], len=) 18 | :return: Weighted Tanimoto loss function (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 19 | """ 20 | if not isinstance(class_weights, tf.Tensor): 21 | class_weights = tf.constant(class_weights) 22 | 23 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 24 | """ 25 | Compute weighted Tanimoto loss. 26 | 27 | :param y_true: True masks (tf.Tensor, shape=(, , , )) 28 | :param y_pred: Predicted masks (tf.Tensor, shape=(, , , )) 29 | :return: Weighted Tanimoto loss (tf.Tensor, shape=(None, )) 30 | """ 31 | axis_to_reduce = range(1, K.ndim(y_pred)) # All axis but first (batch) 32 | numerator = y_true * y_pred * class_weights 33 | numerator = K.sum(numerator, axis=axis_to_reduce) 34 | 35 | denominator = (y_true**2 + y_pred**2 - y_true * y_pred) * class_weights 36 | denominator = K.sum(denominator, axis=axis_to_reduce) 37 | return 1 - numerator / denominator 38 | 39 | return loss 40 | 41 | 42 | def multiclass_weighted_dice_loss(class_weights: Union[list, np.ndarray, tf.Tensor]) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 43 | """ 44 | Weighted Dice loss. 45 | 46 | Used as loss function for multi-class image segmentation with one-hot encoded masks. 47 | 48 | :param class_weights: Class weight coefficients (Union[list, np.ndarray, tf.Tensor], len=) 49 | :return: Weighted Dice loss function (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 50 | """ 51 | if not isinstance(class_weights, tf.Tensor): 52 | class_weights = tf.constant(class_weights) 53 | 54 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 55 | """ 56 | Compute weighted Dice loss. 57 | 58 | :param y_true: True masks (tf.Tensor, shape=(, , , )) 59 | :param y_pred: Predicted masks (tf.Tensor, shape=(, , , )) 60 | :return: Weighted Dice loss (tf.Tensor, shape=(None,)) 61 | """ 62 | axis_to_reduce = range(1, K.ndim(y_pred)) # Reduce all axis but first (batch) 63 | numerator = y_true * y_pred * class_weights # Broadcasting 64 | numerator = 2. * K.sum(numerator, axis=axis_to_reduce) 65 | 66 | denominator = (y_true + y_pred) * class_weights # Broadcasting 67 | denominator = K.sum(denominator, axis=axis_to_reduce) 68 | 69 | return 1 - numerator / denominator 70 | 71 | return loss 72 | 73 | 74 | def multiclass_weighted_squared_dice_loss(class_weights: Union[list, np.ndarray, tf.Tensor]) -> Callable[[tf.Tensor, tf.Tensor], 75 | tf.Tensor]: 76 | """ 77 | Weighted squared Dice loss. 78 | 79 | Used as loss function for multi-class image segmentation with one-hot encoded masks. 80 | 81 | :param class_weights: Class weight coefficients (Union[list, np.ndarray, tf.Tensor], len=) 82 | :return: Weighted squared Dice loss function (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 83 | """ 84 | if not isinstance(class_weights, tf.Tensor): 85 | class_weights = tf.constant(class_weights) 86 | 87 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 88 | """ 89 | Compute weighted squared Dice loss. 90 | 91 | :param y_true: True masks (tf.Tensor, shape=(, , , )) 92 | :param y_pred: Predicted masks (tf.Tensor, shape=(, , , )) 93 | :return: Weighted squared Dice loss (tf.Tensor, shape=(None,)) 94 | """ 95 | axis_to_reduce = range(1, K.ndim(y_pred)) # Reduce all axis but first (batch) 96 | numerator = y_true * y_pred * class_weights # Broadcasting 97 | numerator = 2. * K.sum(numerator, axis=axis_to_reduce) 98 | 99 | denominator = (y_true**2 + y_pred**2) * class_weights # Broadcasting 100 | denominator = K.sum(denominator, axis=axis_to_reduce) 101 | 102 | return 1 - numerator / denominator 103 | 104 | return loss 105 | 106 | 107 | def multiclass_weighted_cross_entropy(class_weights: list, is_logits: bool = False) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 108 | """ 109 | Multi-class weighted cross entropy. 110 | 111 | WCE(p, p̂) = −Σp*log(p̂)*class_weights 112 | 113 | Used as loss function for multi-class image segmentation with one-hot encoded masks. 114 | 115 | :param class_weights: Weight coefficients (list of floats) 116 | :param is_logits: If y_pred are logits (bool) 117 | :return: Weighted cross entropy loss function (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 118 | """ 119 | if not isinstance(class_weights, tf.Tensor): 120 | class_weights = tf.constant(class_weights) 121 | 122 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 123 | """ 124 | Computes the weighted cross entropy. 125 | 126 | :param y_true: Ground truth (tf.Tensor, shape=(None, None, None, None)) 127 | :param y_pred: Predictions (tf.Tensor, shape=(, , , )) 128 | :return: Weighted cross entropy (tf.Tensor, shape=(,)) 129 | """ 130 | assert len(class_weights) == y_pred.shape[-1], f"Number of class_weights ({len(class_weights)}) needs to be the same as number " \ 131 | f"of classes ({y_pred.shape[-1]})" 132 | 133 | if is_logits: 134 | y_pred = softmax(y_pred, axis=-1) 135 | 136 | y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon()) # To avoid unwanted behaviour in K.log(y_pred) 137 | 138 | # p * log(p̂) * class_weights 139 | wce_loss = y_true * K.log(y_pred) * class_weights 140 | 141 | # Average over each data point/image in batch 142 | axis_to_reduce = range(1, K.ndim(wce_loss)) 143 | wce_loss = K.mean(wce_loss, axis=axis_to_reduce) 144 | 145 | return -wce_loss 146 | 147 | return loss 148 | 149 | 150 | def multiclass_focal_loss(class_weights: Union[list, np.ndarray, tf.Tensor], 151 | gamma: Union[list, np.ndarray, tf.Tensor]) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: 152 | """ 153 | Focal loss. 154 | 155 | FL(p, p̂) = -∑class_weights*(1-p̂)ᵞ*p*log(p̂) 156 | 157 | Used as loss function for multi-class image segmentation with one-hot encoded masks. 158 | 159 | :param class_weights: Class weight coefficients (Union[list, np.ndarray, tf.Tensor], len=) 160 | :param gamma: Focusing parameters, γ_i ≥ 0 (Union[list, np.ndarray, tf.Tensor], len=) 161 | :return: Focal loss function (Callable[[tf.Tensor, tf.Tensor], tf.Tensor]) 162 | """ 163 | if not isinstance(class_weights, tf.Tensor): 164 | class_weights = tf.constant(class_weights) 165 | if not isinstance(gamma, tf.Tensor): 166 | gamma = tf.constant(gamma) 167 | 168 | def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: 169 | """ 170 | Compute focal loss. 171 | 172 | :param y_true: True masks (tf.Tensor, shape=(, , , )) 173 | :param y_pred: Predicted masks (tf.Tensor, shape=(, , , )) 174 | :return: Focal loss (tf.Tensor, shape=(None,)) 175 | """ 176 | f_loss = -(class_weights * (1-y_pred)**gamma * y_true * K.log(y_pred)) 177 | 178 | # Average over each data point/image in batch 179 | axis_to_reduce = range(1, K.ndim(f_loss)) 180 | f_loss = K.mean(f_loss, axis=axis_to_reduce) 181 | 182 | return f_loss 183 | 184 | return loss 185 | --------------------------------------------------------------------------------