├── README.md ├── probiou_pytorch.py └── probiou_tensorflow.py /README.md: -------------------------------------------------------------------------------- 1 | # probiou-sample 2 | 3 | This repository has sample code of Probabilistic IOU function in Pytorch and Tensorflow. 4 | 5 | For adapted code with Probabilistic IOU see: 6 | 7 | * Horizontal Bounding Boxes 8 | + [EFFICIENTDET](https://github.com/ProbIOU/PROBIOU-EFFICIENTDET) 9 | + [SSD](https://github.com/ProbIOU/PROBIOU-SSD) 10 | * Rotated Bounding Boxes 11 | + RetinaNet and R3Det: 12 | 13 | ## Cite our work 14 | 15 | ``` 16 | @article{Murrugarra_Llerena_2024, 17 | title={Probabilistic Intersection-Over-Union for Training and Evaluation of Oriented Object Detectors}, 18 | volume={33}, 19 | ISSN={1941-0042}, 20 | url={http://dx.doi.org/10.1109/TIP.2023.3348697}, 21 | DOI={10.1109/tip.2023.3348697}, 22 | journal={IEEE Transactions on Image Processing}, 23 | publisher={Institute of Electrical and Electronics Engineers (IEEE)}, 24 | author={Murrugarra-Llerena, Jeffri and Kirsten, Lucas N. and Zeni, Luis Felipe and Jung, Claudio R.}, 25 | year={2024}, 26 | pages={671–681} } 27 | ``` 28 | 29 | ## FOR QUESTIONS 30 | 31 | email me at: jeffri.mllerena@inf.ufrgs.br 32 | -------------------------------------------------------------------------------- /probiou_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def gbb_form(boxes): 4 | return torch.cat((boxes[:,:2],torch.pow(boxes[:,2:4],2)/12,boxes[:,4:]),1) 5 | 6 | def rotated_form(a_, b_, angles): 7 | 8 | a = a_*torch.pow(torch.cos(angles),2.)+b_*torch.pow(torch.sin(angles),2.) 9 | b = a_*torch.pow(torch.sin(angles),2.)+b_*torch.pow(torch.cos(angles),2.) 10 | c = a_*torch.cos(angles)*torch.sin(angles)-b_*torch.sin(angles)*torch.cos(angles) 11 | return a,b,c 12 | 13 | def probiou_loss(pred, target, eps = 1e-3, mode='l1'): 14 | 15 | """ 16 | pred -> a matrix [N,5](x,y,w,h,angle - in radians) containing ours predicted box ;in case of HBB angle == 0 17 | target -> a matrix [N,5](x,y,w,h,angle - in radians) containing ours target box ;in case of HBB angle == 0 18 | eps -> threshold to avoid infinite values 19 | mode -> ('l1' in [0,1] or 'l2' in [0,inf]) metrics according our paper 20 | 21 | """ 22 | 23 | gbboxes1 = gbb_form(pred) 24 | gbboxes2 = gbb_form(target) 25 | 26 | x1, y1, a1_, b1_, c1_ = gbboxes1[:,0], gbboxes1[:,1], gbboxes1[:,2], gbboxes1[:,3], gbboxes1[:,4] 27 | x2, y2, a2_, b2_, c2_ = gbboxes2[:,0], gbboxes2[:,1], gbboxes2[:,2], gbboxes2[:,3], gbboxes2[:,4] 28 | 29 | a1, b1, c1 = rotated_form(a1_, b1_, c1_) 30 | a2, b2, c2 = rotated_form(a2_, b2_, c2_) 31 | 32 | t1 = (((a1+a2)*(torch.pow(y1-y2,2)) + (b1+b2)*(torch.pow(x1-x2,2)) )/((a1+a2)*(b1+b2)-(torch.pow(c1+c2,2))+eps))*0.25 33 | t2 = (((c1+c2)*(x2-x1)*(y1-y2))/((a1+a2)*(b1+b2)-(torch.pow(c1+c2,2))+eps))*0.5 34 | t3 = torch.log(((a1+a2)*(b1+b2)-(torch.pow(c1+c2,2)))/(4*torch.sqrt((a1*b1-torch.pow(c1,2))*(a2*b2-torch.pow(c2,2)))+eps)+eps)*0.5 35 | 36 | B_d = t1 + t2 + t3 37 | 38 | B_d = torch.clamp(B_d,eps,100.0) 39 | l1 = torch.sqrt(1.0-torch.exp(-B_d)+eps) 40 | l_i = torch.pow(l1, 2.0) 41 | l2 = -torch.log(1.0 - l_i+eps) 42 | 43 | if mode=='l1': 44 | probiou = l1 45 | if mode=='l2': 46 | probiou = l2 47 | 48 | return probiou 49 | 50 | def main(): 51 | 52 | P = torch.rand(8,5) 53 | T = torch.rand(8,5) 54 | LOSS = probiou_loss(P, T) 55 | REDUCE_LOSS = torch.mean(LOSS) 56 | print(REDUCE_LOSS.item()) 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /probiou_tensorflow.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def probiou_loss(boxes_pred, target_boxes_, EPS = 1e-3, mode='l2'): 6 | 7 | """ 8 | boxes_pred -> a matrix [N,5](x,y,w,h,angle - in radians) containing ours predicted box ;in case of HBB angle == 0 9 | target_boxes_ -> a matrix [N,5](x,y,w,h,angle - in radians) containing ours target box ;in case of HBB angle == 0 10 | EPS -> threshold to avoid infinite values 11 | mode -> ('l1' in [0,1] or 'l2' in [0,inf]) metrics according our paper 12 | 13 | """ 14 | 15 | x1, y1, w1, h1, theta1 = tf.unstack(boxes_pred, axis=1) 16 | x2, y2, w2, h2, theta2 = tf.unstack(target_boxes_, axis=1) 17 | x1 = tf.reshape(x1, [-1, 1]) 18 | y1 = tf.reshape(y1, [-1, 1]) 19 | h1 = tf.reshape(h1, [-1, 1]) 20 | w1 = tf.reshape(w1, [-1, 1]) 21 | theta1 = tf.reshape(theta1, [-1, 1]) 22 | x2 = tf.reshape(x2, [-1, 1]) 23 | y2 = tf.reshape(y2, [-1, 1]) 24 | h2 = tf.reshape(h2, [-1, 1]) 25 | w2 = tf.reshape(w2, [-1, 1]) 26 | theta2 = tf.reshape(theta2, [-1, 1]) 27 | 28 | # gbb form 29 | aa = w1**2/12; bb = h1**2/12; angles = theta1 30 | # rotated form 31 | a1 = aa*tf.math.pow(tf.math.cos(angles), 2.) + bb*tf.math.pow(tf.math.sin(angles), 2.) 32 | b1 = aa*tf.math.pow(tf.math.sin(angles), 2.) + bb*tf.math.pow(tf.math.cos(angles), 2.) 33 | c1 = 0.5*(aa - bb)*tf.math.sin(2.*angles) 34 | 35 | # gbb form 36 | aa = w2**2/12; bb = h2**2/12; angles = theta2 37 | # rotated form 38 | a2 = aa*tf.math.pow(tf.math.cos(angles), 2.) + bb*tf.math.pow(tf.math.sin(angles), 2.) 39 | b2 = aa*tf.math.pow(tf.math.sin(angles), 2.) + bb*tf.math.pow(tf.math.cos(angles), 2.) 40 | c2 = 0.5*(aa - bb)*tf.math.sin(2.*angles) 41 | 42 | B1 = 1/4.*( (a1+a2)*(y1-y2)**2. + (b1+b2)*(x1-x2)**2. ) + 1/2.*( (c1+c2)*(x2-x1)*(y1-y2) ) 43 | B1 = B1 / ( (a1+a2)*(b1+b2) - (c1+c2)**2. + EPS ) 44 | 45 | 46 | sqrt = (a1*b1-c1**2)*(a2*b2-c2**2) 47 | sqrt = tf.clip_by_value(sqrt, EPS, tf.reduce_max(sqrt)+EPS) 48 | B2 = ( (a1+a2)*(b1+b2) - (c1+c2)**2. )/( 4.*tf.math.sqrt(sqrt) + EPS ) 49 | B2 = tf.clip_by_value(B2, EPS, tf.reduce_max(B2)+EPS) 50 | B2 = 1/2.*tf.math.log(B2) 51 | 52 | Bd = B1 + B2 53 | Bd = tf.clip_by_value(Bd, EPS, 100.) 54 | 55 | l1 = tf.math.sqrt(1 - tf.math.exp(-Bd) + EPS) 56 | 57 | if mode=='l2': 58 | l2 = tf.math.pow(l1, 2.) 59 | probiou = - tf.math.log(1. - l2 + EPS) 60 | else: 61 | probiou = l1 62 | 63 | return probiou 64 | 65 | def main(): 66 | 67 | g1 = tf.random.Generator.from_seed(1) 68 | P = g1.normal(shape=[8, 5]) 69 | g2 = tf.random.Generator.from_seed(2) 70 | T = g2.normal(shape=[8, 5]) 71 | 72 | LOSS = probiou_loss(P,T,mode='l1') 73 | REDUCE_LOSS = tf.reduce_mean(LOSS) 74 | print(REDUCE_LOSS) 75 | 76 | if __name__ == '__main__': 77 | main() 78 | --------------------------------------------------------------------------------