├── README.md └── multilabelceloss.py /README.md: -------------------------------------------------------------------------------- 1 | # MultilabelCrossEntropyLoss-Pytorch 2 | multilabel categorical crossentropy 3 | 4 | This is a Pytorch implementation of multilabel crossentropy loss, which is modified from Keras version here: 5 | 6 | 苏剑林. (2020, Apr 25). 《将“softmax+交叉熵”推广到多标签分类问题 》[Blog post]. Retrieved from https://kexue.fm/archives/7359 7 | -------------------------------------------------------------------------------- /multilabelceloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def multilabel_categorical_crossentropy(y_true, y_pred): 3 | """多标签分类的交叉熵 4 | 说明:y_true和y_pred的shape一致,y_true的元素非0即1, 5 | 1表示对应的类为目标类,0表示对应的类为非目标类。 6 | """ 7 | y_pred = (1 - 2 * y_true) * y_pred 8 | y_pred_neg = y_pred - y_true * 1e12 9 | y_pred_pos = y_pred - (1 - y_true) * 1e12 10 | zeros = torch.zeros_like(y_pred[..., :1]) 11 | y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) 12 | y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) 13 | neg_loss = torch.logsumexp(y_pred_neg, dim=-1) 14 | pos_loss = torch.logsumexp(y_pred_pos, dim=-1) 15 | return neg_loss + pos_loss --------------------------------------------------------------------------------