├── .gitignore ├── LICENSE ├── README.md ├── focal_loss.py ├── hubconf.py └── setup.cfg /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | 4 | .vscode/ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Adeel Hassan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/292520399.svg)](https://zenodo.org/badge/latestdoi/292520399) 2 | 3 | # Multi-class Focal Loss 4 | 5 | An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, https://arxiv.org/abs/1708.02002, generalized to the multi-class case. 6 | 7 | It is essentially an enhancement to cross-entropy loss and is useful for classification tasks when there is a large class imbalance. It has the effect of underweighting easy examples. 8 | 9 | # Usage 10 | - `FocalLoss` is an `nn.Module` and behaves very much like `nn.CrossEntropyLoss()` i.e. 11 | - supports the `reduction` and `ignore_index` params, and 12 | - is able to work with 2D inputs of shape `(N, C)` as well as K-dimensional inputs of shape `(N, C, d1, d2, ..., dK)`. 13 | 14 | - Example usage 15 | ```python3 16 | focal_loss = FocalLoss(alpha, gamma) 17 | ... 18 | inp, targets = batch 19 | out = model(inp) 20 | loss = focal_loss(out, targets) 21 | ``` 22 | 23 | # Loading through torch.hub 24 | This repo supports importing modules through `torch.hub`. `FocalLoss` can be easily imported into your code via, for example: 25 | ```python3 26 | focal_loss = torch.hub.load( 27 | 'adeelh/pytorch-multi-class-focal-loss', 28 | model='FocalLoss', 29 | alpha=torch.tensor([.75, .25]), 30 | gamma=2, 31 | reduction='mean', 32 | force_reload=False 33 | ) 34 | x, y = torch.randn(10, 2), (torch.rand(10) > .5).long() 35 | loss = focal_loss(x, y) 36 | ``` 37 | Or: 38 | ```python3 39 | focal_loss = torch.hub.load( 40 | 'adeelh/pytorch-multi-class-focal-loss', 41 | model='focal_loss', 42 | alpha=[.75, .25], 43 | gamma=2, 44 | reduction='mean', 45 | device='cpu', 46 | dtype=torch.float32, 47 | force_reload=False 48 | ) 49 | x, y = torch.randn(10, 2), (torch.rand(10) > .5).long() 50 | loss = focal_loss(x, y) 51 | ``` 52 | -------------------------------------------------------------------------------- /focal_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class FocalLoss(nn.Module): 10 | """ Focal Loss, as described in https://arxiv.org/abs/1708.02002. 11 | 12 | It is essentially an enhancement to cross entropy loss and is 13 | useful for classification tasks when there is a large class imbalance. 14 | x is expected to contain raw, unnormalized scores for each class. 15 | y is expected to contain class labels. 16 | 17 | Shape: 18 | - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0. 19 | - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0. 20 | """ 21 | 22 | def __init__(self, 23 | alpha: Optional[Tensor] = None, 24 | gamma: float = 0., 25 | reduction: str = 'mean', 26 | ignore_index: int = -100): 27 | """Constructor. 28 | 29 | Args: 30 | alpha (Tensor, optional): Weights for each class. Defaults to None. 31 | gamma (float, optional): A constant, as described in the paper. 32 | Defaults to 0. 33 | reduction (str, optional): 'mean', 'sum' or 'none'. 34 | Defaults to 'mean'. 35 | ignore_index (int, optional): class label to ignore. 36 | Defaults to -100. 37 | """ 38 | if reduction not in ('mean', 'sum', 'none'): 39 | raise ValueError( 40 | 'Reduction must be one of: "mean", "sum", "none".') 41 | 42 | super().__init__() 43 | self.alpha = alpha 44 | self.gamma = gamma 45 | self.ignore_index = ignore_index 46 | self.reduction = reduction 47 | 48 | self.nll_loss = nn.NLLLoss( 49 | weight=alpha, reduction='none', ignore_index=ignore_index) 50 | 51 | def __repr__(self): 52 | arg_keys = ['alpha', 'gamma', 'ignore_index', 'reduction'] 53 | arg_vals = [self.__dict__[k] for k in arg_keys] 54 | arg_strs = [f'{k}={v!r}' for k, v in zip(arg_keys, arg_vals)] 55 | arg_str = ', '.join(arg_strs) 56 | return f'{type(self).__name__}({arg_str})' 57 | 58 | def forward(self, x: Tensor, y: Tensor) -> Tensor: 59 | if x.ndim > 2: 60 | # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C) 61 | c = x.shape[1] 62 | x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c) 63 | # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,) 64 | y = y.view(-1) 65 | 66 | unignored_mask = y != self.ignore_index 67 | y = y[unignored_mask] 68 | if len(y) == 0: 69 | return torch.tensor(0.) 70 | x = x[unignored_mask] 71 | 72 | # compute weighted cross entropy term: -alpha * log(pt) 73 | # (alpha is already part of self.nll_loss) 74 | log_p = F.log_softmax(x, dim=-1) 75 | ce = self.nll_loss(log_p, y) 76 | 77 | # get true class column from each row 78 | all_rows = torch.arange(len(x)) 79 | log_pt = log_p[all_rows, y] 80 | 81 | # compute focal term: (1 - pt)^gamma 82 | pt = log_pt.exp() 83 | focal_term = (1 - pt)**self.gamma 84 | 85 | # the full loss: -alpha * ((1 - pt)^gamma) * log(pt) 86 | loss = focal_term * ce 87 | 88 | if self.reduction == 'mean': 89 | loss = loss.mean() 90 | elif self.reduction == 'sum': 91 | loss = loss.sum() 92 | 93 | return loss 94 | 95 | 96 | def focal_loss(alpha: Optional[Sequence] = None, 97 | gamma: float = 0., 98 | reduction: str = 'mean', 99 | ignore_index: int = -100, 100 | device='cpu', 101 | dtype=torch.float32) -> FocalLoss: 102 | """Factory function for FocalLoss. 103 | 104 | Args: 105 | alpha (Sequence, optional): Weights for each class. Will be converted 106 | to a Tensor if not None. Defaults to None. 107 | gamma (float, optional): A constant, as described in the paper. 108 | Defaults to 0. 109 | reduction (str, optional): 'mean', 'sum' or 'none'. 110 | Defaults to 'mean'. 111 | ignore_index (int, optional): class label to ignore. 112 | Defaults to -100. 113 | device (str, optional): Device to move alpha to. Defaults to 'cpu'. 114 | dtype (torch.dtype, optional): dtype to cast alpha to. 115 | Defaults to torch.float32. 116 | 117 | Returns: 118 | A FocalLoss object 119 | """ 120 | if alpha is not None: 121 | if not isinstance(alpha, Tensor): 122 | alpha = torch.tensor(alpha) 123 | alpha = alpha.to(device=device, dtype=dtype) 124 | 125 | fl = FocalLoss( 126 | alpha=alpha, 127 | gamma=gamma, 128 | reduction=reduction, 129 | ignore_index=ignore_index) 130 | return fl 131 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Optional list of dependencies required by the package 2 | dependencies = ['torch'] 3 | 4 | from focal_loss import FocalLoss, focal_loss 5 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [yapf] 2 | based_on_style = pep8 3 | DEDENT_CLOSING_BRACKETS = false 4 | SPLIT_COMPLEX_COMPREHENSION = true 5 | COALESCE_BRACKETS = true 6 | --------------------------------------------------------------------------------