├── FocalLoss.py └── README.md /FocalLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FocalLoss: 5 | def __init__(self, alpha_t=None, gamma=0): 6 | """ 7 | :param alpha_t: A list of weights for each class 8 | :param gamma: 9 | """ 10 | self.alpha_t = torch.tensor(alpha_t) if alpha_t else None 11 | self.gamma = gamma 12 | 13 | def __call__(self, outputs, targets): 14 | if self.alpha_t is None and self.gamma == 0: 15 | focal_loss = torch.nn.functional.cross_entropy(outputs, targets) 16 | 17 | elif self.alpha_t is not None and self.gamma == 0: 18 | if self.alpha_t.device != outputs.device: 19 | self.alpha_t = self.alpha_t.to(outputs) 20 | focal_loss = torch.nn.functional.cross_entropy(outputs, targets, 21 | weight=self.alpha_t) 22 | 23 | elif self.alpha_t is None and self.gamma != 0: 24 | ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none') 25 | p_t = torch.exp(-ce_loss) 26 | focal_loss = ((1 - p_t) ** self.gamma * ce_loss).mean() 27 | 28 | elif self.alpha_t is not None and self.gamma != 0: 29 | if self.alpha_t.device != outputs.device: 30 | self.alpha_t = self.alpha_t.to(outputs) 31 | ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none') 32 | p_t = torch.exp(-ce_loss) 33 | ce_loss = torch.nn.functional.cross_entropy(outputs, targets, 34 | weight=self.alpha_t, reduction='none') 35 | focal_loss = ((1 - p_t) ** self.gamma * ce_loss).mean() # mean over the batch 36 | 37 | return focal_loss 38 | 39 | 40 | if __name__ == '__main__': 41 | outputs = torch.tensor([[2, 1.], 42 | [2.5, 1]], device='cuda') 43 | targets = torch.tensor([0, 1], device='cuda') 44 | print(torch.nn.functional.softmax(outputs, dim=1)) 45 | 46 | fl= FocalLoss([0.5, 0.5], 2) 47 | 48 | print(fl(outputs, targets)) 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-multi-class-focal-loss 2 | A simple pytorch implementation for multi-class focal loss 3 | 4 | Someone told me that this implementation seems to be wrong, so it's better not to use it directly. 5 | --------------------------------------------------------------------------------