└── GHM_loss.py /GHM_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class GHM_Loss(nn.Module): 6 | def __init__(self, bins, alpha): 7 | super(GHM_Loss, self).__init__() 8 | self._bins = bins 9 | self._alpha = alpha 10 | self._last_bin_count = None 11 | 12 | def _g2bin(self, g): 13 | return torch.floor(g * (self._bins - 0.0001)).long() 14 | 15 | def _custom_loss(self, x, target, weight): 16 | raise NotImplementedError 17 | 18 | def _custom_loss_grad(self, x, target): 19 | raise NotImplementedError 20 | 21 | def forward(self, x, target): 22 | g = torch.abs(self._custom_loss_grad(x, target)).detach() 23 | 24 | bin_idx = self._g2bin(g) 25 | 26 | bin_count = torch.zeros((self._bins)) 27 | for i in range(self._bins): 28 | bin_count[i] = (bin_idx == i).sum().item() 29 | 30 | N = (x.size(0) * x.size(1)) 31 | 32 | if self._last_bin_count is None: 33 | self._last_bin_count = bin_count 34 | else: 35 | bin_count = self._alpha * self._last_bin_count + (1 - self._alpha) * bin_count 36 | self._last_bin_count = bin_count 37 | 38 | nonempty_bins = (bin_count > 0).sum().item() 39 | 40 | gd = bin_count * nonempty_bins 41 | gd = torch.clamp(gd, min=0.0001) 42 | beta = N / gd 43 | 44 | return self._custom_loss(x, target, beta[bin_idx]) 45 | 46 | 47 | class GHMC_Loss(GHM_Loss): 48 | def __init__(self, bins, alpha): 49 | super(GHMC_Loss, self).__init__(bins, alpha) 50 | 51 | def _custom_loss(self, x, target, weight): 52 | return F.binary_cross_entropy_with_logits(x, target, weight=weight) 53 | 54 | def _custom_loss_grad(self, x, target): 55 | return torch.sigmoid(x).detach() - target 56 | 57 | 58 | class GHMR_Loss(GHM_Loss): 59 | def __init__(self, bins, alpha, mu): 60 | super(GHMR_Loss, self).__init__(bins, alpha) 61 | self._mu = mu 62 | 63 | def _custom_loss(self, x, target, weight): 64 | d = x - target 65 | mu = self._mu 66 | loss = torch.sqrt(d * d + mu * mu) - mu 67 | N = x.size(0) * x.size(1) 68 | return (loss * weight).sum() / N 69 | 70 | def _custom_loss_grad(self, x, target): 71 | d = x - target 72 | mu = self._mu 73 | return d / torch.sqrt(d * d + mu * mu) 74 | --------------------------------------------------------------------------------