├── README.md └── focal_loss.py /README.md: -------------------------------------------------------------------------------- 1 | # focal-loss 2 | 3 | Tensorflow实现何凯明的Focal Loss, 该损失函数主要用于解决分类问题中的类别不平衡 4 | 5 | focal_loss_sigmoid: 二分类loss 6 | 7 | focal_loss_softmax: 多分类loss 8 | 9 | Reference Paper : Focal Loss for Dense Object Detection 10 | -------------------------------------------------------------------------------- /focal_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | 4 | """ 5 | Tensorflow实现何凯明的Focal Loss, 该损失函数主要用于解决分类问题中的类别不平衡 6 | focal_loss_sigmoid: 二分类loss 7 | focal_loss_softmax: 多分类loss 8 | Reference Paper : Focal Loss for Dense Object Detection 9 | """ 10 | 11 | def focal_loss_sigmoid(labels,logits,alpha=0.25,gamma=2): 12 | """ 13 | Computer focal loss for binary classification 14 | Args: 15 | labels: A int32 tensor of shape [batch_size]. 16 | logits: A float32 tensor of shape [batch_size]. 17 | alpha: A scalar for focal loss alpha hyper-parameter. If positive samples number 18 | > negtive samples number, alpha < 0.5 and vice versa. 19 | gamma: A scalar for focal loss gamma hyper-parameter. 20 | Returns: 21 | A tensor of the same shape as `lables` 22 | """ 23 | y_pred=tf.nn.sigmoid(logits) 24 | labels=tf.to_float(labels) 25 | L=-labels*(1-alpha)*((1-y_pred)*gamma)*tf.log(y_pred)-\ 26 | (1-labels)*alpha*(y_pred**gamma)*tf.log(1-y_pred) 27 | return L 28 | 29 | def focal_loss_softmax(labels,logits,gamma=2): 30 | """ 31 | Computer focal loss for multi classification 32 | Args: 33 | labels: A int32 tensor of shape [batch_size]. 34 | logits: A float32 tensor of shape [batch_size,num_classes]. 35 | gamma: A scalar for focal loss gamma hyper-parameter. 36 | Returns: 37 | A tensor of the same shape as `lables` 38 | """ 39 | y_pred=tf.nn.softmax(logits,dim=-1) # [batch_size,num_classes] 40 | labels=tf.one_hot(labels,depth=y_pred.shape[1]) 41 | L=-labels*((1-y_pred)**gamma)*tf.log(y_pred) 42 | L=tf.reduce_sum(L,axis=1) 43 | return L 44 | 45 | if __name__ == '__main__': 46 | logits=tf.random_uniform(shape=[5],minval=-1,maxval=1,dtype=tf.float32) 47 | labels=tf.Variable([0,1,0,0,1]) 48 | loss1=focal_loss_sigmoid(labels=labels,logits=logits) 49 | 50 | logits2=tf.random_uniform(shape=[5,4],minval=-1,maxval=1,dtype=tf.float32) 51 | labels2=tf.Variable([1,0,2,3,1]) 52 | loss2=focal_loss_softmax(labels==labels2,logits=logits2) 53 | 54 | with tf.Session() as sess: 55 | sess.run(tf.global_variables_initializer()) 56 | print sess.run(loss1) 57 | print sess.run(loss2) 58 | 59 | --------------------------------------------------------------------------------