├── loss.py └── metrics.py /loss.py: -------------------------------------------------------------------------------- 1 | 2 | class ComputeLoss: 3 | sort_obj_iou = False 4 | 5 | # Compute losses 6 | def __init__(self, model, autobalance=False): 7 | device = next(model.parameters()).device # get model device 8 | h = model.hyp # hyperparameters 9 | 10 | # Define criteria 11 | BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) 12 | BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) 13 | 14 | # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 15 | self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets 16 | 17 | # Focal loss 18 | g = h['fl_gamma'] # focal loss gamma 19 | if g > 0: 20 | BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) 21 | 22 | m = de_parallel(model).model[-1] # Detect() module 23 | self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 24 | self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index 25 | self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance 26 | self.na = m.na # number of anchors 27 | self.nc = m.nc # number of classes 28 | self.nl = m.nl # number of layers 29 | self.anchors = m.anchors 30 | self.device = device 31 | 32 | def __call__(self, p, targets, epoch=0): # predictions, targets 33 | lcls = torch.zeros(1, device=self.device) # class loss 34 | lbox = torch.zeros(1, device=self.device) # box loss 35 | lobj = torch.zeros(1, device=self.device) # object loss 36 | tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets 37 | 38 | # Losses 39 | for i, pi in enumerate(p): # layer index, layer predictions 40 | b, a, gj, gi = indices[i] # image, anchor, gridy, gridx 41 | tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj 42 | 43 | n = b.shape[0] # number of targets 44 | if n: 45 | # pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0 46 | pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions 47 | 48 | # Regression 49 | pxy = pxy.sigmoid() * 2 - 0.5 50 | pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i] 51 | pbox = torch.cat((pxy, pwh), 1) # predicted box 52 | 53 | # ---------------------------------------------------------------------------------------------------- 54 | # part of change 55 | iou = bbox_iou(pbox, tbox[i], epoch, CIoU=True, Focal=False, scale=False) 56 | if type(iou) is tuple: 57 | if len(iou) == 2: 58 | # increased the weight of low/high IoU 59 | lbox += ((1 - iou[1].detach().squeeze()) * (1 - iou[0].squeeze())).mean() # Focal 60 | # lbox += (iou[1].detach().squeeze() * (1 - iou[0].squeeze())).mean() # Focal-inv 61 | iou = iou[0].squeeze() 62 | else: 63 | lbox += (iou[0] * iou[1]).mean() 64 | iou = iou[2].squeeze() 65 | else: 66 | lbox += (1.0 - iou.squeeze()).mean() # iou loss 67 | iou = iou.squeeze() 68 | # ----------------------------------------------------------------------------------------------------- 69 | 70 | # Objectness 71 | iou = iou.detach().clamp(0).type(tobj.dtype) 72 | if self.sort_obj_iou: 73 | j = iou.argsort() 74 | b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j] 75 | if self.gr < 1: 76 | iou = (1.0 - self.gr) + self.gr * iou 77 | tobj[b, a, gj, gi] = iou # iou ratio 78 | 79 | # Classification 80 | if self.nc > 1: # cls loss (only if multiple classes) 81 | t = torch.full_like(pcls, self.cn, device=self.device) # targets 82 | t[range(n), tcls[i]] = self.cp 83 | lcls += self.BCEcls(pcls, t) # BCE 84 | 85 | # Append targets to text file 86 | # with open('targets.txt', 'a') as file: 87 | # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] 88 | 89 | obji = self.BCEobj(pi[..., 4], tobj) 90 | lobj += obji * self.balance[i] # obj loss 91 | if self.autobalance: 92 | self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item() 93 | 94 | if self.autobalance: 95 | self.balance = [x / self.balance[self.ssi] for x in self.balance] 96 | lbox *= self.hyp['box'] 97 | lobj *= self.hyp['obj'] 98 | lcls *= self.hyp['cls'] 99 | bs = tobj.shape[0] # batch size 100 | 101 | return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach() 102 | 103 | def build_targets(self, p, targets): 104 | # Build targets for compute_loss(), input targets(image,class,x,y,w,h) 105 | na, nt = self.na, targets.shape[0] # number of anchors, targets 106 | tcls, tbox, indices, anch = [], [], [], [] 107 | gain = torch.ones(7, device=self.device) # normalized to gridspace gain 108 | ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) 109 | targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2) # append anchor indices 110 | 111 | g = 0.5 # bias 112 | off = torch.tensor( 113 | [ 114 | [0, 0], 115 | [1, 0], 116 | [0, 1], 117 | [-1, 0], 118 | [0, -1], # j,k,l,m 119 | # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm 120 | ], 121 | device=self.device).float() * g # offsets 122 | 123 | for i in range(self.nl): 124 | anchors, shape = self.anchors[i], p[i].shape 125 | gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain 126 | 127 | # Match targets to anchors 128 | t = targets * gain # shape(3,n,7) 129 | if nt: 130 | # Matches 131 | r = t[..., 4:6] / anchors[:, None] # wh ratio 132 | j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t'] # compare 133 | # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) 134 | t = t[j] # filter 135 | 136 | # Offsets 137 | gxy = t[:, 2:4] # grid xy 138 | gxi = gain[[2, 3]] - gxy # inverse 139 | j, k = ((gxy % 1 < g) & (gxy > 1)).T 140 | l, m = ((gxi % 1 < g) & (gxi > 1)).T 141 | j = torch.stack((torch.ones_like(j), j, k, l, m)) 142 | t = t.repeat((5, 1, 1))[j] 143 | offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] 144 | else: 145 | t = targets[0] 146 | offsets = 0 147 | 148 | # Define 149 | bc, gxy, gwh, a = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors 150 | a, (b, c) = a.long().view(-1), bc.long().T # anchors, image, class 151 | gij = (gxy - offsets).long() 152 | gi, gj = gij.T # grid indices 153 | 154 | # Append 155 | indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid 156 | tbox.append(torch.cat((gxy - gij, gwh), 1)) # box 157 | anch.append(anchors[a]) # anchors 158 | tcls.append(c) # class 159 | 160 | return tcls, tbox, indices, anch 161 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | 2 | class WIoU_Scale: 3 | ''' monotonous: { 4 | None: origin v1 5 | True: monotonic FM v2 6 | False: non-monotonic FM v3 7 | } 8 | momentum: The momentum of running mean''' 9 | 10 | iou_mean = 1. 11 | monotonous = False 12 | _momentum = 1 - 0.5 ** (1 / 7000) 13 | _is_train = True 14 | 15 | def __init__(self, iou): 16 | self.iou = iou 17 | self._update(self) 18 | 19 | @classmethod 20 | def _update(cls, self): 21 | if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \ 22 | cls._momentum * self.iou.detach().mean().item() 23 | 24 | @classmethod 25 | def _scaled_loss(cls, self, gamma=1.9, delta=3): 26 | if isinstance(self.monotonous, bool): 27 | if self.monotonous: 28 | return (self.iou.detach() / self.iou_mean).sqrt() 29 | else: 30 | beta = self.iou.detach() / self.iou_mean 31 | alpha = delta * torch.pow(gamma, beta - delta) 32 | return beta / alpha 33 | return 1 34 | 35 | 36 | def bbox_iou(box1, box2, epoch, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, WIoU=False, UIoU=False, Focal=False, alpha=1, gamma=0.5, scale=False, eps=1e-7): 37 | # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4) 38 | 39 | # Get the coordinates of bounding boxes 40 | if xywh: # transform from xywh to xyxy 41 | (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1) 42 | w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 43 | b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ 44 | b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ 45 | else: # x1, y1, x2, y2 = box1 46 | b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1) 47 | b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1) 48 | w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps) 49 | w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps) 50 | 51 | # ---------------------------------------------------------------------------------------------------------------- 52 | # UIoU 53 | if UIoU: 54 | # define the center point for scaling 55 | bb1_xc = x1 56 | bb1_yc = y1 57 | bb2_xc = x2 58 | bb2_yc = y2 59 | # attenuation mode of hyperparameter "ratio" 60 | linear = True 61 | cosine = False 62 | fraction = False 63 | # assuming that the total training epochs are 300, the "ratio" changes from 2 to 0.5 64 | if linear: 65 | ratio = -0.005 * epoch + 2 66 | elif cosine: 67 | ratio = 0.75 * math.cos(math.pi * epoch / 300) + 1.25 68 | elif fraction: 69 | ratio = 200 / (epoch + 100) 70 | else: 71 | ratio = 0.5 72 | ww1, hh1, ww2, hh2 = w1 * ratio, h1 * ratio, w2 * ratio, h2 * ratio 73 | bb1_x1, bb1_x2, bb1_y1, bb1_y2 = bb1_xc - (ww1 / 2), bb1_xc + (ww1 / 2), bb1_yc - (hh1 / 2), bb1_yc + (hh1 / 2) 74 | bb2_x1, bb2_x2, bb2_y1, bb2_y2 = bb2_xc - (ww2 / 2), bb2_xc + (ww2 / 2), bb2_yc - (hh2 / 2), bb2_yc + (hh2 / 2) 75 | # assign the value back to facilitate subsequent calls 76 | w1, h1, w2, h2 = ww1, hh1, ww2, hh2 77 | b1_x1, b1_x2, b1_y1, b1_y2 = bb1_x1, bb1_x2, bb1_y1, bb1_y2 78 | b2_x1, b2_x2, b2_y1, b2_y2 = bb2_x1, bb2_x2, bb2_y1, bb2_y2 79 | CIoU = True 80 | # --------------------------------------------------------------------------------------------------------------- 81 | 82 | # Intersection area 83 | inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \ 84 | (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0) 85 | 86 | # Union Area 87 | union = w1 * h1 + w2 * h2 - inter + eps 88 | 89 | # WIoU needs to set scale to "True" 90 | if scale: 91 | self = WIoU_Scale(1 - (inter / union)) 92 | 93 | # IoU 94 | # iou = inter / union # ori iou 95 | iou = torch.pow(inter/(union + eps), alpha) # alpha iou 96 | if CIoU or DIoU or GIoU or EIoU or SIoU or WIoU: 97 | cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width 98 | ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height 99 | if CIoU or DIoU or EIoU or SIoU or WIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 100 | c2 = (cw ** 2 + ch ** 2) ** alpha + eps # convex diagonal squared 101 | rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha # center dist ** 2 102 | if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 103 | v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2) 104 | with torch.no_grad(): 105 | alpha_ciou = v / (v - iou + (1 + eps)) 106 | if Focal: 107 | return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter/(union + eps), gamma) # Focal_CIoU 108 | else: 109 | return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)) # CIoU 110 | elif EIoU: 111 | rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2 112 | rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2 113 | cw2 = torch.pow(cw ** 2 + eps, alpha) 114 | ch2 = torch.pow(ch ** 2 + eps, alpha) 115 | if Focal: 116 | return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter/(union + eps), gamma) # Focal_EIou 117 | else: 118 | return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIou 119 | elif SIoU: 120 | # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf 121 | s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps 122 | s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps 123 | sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5) 124 | sin_alpha_1 = torch.abs(s_cw) / sigma 125 | sin_alpha_2 = torch.abs(s_ch) / sigma 126 | threshold = pow(2, 0.5) / 2 127 | sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1) 128 | angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2) 129 | rho_x = (s_cw / cw) ** 2 130 | rho_y = (s_ch / ch) ** 2 131 | gamma = angle_cost - 2 132 | distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y) 133 | omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2) 134 | omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2) 135 | shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4) 136 | if Focal: 137 | return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter/(union + eps), gamma) # Focal_SIou 138 | else: 139 | return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha) # SIou 140 | elif WIoU: 141 | if Focal: 142 | raise RuntimeError("WIoU do not support Focal.") 143 | elif scale: 144 | return getattr(WIoU_Scale, '_scaled_loss')(self), (1 - iou) * torch.exp((rho2 / c2)), iou # WIoU https://arxiv.org/abs/2301.10051 145 | else: 146 | return iou, torch.exp((rho2 / c2)) # WIoU v1 147 | if Focal: 148 | return iou - rho2 / c2, torch.pow(inter/(union + eps), gamma) # Focal_DIoU 149 | else: 150 | return iou - rho2 / c2 # DIoU 151 | c_area = cw * ch + eps # convex area 152 | if Focal: 153 | return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter/(union + eps), gamma) # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf 154 | else: 155 | return iou - torch.pow((c_area - union) / c_area + eps, alpha) # GIoU https://arxiv.org/pdf/1902.09630.pdf 156 | if Focal: 157 | return iou, torch.pow(inter/(union + eps), gamma) # Focal_IoU 158 | else: 159 | return iou # IoU 160 | --------------------------------------------------------------------------------