├── src ├── test │ ├── __init__.py │ └── test_focal_loss.py └── loss_function │ ├── __init__.py │ └── losses.py ├── focal_loss.png ├── Makefile ├── requirements.txt ├── .github └── workflows │ └── python-app.yml └── README.md /src/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/loss_function/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /focal_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umbertogriffo/focal-loss-keras/HEAD/focal_loss.png -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: check clean setup-env export-env test 2 | 3 | all: check setup-env test 4 | 5 | check: 6 | which pip3 7 | which python3 8 | 9 | clean: 10 | rm -rf .pyenv/ 11 | rm -rf .pytest_cache/ 12 | 13 | setup-env: 14 | virtualenv .pyenv; \ 15 | . .pyenv/bin/activate; \ 16 | pip3 install -r requirements.txt; \ 17 | 18 | export-env: 19 | . .pyenv/bin/activate; \ 20 | pip3 freeze > requirements.txt 21 | 22 | test: 23 | . .pyenv/bin/activate && cd src/test; \ 24 | pytest -s -vv; 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astor==0.8.1 3 | attrs==19.3.0 4 | dill==0.3.2 5 | gast==0.3.3 6 | google-pasta==0.2.0 7 | grpcio==1.30.0 8 | h5py==2.10.0 9 | importlib-metadata==1.7.0 10 | Keras==2.4.3 11 | Keras-Applications==1.0.8 12 | Keras-Preprocessing==1.1.2 13 | Markdown==3.2.2 14 | more-itertools==8.4.0 15 | numpy==1.18.1 16 | opt-einsum==3.2.1 17 | packaging==20.4 18 | pluggy==0.13.1 19 | protobuf==3.12.2 20 | py==1.9.0 21 | pyparsing==2.4.7 22 | pytest==5.4.1 23 | PyYAML==5.3.1 24 | scipy==1.4.1 25 | six==1.15.0 26 | tensorboard==2.2.2 27 | tensorflow==2.3.1 28 | tensorflow-estimator==2.2.0 29 | termcolor==1.1.0 30 | tqdm==4.42.1 31 | wcwidth==0.2.5 32 | Werkzeug==1.0.1 33 | wrapt==1.12.1 34 | zipp==3.1.0 35 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.8 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.8 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install flake8 pytest 27 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 28 | - name: Lint with flake8 29 | run: | 30 | # stop the build if there are Python syntax errors or undefined names 31 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 32 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 33 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 34 | - name: Test with pytest 35 | run: | 36 | pytest 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Focal Loss 2 | 3 | ![CI](https://github.com/umbertogriffo/focal-loss-keras/workflows/CI/badge.svg) 4 | 5 | [focal loss](https://arxiv.org/abs/1708.02002) down-weights the well-classified examples. This has the net effect of putting more training emphasis on that data that is hard to classify. In a practical setting where we have a data imbalance, our majority class will quickly become well-classified since we have much more data for it. Thus, in order to insure that we also achieve high accuracy on our minority class, we can use the focal loss to give those minority class examples more relative weight during training. 6 | ![](https://github.com/umbertogriffo/focal-loss-keras/blob/master/focal_loss.png) 7 | 8 | The focal loss can easily be implemented in Keras as a custom loss function. 9 | 10 | ## Usage 11 | Compile your model with focal loss as sample: 12 | 13 | **Binary** 14 | >model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam) 15 | 16 | **Categorical** 17 | >model.compile(loss=[categorical_focal_loss(alpha=[[.25, .25, .25]], gamma=2)], metrics=["accuracy"], optimizer=adam) 18 | 19 | Alpha is used to specify the weight of different categories/labels, the size of the array needs to be consistent with the number of classes. 20 | 21 | **Convert a trained keras model into an inference tensorflow model** 22 | 23 | If you use the [@amir-abdi's code](https://github.com/amir-abdi/keras_to_tensorflow) to convert a trained keras model into an inference tensorflow model, you have to serialize nested functions. 24 | In order to serialize nested functions you have to install dill in your anaconda environment as follow: 25 | 26 | >conda install -c anaconda dill 27 | 28 | then modify **keras_to_tensorflow.py** adding this piece of code after the imports: 29 | ``` python 30 | import dill 31 | custom_object = {'binary_focal_loss_fixed': dill.loads(dill.dumps(binary_focal_loss(gamma=2., alpha=.25))), 32 | 'categorical_focal_loss_fixed': dill.loads(dill.dumps(categorical_focal_loss(gamma=2., alpha=[[.25, .25, .25]]))), 33 | 'categorical_focal_loss': categorical_focal_loss, 34 | 'binary_focal_loss': binary_focal_loss} 35 | ``` 36 | and modify the beginning of **load_model** method as follow: 37 | ``` python 38 | if not Path(input_model_path).exists(): 39 | raise FileNotFoundError( 40 | 'Model file `{}` does not exist.'.format(input_model_path)) 41 | try: 42 | model = keras.models.load_model(input_model_path, custom_objects=custom_object) 43 | return model 44 | ``` 45 | 46 | ## Who is citing this work? 47 | 48 | * In the open-source Python library [AUCMEDI](https://frankkramer-lab.github.io/aucmedi/), a high-level API that allows fast setup of medical image classification pipelines with state-of-the-art methods in just a few lines of code. 49 | * [Loss Functions](https://frankkramer-lab.github.io/aucmedi/reference/neural_network/loss_functions/) 50 | 51 | ## References 52 | * The binary implementation is based [@mkocabas's code](https://github.com/mkocabas/focal-loss-keras) 53 | -------------------------------------------------------------------------------- /src/loss_function/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define our custom loss function. 3 | """ 4 | import numpy as np 5 | from keras import backend as K 6 | import tensorflow as tf 7 | 8 | import dill 9 | 10 | 11 | def binary_focal_loss(gamma=2., alpha=.25): 12 | """ 13 | Binary form of focal loss. 14 | 15 | FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t) 16 | 17 | where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively. 18 | 19 | References: 20 | https://arxiv.org/pdf/1708.02002.pdf 21 | Usage: 22 | model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam) 23 | 24 | """ 25 | 26 | def binary_focal_loss_fixed(y_true, y_pred): 27 | """ 28 | :param y_true: A tensor of the same shape as `y_pred` 29 | :param y_pred: A tensor resulting from a sigmoid 30 | :return: Output tensor. 31 | """ 32 | y_true = tf.cast(y_true, tf.float32) 33 | # Define epsilon so that the back-propagation will not result in NaN for 0 divisor case 34 | epsilon = K.epsilon() 35 | # Add the epsilon to prediction value 36 | # y_pred = y_pred + epsilon 37 | # Clip the prediciton value 38 | y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon) 39 | # Calculate p_t 40 | p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred) 41 | # Calculate alpha_t 42 | alpha_factor = K.ones_like(y_true) * alpha 43 | alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor) 44 | # Calculate cross entropy 45 | cross_entropy = -K.log(p_t) 46 | weight = alpha_t * K.pow((1 - p_t), gamma) 47 | # Calculate focal loss 48 | loss = weight * cross_entropy 49 | # Sum the losses in mini_batch 50 | loss = K.mean(K.sum(loss, axis=1)) 51 | return loss 52 | 53 | return binary_focal_loss_fixed 54 | 55 | 56 | def categorical_focal_loss(alpha, gamma=2.): 57 | """ 58 | Softmax version of focal loss. 59 | When there is a skew between different categories/labels in your data set, you can try to apply this function as a 60 | loss. 61 | m 62 | FL = ∑ -alpha * (1 - p_o,c)^gamma * y_o,c * log(p_o,c) 63 | c=1 64 | 65 | where m = number of classes, c = class and o = observation 66 | 67 | Parameters: 68 | alpha -- the same as weighing factor in balanced cross entropy. Alpha is used to specify the weight of different 69 | categories/labels, the size of the array needs to be consistent with the number of classes. 70 | gamma -- focusing parameter for modulating factor (1-p) 71 | 72 | Default value: 73 | gamma -- 2.0 as mentioned in the paper 74 | alpha -- 0.25 as mentioned in the paper 75 | 76 | References: 77 | Official paper: https://arxiv.org/pdf/1708.02002.pdf 78 | https://www.tensorflow.org/api_docs/python/tf/keras/backend/categorical_crossentropy 79 | 80 | Usage: 81 | model.compile(loss=[categorical_focal_loss(alpha=[[.25, .25, .25]], gamma=2)], metrics=["accuracy"], optimizer=adam) 82 | """ 83 | 84 | alpha = np.array(alpha, dtype=np.float32) 85 | 86 | def categorical_focal_loss_fixed(y_true, y_pred): 87 | """ 88 | :param y_true: A tensor of the same shape as `y_pred` 89 | :param y_pred: A tensor resulting from a softmax 90 | :return: Output tensor. 91 | """ 92 | 93 | # Clip the prediction value to prevent NaN's and Inf's 94 | epsilon = K.epsilon() 95 | y_pred = K.clip(y_pred, epsilon, 1. - epsilon) 96 | 97 | # Calculate Cross Entropy 98 | cross_entropy = -y_true * K.log(y_pred) 99 | 100 | # Calculate Focal Loss 101 | loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy 102 | 103 | # Compute mean loss in mini_batch 104 | return K.mean(K.sum(loss, axis=-1)) 105 | 106 | return categorical_focal_loss_fixed 107 | 108 | 109 | if __name__ == '__main__': 110 | 111 | # Test serialization of nested functions 112 | bin_inner = dill.loads(dill.dumps(binary_focal_loss(gamma=2., alpha=.25))) 113 | print(bin_inner) 114 | 115 | cat_inner = dill.loads(dill.dumps(categorical_focal_loss(gamma=2., alpha=.25))) 116 | print(cat_inner) 117 | -------------------------------------------------------------------------------- /src/test/test_focal_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from keras import backend as K 4 | 5 | import unittest 6 | 7 | from loss_function.losses import categorical_focal_loss, binary_focal_loss 8 | 9 | 10 | class TestFocalLoss(unittest.TestCase): 11 | 12 | def test_is_equal_to_binary_cross_entropy(self): 13 | """ When alpha is equal to 1 and gamma is equal to 0 the focal loss must be equal to 14 | the binary crossentropy loss with 'sample_weight' = [1, 0].""" 15 | y_true = np.array([[0., 1.], [0., 0.]]) 16 | y_pred = np.array([[0.6, 0.4], [0.4, 0.6]], dtype=np.float32) 17 | 18 | print("binary_cross_entropy") 19 | bce = tf.keras.losses.BinaryCrossentropy() 20 | bce_value = K.mean(bce(y_true, y_pred, sample_weight=[1, 0])).numpy() 21 | print(bce_value) 22 | 23 | print("focal_loss") 24 | bfl = binary_focal_loss(alpha=1, gamma=0.) 25 | bfl_value = K.mean(bfl(y_true, y_pred)).numpy() 26 | print(bfl_value) 27 | 28 | self.assertAlmostEquals(bce_value, bfl_value, places=4) 29 | 30 | def test_is_equal_to_categorical_cross_entropy_pixel_based(self): 31 | """ When alpha is equal to 1 and gamma is equal to 0 the focal loss of a Pixel-based batch size must be equal to 32 | the categorical crossentropy loss.""" 33 | y_true = np.array([[0, 1, 0], [0, 0, 1]]) 34 | y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]], dtype=np.float32) 35 | 36 | print("Pixel-based labelling") 37 | print("Data dimension as [batch_size (amount of pixels), one_hot_encoding of pixel label]") 38 | print(y_true.shape) 39 | 40 | print("categorical_cross_entropy") 41 | cce = tf.keras.losses.categorical_crossentropy 42 | cce_value = K.mean(cce(y_true, y_pred)).numpy() 43 | print(cce_value) 44 | 45 | print("focal_loss") 46 | cfl = categorical_focal_loss(alpha=[[1, 1, 1]], gamma=0.) 47 | cfl_value = cfl(y_true, y_pred).numpy() 48 | print(cfl_value) 49 | 50 | self.assertEqual(cce_value, cfl_value) 51 | 52 | def test_is_equal_to_categorical_cross_entropy_image_based(self): 53 | """ When alpha is equal to 1 and gamma is equal to 0 the focal loss of a Image based batch size must be equal to 54 | the categorical crossentropy loss.""" 55 | y_true = np.array([[[[1, 0, 0, 0], [0, 1, 0, 0]], [[0, 0, 0, 1], [0, 0, 1, 0]]], 56 | [[[0, 1, 0, 0], [0, 1, 0, 0]], [[1, 0, 0, 0], [0, 0, 0, 1]]]]) 57 | 58 | y_pred = np.array( 59 | [[[[0.8, 0.0, 0.2, 0.0], [0.0, 0.95, 0.0, 0.05]], [[0.1, 0.2, 0.3, 0.4], [0.5, 0.0, 0.5, 0.0]]], 60 | [[[0.0, 0.6, 0.0, 0.4], [0.1, 0.80, 0.1, 0.00]], [[0.7, 0.0, 0.3, 0.0], [0.2, 0.0, 0.3, 0.5]]]], 61 | dtype=np.float32) 62 | 63 | print("Image-based labelling") 64 | print("Data dimension as [batch_size (amount of images), height, width, one_hot_encoding of pixel label]") 65 | print(y_true.shape) 66 | 67 | print("categorical_cross_entropy") 68 | cce = tf.keras.losses.categorical_crossentropy 69 | cce_value = K.mean(cce(y_true, y_pred)).numpy() 70 | print(cce_value) 71 | 72 | print("focal_loss") 73 | cfl = categorical_focal_loss(alpha=[[1, 1, 1, 1]], gamma=0.) 74 | cfl_value = cfl(y_true, y_pred).numpy() 75 | print(cfl_value) 76 | 77 | self.assertEqual(cce_value, cfl_value) 78 | 79 | def test_focal_loss_effectiveness_of_balancing(self): 80 | """ Test to verify the effectiveness of the weights between α balance categories. 81 | """ 82 | 83 | y_true = np.array([[1, 1, 1, 1, 1], [0, 0, 0, 0, 0]]) 84 | y_pred = np.array([[0.3, 0.99, 0.8, 0.97, 0.85], [0.9, 0.05, 0.1, 0.09, 0]], dtype=np.float32) 85 | 86 | """ 87 | Suppose we are dealing with a multi-class prediction problem with five outputs. 88 | According to the above example, suppose our model predicts the first label poorly compared to other labels. 89 | """ 90 | cfl_balanced = categorical_focal_loss(alpha=[[1, 1, 1, 1, 1]], gamma=0.) 91 | cfl_balanced_value = cfl_balanced(y_true, y_pred).numpy() 92 | print(cfl_balanced_value) 93 | 94 | """ 95 | We use α to adjust the weight of the first label, and try to modify α to [[2, 1, 1, 1, 1]]. 96 | The loss increases, indicating that the loss function has successfully enlarged the weight of 97 | the first category, which will make the model pay more attention to the correct prediction of the first label. 98 | """ 99 | cfl_unbalanced = categorical_focal_loss(alpha=[[2, 1, 1, 1, 1]], gamma=0.) 100 | cfl_unbalanced_value = cfl_unbalanced(y_true, y_pred).numpy() 101 | print(cfl_unbalanced_value) 102 | 103 | self.assertGreater(cfl_unbalanced_value, cfl_balanced_value) 104 | --------------------------------------------------------------------------------