├── images └── focal_loss.png ├── README.md ├── LICENSE ├── focalloss_test.py ├── focalloss.py └── .gitignore /images/focal_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clcarwin/focal_loss_pytorch/HEAD/images/focal_loss.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Focal Loss for Dense Object Detection in PyTorch 2 | ![focal loss](images/focal_loss.png) 3 | [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf) 4 | 5 | # Result 6 | Method | training set | val set | mAP 7 | --- |--- |--- |--- 8 | Cross Entropy Loss |VOC2007 | VOC2007 | 63.36 9 | Focal Loss |VOC2007 | VOC2007 | 65.26 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 carwin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /focalloss_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import os,sys,random,time 10 | import argparse 11 | 12 | from focalloss import * 13 | 14 | 15 | start_time = time.time() 16 | maxe = 0 17 | for i in range(1000): 18 | x = torch.rand(12800,2)*random.randint(1,10) 19 | x = Variable(x.cuda()) 20 | l = torch.rand(12800).ge(0.1).long() 21 | l = Variable(l.cuda()) 22 | 23 | output0 = FocalLoss(gamma=0)(x,l) 24 | output1 = nn.CrossEntropyLoss()(x,l) 25 | a = output0.data[0] 26 | b = output1.data[0] 27 | if abs(a-b)>maxe: maxe = abs(a-b) 28 | print('time:',time.time()-start_time,'max_error:',maxe) 29 | 30 | 31 | start_time = time.time() 32 | maxe = 0 33 | for i in range(100): 34 | x = torch.rand(128,1000,8,4)*random.randint(1,10) 35 | x = Variable(x.cuda()) 36 | l = torch.rand(128,8,4)*1000 # 1000 is classes_num 37 | l = l.long() 38 | l = Variable(l.cuda()) 39 | 40 | output0 = FocalLoss(gamma=0)(x,l) 41 | output1 = nn.NLLLoss2d()(F.log_softmax(x),l) 42 | a = output0.data[0] 43 | b = output1.data[0] 44 | if abs(a-b)>maxe: maxe = abs(a-b) 45 | print('time:',time.time()-start_time,'max_error:',maxe) 46 | -------------------------------------------------------------------------------- /focalloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class FocalLoss(nn.Module): 7 | def __init__(self, gamma=0, alpha=None, size_average=True): 8 | super(FocalLoss, self).__init__() 9 | self.gamma = gamma 10 | self.alpha = alpha 11 | if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha]) 12 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 13 | self.size_average = size_average 14 | 15 | def forward(self, input, target): 16 | if input.dim()>2: 17 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 18 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 19 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 20 | target = target.view(-1,1) 21 | 22 | logpt = F.log_softmax(input) 23 | logpt = logpt.gather(1,target) 24 | logpt = logpt.view(-1) 25 | pt = Variable(logpt.data.exp()) 26 | 27 | if self.alpha is not None: 28 | if self.alpha.type()!=input.data.type(): 29 | self.alpha = self.alpha.type_as(input.data) 30 | at = self.alpha.gather(0,target.data.view(-1)) 31 | logpt = logpt * Variable(at) 32 | 33 | loss = -1 * (1-pt)**self.gamma * logpt 34 | if self.size_average: return loss.mean() 35 | else: return loss.sum() 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | --------------------------------------------------------------------------------