├── Logger.py ├── Makefile ├── README.md ├── configs ├── retinanet_r50_fpn_hrsc.yml └── retinanet_r50_fpn_ssdd.yml ├── datasets ├── HRSC_dataset.py ├── SSDD_dataset.py ├── __pycache__ │ ├── HRSC_dataset.cpython-37.pyc │ ├── SSDD_dataset.cpython-37.pyc │ └── collater.cpython-37.pyc ├── collater.py ├── convert.py ├── prepare_dataset.py └── test_collater.py ├── detect.py ├── eval.py ├── models ├── __pycache__ │ ├── anchors.cpython-37.pyc │ ├── fpn.cpython-37.pyc │ ├── heads.cpython-37.pyc │ ├── losses.cpython-37.pyc │ ├── model.cpython-37.pyc │ └── resnet.cpython-37.pyc ├── anchors.py ├── fpn.py ├── heads.py ├── losses.py ├── model.py └── resnet.py ├── requirements.txt ├── resnet_pretrained_pth ├── .gitignore └── README.md ├── resource ├── HRSC_Result.png └── RSSDD_Result.png ├── setup.py ├── show.py ├── show_result ├── HRSC │ ├── demo1.jpg │ ├── demo2.jpg │ └── demo3.jpg └── RSSDD │ ├── demo1.jpg │ ├── demo2.jpg │ └── demo3.jpg ├── train.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── augment.cpython-37.pyc │ ├── bbox_transforms.cpython-37.pyc │ ├── box_coder.cpython-37.pyc │ ├── map.cpython-37.pyc │ └── utils.cpython-37.pyc ├── augment.py ├── bbox_transforms.py ├── box_coder.py ├── map.py ├── rotation_nms │ ├── .gitignore │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-37.pyc │ └── cpu_nms.pyx ├── rotation_overlaps │ ├── .gitignore │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-37.pyc │ └── rbox_overlaps.pyx └── utils.py └── warmup.py /Logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import colorlog 3 | 4 | """ Logger Rank: (low -> high) 5 | # 1. DEBUG 6 | # 2. INFO 7 | # 3. WARNING 8 | # 4. ERROR 9 | # 5. CRITICAL 10 | """ 11 | 12 | 13 | class Logger(object): 14 | def __init__(self, log_path, logging_name): 15 | self.log_path = log_path 16 | self.logging_name = logging_name 17 | self.dash_line = '-' * 60 + '\n' 18 | self.level_color = {'DEBUG': 'cyan', 19 | 'INFO': 'bold_white', 20 | 'WARNING': 'yellow', 21 | 'ERROR': 'red', 22 | 'CRITICAL': 'red'} 23 | 24 | def logger_config(self): 25 | logger = logging.getLogger(self.logging_name) 26 | logger.setLevel(level=logging.DEBUG) 27 | handler = logging.FileHandler(self.log_path, encoding='UTF-8') 28 | handler.setLevel(logging.DEBUG) 29 | file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', 30 | datefmt="%Y-%m-%d %H:%M:%S") 31 | handler.setFormatter(file_formatter) 32 | 33 | console_formatter = colorlog.ColoredFormatter( 34 | '%(log_color)s[%(asctime)s] - [%(name)s] - [%(levelname)s]:\n%(message)s', datefmt="%Y-%m-%d %H:%M:%S", 35 | log_colors=self.level_color) 36 | 37 | console = logging.StreamHandler() 38 | console.setFormatter(console_formatter) 39 | console.setLevel(logging.INFO) 40 | 41 | logger.addHandler(handler) 42 | logger.addHandler(console) 43 | return logger 44 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python setup.py build_ext --inplace 3 | rm -rf build 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## :rocket:RetinaNet Oriented Detector Based PyTorch 2 | This is an oriented detector **Rotation-RetinaNet** implementation on Optical and SAR **ship dataset**. 3 | - SAR ship dataset (SSDD): [SSDD Dataset link](https://github.com/TianwenZhang0825/Official-SSDD) 4 | - Optical ship dataset (HRSC): [HRSC Dataset link](https://www.kaggle.com/guofeng/hrsc2016) 5 | - RetinaNet Detector original paper link is [here](https://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf). 6 | ## :star2:Performance of the implemented Rotation-RetinaNet Detector 7 | 8 | ### Detection Performance on HRSC Dataset. 9 | 10 | 11 | ### Detection Performance on SSDD Dataset. 12 | 13 | 14 | ## :dart:Experiment 15 | 16 | | Dataset | Backbone | Input Size | bs | Trick | mAP.5 | Config | 17 | |:-------:|:--------:|:----------:|:--:|:-----:|:-----:|:------:| 18 | | SSDD | ResNet-50| 512 x 512 | 16 | N | 78.96 |[config file](/configs/retinanet_r50_fpn_ssdd.yml)| 19 | | SSDD | ResNet-50| 512 x 512 | 16 |Augment| 85.6 |[config file](/configs/retinanet_r50_fpn_ssdd.yml)| 20 | | HRSC | ResNet-50| 512 x 512 | 16 | N | 70.71 |[config file](/configs/retinanet_r50_fpn_hrsc.yml)| 21 | | HRSC | ResNet-50| 512 x 512 | 4 | N | 74.22 |[config file](/configs/retinanet_r50_fpn_hrsc.yml)| 22 | | HRSC | ResNet-50| 512 x 512 | 16 |Augment| 80.20 |[config file](/configs/retinanet_r50_fpn_hrsc.yml)| 23 | 24 | ## :boom:Get Started 25 | ### Installation 26 | #### A. Install requirements: 27 | ``` 28 | conda create -n rotate python=3.7 29 | conda activate rotate 30 | conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=11.0 -c pytorch 31 | pip install -r requirements.txt 32 | 33 | Note: the opencv version must > 4.5.1 34 | ``` 35 | #### B. Install rotation\_nms and rotation\_overlaps module: 36 | ``` 37 | Only need one Step: 38 | make 39 | ``` 40 | ## Demo 41 | ### A. Set project's data path 42 | you should set project's data path in `yml` file first. 43 | ``` 44 | # .yml file 45 | # Note: all the path should be absolute path. 46 | data_path = r'/$ROOT_PATH/SSDD_data/' # absolute data root path 47 | output_path = r'/$ROOT_PATH/Output/' # absolute model output path 48 | 49 | # For example 50 | $ROOT_PATH 51 | -HRSC/ 52 | -train/ # train set 53 | -Annotations/ 54 | -*.xml 55 | -images/ 56 | -*.jpg 57 | -test/ # test set 58 | -Annotations/ 59 | -*.xml 60 | -images/ 61 | -*.jpg 62 | -ground-truth/ 63 | -*.txt # gt label in txt format (for voc evaluation method) 64 | 65 | -SSDD/ 66 | -train/ # train set 67 | -Annotations/ 68 | -*.xml 69 | -images/ 70 | -*.jpg 71 | -test/ # test set 72 | -Annotations/ 73 | -*.xml 74 | -images/ 75 | -*.jpg 76 | -ground-truth/ 77 | -*.txt # gt label in txt format (for voc evaluation method) 78 | 79 | 80 | -Output/ 81 | -checkpoints/ 82 | - the path of saving chkpt files 83 | -tensorboard/ 84 | - the path of saving tensorboard event files 85 | -evaluate/ 86 | - the path of saving model detection results for evaluate (voc method method) 87 | -log.log (save the loss and eval result) 88 | -yml file (config file) 89 | ``` 90 | ### B. Run the show.py 91 | ``` 92 | # for SSDD dataset 93 | python show.py --config_file ./configs/retinanet_r50_fpn_ssdd.yml --chkpt {chkpt.file} --result_path show_result/RSSDD --pic_name demo1.jpg 94 | 95 | # for HRSC dataset 96 | python show.py --config_file ./configs/retinanet_r50_fpn_hrsc.yml --chkpt {chkpt.file} --result_path show_result/HRSC --pic_name demo1.jpg 97 | ``` 98 | ## Train 99 | ### A. Prepare dataset 100 | you should structure your dataset files as shown above. 101 | ### B. Manual set project's hyper parameters 102 | you should manual set projcet's hyper parameters in `config` file. 103 | ``` 104 | 1. data file structure (Must Be Set !) 105 | has shown above. 106 | 107 | 2. Other settings (Optional) 108 | if you want to follow my experiment, dont't change anything. 109 | ``` 110 | ### C. Train Rotation-RetinaNet on SSDD or HRSC dataset with resnet-50 from scratch 111 | #### C.1 Download the pre-trained resnet-50 pth file 112 | you should download the pre-trained resnet-50 pth first and put the pth file in `resnet_pretrained_pth/` folder. 113 | #### C.2 Train Rotation-RetinaNet Detector on SSDD or HRSC Dataset with pre-trained pth file 114 | ``` 115 | # train model on SSDD dataset from scratch 116 | python train.py --config_file ./configs/retinanet_r50_fpn_ssdd.yml --resume None 117 | 118 | # train model on HRSC dataset from scratch 119 | python train.py --config_file ./configs/retinanet_r50_hrsc.yml --resume None 120 | 121 | ``` 122 | ### D. Resume training Rotation-RetinaNet detector on SSDD or HRSC dataset 123 | ``` 124 | # train model on SSDD dataset from specific epoch 125 | python train.py --config_file ./configs/retinanet_r50_fpn_ssdd.yml --resume {epoch}_{step}.pth 126 | 127 | # train model on HRSC dataset from specific epoch 128 | python train.py --config_file ./configs/retinanet_r50_hrsc.yml --resume {epoch}_{step}.pth 129 | 130 | ``` 131 | ## Evaluation 132 | ### A. evaluate model performance on SSDD or HRSC val set. 133 | ``` 134 | python eval.py --Dataset SSDD --config_file ./configs/retinanet_r50_fpn_ssdd.yml --evaluate True --chkpt {epoch}_{step}.pth 135 | python eval.py --Dataset HRSC --config_file ./configs/retinanet_r50_fpn_hrsc.yml --evaluate True --chkpt {epoch}_{step}.pth 136 | ``` 137 | ## :bulb:Inferences 138 | Thanks for these great work. 139 | [https://github.com/open-mmlab/mmrotate](https://github.com/open-mmlab/mmrotate) 140 | [https://github.com/ming71/Rotated-RetinaNet](https://github.com/ming71/Rotated-RetinaNet) 141 | 142 | ## :fast\_forward:Zhihu Link 143 | [zhihu article](https://zhuanlan.zhihu.com/p/490422549?) 144 | -------------------------------------------------------------------------------- /configs/retinanet_r50_fpn_hrsc.yml: -------------------------------------------------------------------------------- 1 | backbone: {'type': 'resnet50', 2 | 'pretrained': True} 3 | 4 | neck: {'type': 'fpn', 5 | 'init_method': 'xavier_init', 6 | 'extra_conv_init_method': 'xavier_init'} 7 | 8 | head: {'type': 'retinanet', 9 | 'num_stacked': 4, 10 | 'cls_branch_init_method': 'normal_init', 11 | 'reg_branch_init_method': 'normal_init'} 12 | 13 | loss: {'cls': {'alpha': 0.25, 'gamma': 2.0}, 14 | 'reg': {'type': 'smooth'}} 15 | 16 | 17 | assigner: {'pos_iou_thr': 0.3, 18 | 'neg_iou_thr': 0.2, 19 | 'min_pos_iou': 0.0, 20 | 'low_quality_match': True} 21 | 22 | # warmup settings 23 | warm_up: True 24 | warmup_epoch: 2 25 | warmup_lr: 0.00001 26 | 27 | # data settings 28 | dataset: HRSC 29 | classes: ['ship'] 30 | image_size: 512 31 | keep_ratio: False 32 | batch_size: 16 33 | augment: False 34 | 35 | data_path: '{Data Root Path.}' 36 | output_path: '{Project Output Path.}' 37 | # weight the delta values. 38 | 39 | optimizer: adam 40 | lr: 0.0001 41 | epoch: 100 42 | evaluation_train_start: 101 43 | evaluation_val_start: 44 44 | save_interval: 4 45 | val_interval: 4 46 | eval_method: voc 47 | freeze_bn: True 48 | device: [3] 49 | 50 | # anchor settings 51 | base_size: 4 52 | ratios: [0.2, 0.5, 1.0, 2.0, 5.0] 53 | #ratios: [0.5, 1.0, 2.0] 54 | scales_per_octave: 3 55 | angle: 0 # opencv version > 4.5.1 56 | 57 | rotation_nms_thr: 0.5 58 | score_thr: 0.05 59 | 60 | tensorboard: 'tensorboard' 61 | checkpoint: 'checkpoints' 62 | log: 'log.log' 63 | -------------------------------------------------------------------------------- /configs/retinanet_r50_fpn_ssdd.yml: -------------------------------------------------------------------------------- 1 | backbone: {'type': 'resnet50', 2 | 'pretrained': True} 3 | 4 | neck: {'type': 'fpn', 5 | 'init_method': 'xavier_init', 6 | 'extra_conv_init_method': 'xavier_init'} 7 | 8 | head: {'type': 'retinanet', 9 | 'num_stacked': 4, 10 | 'cls_branch_init_method': 'normal_init', 11 | 'reg_branch_init_method': 'normal_init'} 12 | 13 | loss: {'cls': {'alpha': 0.25, 'gamma': 2.0}, 14 | 'reg': {'type': 'smooth'}} 15 | 16 | 17 | assigner: {'pos_iou_thr': 0.3, 18 | 'neg_iou_thr': 0.2, 19 | 'min_pos_iou': 0.0, 20 | 'low_quality_match': True} 21 | 22 | # warmup settings 23 | warm_up: True 24 | warmup_epoch: 2 25 | warmup_lr: 0.00001 26 | 27 | # data settings 28 | dataset: SSDD 29 | classes: ['ship'] 30 | image_size: 512 31 | keep_ratio: False 32 | batch_size: 16 33 | augment: False 34 | 35 | data_path: '{Data Root Path.}' 36 | output_path: '{Project Root Path.}' 37 | 38 | optimizer: adam 39 | lr: 0.0001 40 | epoch: 100 41 | evaluation_train_start: 101 42 | evaluation_val_start: 101 43 | save_interval: 4 44 | val_interval: 4 45 | eval_method: voc 46 | freeze_bn: True 47 | device: [2] 48 | 49 | # anchor settings 50 | base_size: 4 51 | ratios: [0.2, 0.5, 1.0, 2.0, 5.0] 52 | #ratios: [0.5, 1.0, 2.0] 53 | scales_per_octave: 3 54 | angle: 0 # opencv version > 4.5.1 55 | 56 | rotation_nms_thr: 0.5 57 | score_thr: 0.05 58 | 59 | tensorboard: 'tensorboard' 60 | checkpoint: 'checkpoints' 61 | log: 'log.log' 62 | -------------------------------------------------------------------------------- /datasets/HRSC_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | import matplotlib.pyplot as plt 4 | from utils.bbox_transforms import * 5 | import cv2 6 | from utils.augment import * 7 | 8 | 9 | class HRSCDataset(data.Dataset): 10 | def __init__(self, root_path, set_name, augment=False, classes=None): 11 | self.root_path = root_path 12 | self.set_name = set_name 13 | self.augment = augment 14 | self.image_lists = self._load_image_names() 15 | self.classes = classes 16 | self.num_classes = len(self.classes) 17 | self.class_to_ind = dict(zip(self.classes, range(self.num_classes))) 18 | if self.augment is True: 19 | print(f'[Info]: Using the data augmentation.') 20 | else: 21 | print(f'[Info]: Not using the data augmentation.') 22 | 23 | def __len__(self): 24 | return len(self.image_lists) 25 | 26 | def __getitem__(self, index): 27 | imagename = self.image_lists[index] 28 | img_path = os.path.join(self.root_path, self.set_name, "images", imagename) 29 | image = cv2.cvtColor(cv2.imread(img_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 30 | roidb = self._load_annotation(imagename) 31 | gt_inds = np.where(roidb['gt_classes'] != 0)[0] 32 | num_gt = len(roidb['boxes']) 33 | gt_boxes = np.zeros((len(gt_inds), 9), dtype=np.float32) # [x1,y1,x2,y2,x3,y3,x4,y4,class_index] 34 | if num_gt: 35 | # get the bboxes and classes info from the self._load_annotation() result. 36 | bboxes = roidb['boxes'][gt_inds, :] 37 | classes = roidb['gt_classes'][gt_inds] - 1 38 | 39 | # perform the data augmentation 40 | if self.augment is True: 41 | transforms = Augment([ 42 | HSV(0.5, 0.5, p=0.5), 43 | HorizontalFlip(p=0.5), 44 | VerticalFlip(p=0.5) 45 | ]) 46 | image, bboxes = transforms(image, bboxes) 47 | gt_boxes[:, :-1] = bboxes 48 | 49 | for i, bbox in enumerate(bboxes): 50 | gt_boxes[i, 8] = classes[i] 51 | 52 | return {'image': image, 'boxes': gt_boxes, 'imagename': imagename} 53 | 54 | def _load_image_names(self): 55 | return os.listdir(os.path.join(self.root_path, self.set_name, 'images')) 56 | 57 | def _load_annotation(self, imagename): 58 | filename = os.path.join(self.root_path, self.set_name, "Annotations", imagename.replace('jpg', 'xml')) 59 | boxes, gt_classes = [], [] 60 | with open(filename, 'r', encoding='utf-8-sig') as f: 61 | content = f.read() 62 | objects = content.split('') 63 | info = objects.pop(0) 64 | for obj in objects: 65 | cls_id = obj[obj.find('') + 10: obj.find('')] 66 | cx = float(eval(obj[obj.find('') + 9: obj.find('')])) 67 | cy = float(eval(obj[obj.find('') + 9: obj.find('')])) 68 | w = float(eval(obj[obj.find('') + 8: obj.find('')])) 69 | h = float(eval(obj[obj.find('') + 8: obj.find('')])) 70 | angle = float(obj[obj.find('') + 10: obj.find('')]) # radian 71 | 72 | # add extra score parameter to use obb2poly_up 73 | bbox = np.array([[cx, cy, w, h, angle, 0]], dtype=np.float32) 74 | polygon = obb2poly_np(bbox, 'le90')[0, :-1].astype(np.float32) 75 | boxes.append(polygon) 76 | label_index = 1 77 | gt_classes.append(label_index) 78 | return {'boxes': np.array(boxes), 'gt_classes': np.array(gt_classes)} 79 | 80 | 81 | if __name__ == '__main__': 82 | hrsc = HRSCDataset(root_path='/data/fzh/HRSC/', 83 | set_name='train', 84 | augment=True, 85 | classes=['ship', ]) 86 | for idx in range(len(hrsc)): 87 | a = hrsc[idx] 88 | bboxes = a['boxes'] # polygon format [x1, y1, x2, y2, x3, y3, x4, y4] 89 | img = a['image'] 90 | image_name = a['imagename'] 91 | for gt_bbox in bboxes: 92 | ps = gt_bbox[:-1].reshape(1, 4, 2).astype(np.int32) 93 | cv2.drawContours(img, [ps], -1, [0, 255, 0], thickness=2) 94 | plt.imshow(img) 95 | plt.title(image_name) 96 | plt.show() 97 | -------------------------------------------------------------------------------- /datasets/SSDD_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | import matplotlib.pyplot as plt 4 | from utils.bbox_transforms import * 5 | import cv2 6 | import xml.etree.ElementTree as ET 7 | from utils.augment import * 8 | 9 | 10 | class SSDDataset(data.Dataset): 11 | def __init__(self, root_path, set_name, augment=False, classes=None): 12 | self.root_path = root_path 13 | self.set_name = set_name 14 | self.augment = augment 15 | self.image_lists = self._load_image_names() 16 | self.classes = classes 17 | self.num_classes = len(self.classes) 18 | self.class_to_ind = dict(zip(self.classes, range(self.num_classes))) 19 | if self.augment is True: 20 | print(f'[Info]: Using the data augmentation.') 21 | else: 22 | print(f'[Info]: Not using the data augmentation.') 23 | 24 | def __len__(self): 25 | return len(self.image_lists) 26 | 27 | def __getitem__(self, index): 28 | imagename = self.image_lists[index] 29 | img_path = os.path.join(self.root_path, self.set_name, "images", imagename) 30 | image = cv2.cvtColor(cv2.imread(img_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 31 | roidb = self._load_annotation(imagename) 32 | gt_inds = np.where(roidb['gt_classes'] != 0)[0] 33 | num_gt = len(roidb['boxes']) 34 | gt_boxes = np.zeros((len(gt_inds), 9), dtype=np.float32) # [x1,y1,x2,y2,x3,y3,x4,y4,class_index] 35 | if num_gt: 36 | # get the bboxes and classes info from the self._load_annotation() result. 37 | bboxes = roidb['boxes'][gt_inds, :] 38 | classes = roidb['gt_classes'][gt_inds] - 1 39 | 40 | # perform the data augmentation 41 | if self.augment is True: 42 | transforms = Augment([ 43 | # HSV(0.5, 0.5, p=0.5), 44 | HorizontalFlip(p=0.5), 45 | VerticalFlip(p=0.5) 46 | ]) 47 | image, bboxes = transforms(image, bboxes) 48 | 49 | gt_boxes[:, :-1] = bboxes 50 | 51 | for i, bbox in enumerate(bboxes): 52 | gt_boxes[i, 8] = classes[i] 53 | 54 | return {'image': image, 'boxes': gt_boxes, 'imagename': imagename} 55 | 56 | def _load_image_names(self): 57 | return os.listdir(os.path.join(self.root_path, self.set_name, 'images')) 58 | 59 | def _load_annotation(self, imagename): 60 | filename = os.path.join(self.root_path, self.set_name, "Annotations", imagename.replace('jpg', 'xml')) 61 | boxes, gt_classes = [], [] 62 | infile = open(os.path.join(filename)) 63 | tree = ET.parse(infile) 64 | root = tree.getroot() 65 | for obj in root.iter('object'): 66 | rbox = obj.find('rotated_bndbox') 67 | x1 = float(rbox.find('x1').text) 68 | y1 = float(rbox.find('y1').text) 69 | x2 = float(rbox.find('x2').text) 70 | y2 = float(rbox.find('y2').text) 71 | x3 = float(rbox.find('x3').text) 72 | y3 = float(rbox.find('y3').text) 73 | x4 = float(rbox.find('x4').text) 74 | y4 = float(rbox.find('y4').text) 75 | polygon = np.array([x1, y1, x2, y2, x3, y3, x4, y4], dtype=np.int32) 76 | boxes.append(polygon) 77 | label_index = 1 78 | gt_classes.append(label_index) 79 | return {'boxes': np.array(boxes), 'gt_classes': np.array(gt_classes)} 80 | 81 | 82 | if __name__ == '__main__': 83 | rssdd = SSDDataset(root_path='/data/fzh/RSSDD/', 84 | set_name='train', 85 | augment=True, 86 | classes=['ship', ]) 87 | for idx in range(len(rssdd)): 88 | idx = 0 89 | a = rssdd[idx] 90 | bboxes = a['boxes'] 91 | img = a['image'] 92 | for gt_bbox in bboxes: 93 | 94 | ps = gt_bbox[:-1].reshape(-1, 4, 2).astype(np.int32) 95 | cv2.drawContours(img, [ps], -1, [0, 255, 0], thickness=2) 96 | 97 | plt.imshow(img) 98 | plt.show() -------------------------------------------------------------------------------- /datasets/__pycache__/HRSC_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/datasets/__pycache__/HRSC_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/SSDD_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/datasets/__pycache__/SSDD_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/collater.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/datasets/__pycache__/collater.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/collater.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.utils import Rescale, Normalize, Reshape 4 | from torchvision.transforms import Compose 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | import math 8 | from utils.bbox_transforms import poly2obb_np 9 | 10 | 11 | class Collater(object): 12 | def __init__(self, scales, keep_ratio=False, multiple=32): 13 | self.scales = scales 14 | self.keep_ratio = keep_ratio 15 | self.multiple = multiple 16 | 17 | def __call__(self, batch): 18 | scales = int(np.floor(float(self.scales) / self.multiple) * self.multiple) 19 | rescale = Rescale(target_size=scales, keep_ratio=self.keep_ratio) 20 | transform = Compose([Normalize(), Reshape(unsqueeze=False)]) 21 | 22 | images = [sample['image'] for sample in batch] 23 | bboxes = [sample['boxes'] for sample in batch] 24 | image_names = [sample['imagename'] for sample in batch] 25 | 26 | max_height, max_width = -1, -1 27 | 28 | for index in range(len(batch)): 29 | im, _ = rescale(images[index]) 30 | height, width = im.shape[0], im.shape[1] 31 | max_height = height if height > max_height else max_height 32 | max_width = width if width > max_width else max_width 33 | 34 | padded_ims = torch.zeros(len(batch), 3, max_height, max_width) 35 | 36 | # ready to save the openCV format info [xc, yc, w, h, theta, class_index] 37 | num_params = 6 38 | max_num_boxes = max(bbox.shape[0] for bbox in bboxes) 39 | padded_boxes = torch.ones(len(batch), max_num_boxes, num_params) * -1 40 | 41 | for i in range(len(batch)): 42 | im, bbox = images[i], bboxes[i] 43 | 44 | # rescale the image 45 | im, im_scale = rescale(im) 46 | height, width = im.shape[0], im.shape[1] 47 | padded_ims[i, :, :height, :width] = transform(im) # transform is similar to the pipeline in mmdet 48 | 49 | # rescale the bounding box 50 | oc_bboxes = [] 51 | labels = [] 52 | for single in bbox: 53 | 54 | # rescale the bounding box 55 | single[0::2] *= im_scale[0] 56 | single[1::2] *= im_scale[1] 57 | 58 | # polygons to the opencv format, opencv version > 4.5.1 59 | oc_bbox = poly2obb_np(single[:-1], 'oc') # oc_bbox: [xc, yc, h, w, angle(radian)] 60 | assert 0 < oc_bbox[4] <= np.pi / 2 61 | oc_bboxes.append(np.array(oc_bbox, dtype=np.float32)) 62 | labels.append(single[-1]) 63 | 64 | if bbox.shape[0] != 0: 65 | padded_boxes[i, :bbox.shape[0], :-1] = torch.from_numpy(np.array(oc_bboxes)) 66 | padded_boxes[i, :bbox.shape[0], -1] = torch.from_numpy(np.array(labels)) 67 | 68 | # # visualize rescale result 69 | # vis_im = images[i] 70 | # vis_im, _ = rescale(vis_im) 71 | # for gt_bbox in oc_bboxes: 72 | # xc, yc, h, w, ag = gt_bbox[:5] 73 | # print(f'GT Annotation: xc:{xc} yc:{yc} h:{h} w:{w} ag:{ag}') 74 | # wx, wy = -w / 2 * math.sin(ag), w / 2 * math.cos(ag) 75 | # hx, hy = h / 2 * math.cos(ag), h / 2 * math.sin(ag) 76 | # p1 = (xc - wx - hx, yc - wy - hy) 77 | # p2 = (xc - wx + hx, yc - wy + hy) 78 | # p3 = (xc + wx + hx, yc + wy + hy) 79 | # p4 = (xc + wx - hx, yc + wy - hy) 80 | # ps = np.int0(np.array([p1, p2, p3, p4])) 81 | # cv2.drawContours(vis_im, [ps], -1, [0, 255, 0], thickness=2) 82 | # plt.imshow(vis_im) 83 | # plt.title(image_names[i]) 84 | # plt.show() 85 | 86 | return {'image': padded_ims, 'bboxes': padded_boxes, 'image_name': image_names} 87 | -------------------------------------------------------------------------------- /datasets/convert.py: -------------------------------------------------------------------------------- 1 | """This script is used to convert xml format to txt format for HRSC Dataset for evaluation.""" 2 | 3 | import os 4 | import numpy as np 5 | import math 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | class Convert(object): 11 | def __init__(self, xml_path, txt_path, image_path): 12 | self.xml_path = xml_path 13 | self.txt_path = txt_path 14 | self.image_path = image_path 15 | self.xml_lists = os.listdir(xml_path) 16 | self._makedir() 17 | 18 | def _makedir(self): 19 | if not os.path.exists(self.txt_path): 20 | os.makedirs(self.txt_path) 21 | 22 | def _readXml(self, single_xml): 23 | with open(os.path.join(self.xml_path, single_xml), 'r', encoding='utf-8-sig') as f: 24 | content = f.read() 25 | objects = content.split('') 26 | info = objects.pop(0) 27 | 28 | results = [] 29 | for obj in objects: 30 | cls_name = 'ship' 31 | cx = round(eval(obj[obj.find('') + 9: obj.find('')])) 32 | cy = round(eval(obj[obj.find('') + 9: obj.find('')])) 33 | w = round(eval(obj[obj.find('') + 8: obj.find('')])) 34 | h = round(eval(obj[obj.find('') + 8: obj.find('')])) 35 | angle = eval(obj[obj.find('') + 10: obj.find('')]) / math.pi * 180 36 | rbox = np.array([cx, cy, w, h, angle]) 37 | quad_box = rbox_2_quad(rbox, 'xywha').squeeze() 38 | line = cls_name + ' ' + str(quad_box[0]) + ' ' + str(quad_box[1]) + ' ' + str(quad_box[2]) + ' ' +\ 39 | str(quad_box[3]) + ' ' + str(quad_box[4]) + ' ' + str(quad_box[5]) + ' ' + str(quad_box[6]) +\ 40 | ' ' + str(quad_box[7]) + '\n' 41 | results.append(line) 42 | return results 43 | 44 | def writeTxt(self): 45 | for single_xml in self.xml_lists: 46 | lines = self._readXml(single_xml) 47 | txt_file = single_xml.replace('xml', 'txt') 48 | with open(os.path.join(self.txt_path, txt_file), 'w') as f: 49 | for single_line in lines: 50 | f.write(single_line) 51 | 52 | def plotgt(self): 53 | for single_xml in self.xml_lists: 54 | single_image = single_xml.replace('xml', 'jpg') 55 | image = cv2.cvtColor(cv2.imread(os.path.join(self.image_path, single_image), cv2.IMREAD_COLOR), 56 | cv2.COLOR_BGR2RGB) 57 | lines = self._readXml(single_xml) 58 | for single_line in lines: 59 | single_line = single_line.strip().split(' ') 60 | box = np.array(list(map(float, single_line[1:]))) 61 | cv2.polylines(image, [box.reshape(-1, 2).astype(np.int32)], True, (255, 0, 0), 3) 62 | plt.imshow(image) 63 | plt.show() 64 | 65 | 66 | if __name__ == '__main__': 67 | convert = Convert(xml_path='/data/fzh/HRSC/train/Annotations/', 68 | txt_path='/data/fzh/HRSC/train/train-ground-truth/', 69 | image_path='/data/fzh/HRSC/train/images/') 70 | 71 | convert.writeTxt() 72 | -------------------------------------------------------------------------------- /datasets/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | """This script is used to convert .bmp format to .jpg format.""" 2 | 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import shutil 6 | import os 7 | import cv2 8 | 9 | 10 | class Convert(object): 11 | def __init__(self, root_path): 12 | self.root_path = root_path 13 | self.convert_image_folder = 'images' 14 | self.image_folder = 'AllImages' 15 | self._mkdir() 16 | 17 | def _mkdir(self): 18 | self.train_image_path = os.path.join(self.root_path, 'train', self.convert_image_folder) 19 | self.val_image_path = os.path.join(self.root_path, 'test', self.convert_image_folder) 20 | 21 | if not os.path.exists(self.train_image_path): 22 | os.makedirs(self.train_image_path) 23 | 24 | if not os.path.exists(self.val_image_path): 25 | os.makedirs(self.val_image_path) 26 | 27 | def convert(self, set_name): 28 | image_lists = os.listdir(os.path.join(self.root_path, set_name, self.image_folder)) 29 | for single_image in image_lists: 30 | image = cv2.imread(os.path.join(self.root_path, set_name, self.image_folder, single_image)) 31 | converted_single_image = single_image.replace('bmp', 'jpg') 32 | cv2.imwrite(os.path.join(self.root_path, set_name, self.convert_image_folder, converted_single_image), 33 | image) 34 | 35 | 36 | if __name__ == '__main__': 37 | convert = Convert(root_path='/home/fzh/Data/HRSC/') 38 | convert.convert(set_name='test') 39 | 40 | -------------------------------------------------------------------------------- /datasets/test_collater.py: -------------------------------------------------------------------------------- 1 | from datasets.HRSC_dataset import HRSCDataset 2 | from datasets.SSDD_dataset import SSDDataset 3 | from datasets.collater import Collater 4 | 5 | if __name__ == '__main__': 6 | training_set = SSDDataset(root_path='/data/fzh/RSSDD/', 7 | set_name='train', 8 | augment=True, 9 | classes=['ship']) 10 | 11 | """Check some outputs from custom collater. 12 | 1. User can specify the test_idx manually. 13 | 2. User can visualize scale image result to cancel annotation line (57-65) in collater.py""" 14 | test_idxs = [0, 1, 2, 3, 4, 5, 6] 15 | batch = [training_set[idx] for idx in test_idxs] 16 | collater = Collater(scales=512, keep_ratio=False, multiple=32) 17 | result = collater(batch) 18 | print(result) 19 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.transforms import Compose 4 | from utils.utils import Rescale, Normalize, Reshape 5 | from utils.rotation_nms.cpu_nms import cpu_nms 6 | from utils.bbox_transforms import * 7 | 8 | 9 | def im_detect(model, src, target_sizes, params, use_gpu=True, conf=None, device=None): 10 | if isinstance(target_sizes, int): 11 | target_sizes = [target_sizes] 12 | if len(target_sizes) == 1: 13 | return single_scale_detect(model, src, target_size=target_sizes[0], params=params, 14 | use_gpu=use_gpu, conf=conf, device=device) 15 | 16 | 17 | def single_scale_detect(model, src, target_size, params=None, 18 | use_gpu=True, conf=None, device=None): 19 | im, im_scales = Rescale(target_size=target_size, keep_ratio=params.keep_ratio)(src) 20 | im = Compose([Normalize(), Reshape(unsqueeze=True)])(im) 21 | if use_gpu and torch.cuda.is_available(): 22 | model, im = model.cuda(device=device), im.cuda(device=device) 23 | with torch.no_grad(): # bboxes: [x, y, x, y, a, a_x, a_y, a_x, a_y, a_a] 24 | scores, classes, boxes = model(im, test_conf=conf) 25 | scores = scores.data.cpu().numpy() 26 | classes = classes.data.cpu().numpy() 27 | boxes = boxes.data.cpu().numpy() 28 | 29 | # convert oc format to polygon for rescale predict box coordinate 30 | predicted_bboxes = [] 31 | for idx in range(len(boxes)): 32 | single_box = boxes[idx] # single box: [pred_xc, pred_yc, pred_h, pred_w, pred_angle(radian)] 33 | single_box = np.array([[single_box[0], single_box[1], single_box[2], single_box[3], single_box[4], 0]], 34 | dtype=np.float32) # add extra score 0 35 | predicted_polygon = obb2poly_np_oc(single_box)[0, :-1].astype(np.float32) 36 | predicted_polygon[0::2] /= im_scales[0] 37 | predicted_polygon[1::2] /= im_scales[1] 38 | predicted_bbox = poly2obb_np(predicted_polygon, 'oc') # polygon 2 rbboxes (oc format: [xc, yc, h, w, angle(radian)] 39 | predicted_bboxes.append(predicted_bbox) 40 | 41 | if boxes.shape[1] > 5: 42 | # [pred_xc, pred_yc, pred_h, pred_w, pred_angle(radian), 43 | # anchor_xc, anchor_yc, anchor_w, anchor_h, anchor_angle(radian)] 44 | boxes[:, 5:9] = boxes[:, 5:9] / im_scales 45 | scores = np.reshape(scores, (-1, 1)) 46 | classes = np.reshape(classes, (-1, 1)) 47 | for id in range(len(predicted_bboxes)): 48 | boxes[id, :5] = predicted_bboxes[id] 49 | cls_dets = np.concatenate([classes, scores, boxes], axis=1) 50 | keep = np.where(classes < model.num_class)[0] 51 | return cls_dets[keep, :] 52 | # cls, score, x,y,w,h,a, a_x,a_y,a_w,a_h,a_a 53 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from detect import im_detect 3 | import shutil 4 | from tqdm import tqdm 5 | from utils.map import eval_mAP 6 | from utils.bbox_transforms import * 7 | 8 | 9 | # evaluate by rotation detection result 10 | def evaluate(model=None, 11 | target_size=None, 12 | test_path=None, 13 | conf=None, 14 | device=None, 15 | mode=None, 16 | params=None): 17 | evaluate_dir = 'voc_evaluate' 18 | _dir = mode + '_evaluate' 19 | out_dir = os.path.join(params.output_path, evaluate_dir, _dir, 'detection-results') 20 | if os.path.exists(out_dir): 21 | shutil.rmtree(out_dir) 22 | os.makedirs(out_dir) 23 | 24 | # Step1. Collect detect result for per image or get predict result 25 | for image_name in tqdm(os.listdir(os.path.join(params.data_path, mode, 'images'))): 26 | image_path = os.path.join(params.data_path, mode, 'images', image_name) 27 | image = cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 28 | dets = im_detect(model=model, 29 | src=image, 30 | params=params, 31 | target_sizes=target_size, 32 | use_gpu=True, 33 | conf=conf, # score threshold 34 | device=device) 35 | 36 | # Step2. Write per image detect result into per txt file 37 | # line = cls_name score x1 y1 x2 y2 x3 y3 x4 y4 38 | img_ext = image_name.split('.')[-1] 39 | with open(os.path.join(out_dir, image_name.replace(img_ext, 'txt')), 'w') as f: 40 | for det in dets: 41 | cls_ind = int(det[0]) 42 | cls_socre = det[1] 43 | rbox = det[2:7] # [xc, yc, h, w, angle(radian)] 44 | 45 | if np.isnan(rbox[0]) or np.isnan(rbox[1]) or np.isnan(rbox[2]) or np.isnan(rbox[3]) or np.isnan(rbox[4]): 46 | line = '' 47 | else: 48 | # add extra score 49 | rbbox = np.array([[rbox[0], rbox[1], rbox[2], rbox[3], rbox[4], 0]], dtype=np.float32) 50 | polygon = obb2poly_np(rbbox, 'oc')[0, :-1].astype(np.float32) 51 | line = str(params.classes[cls_ind]) + ' ' + str(cls_socre) + ' ' + str(polygon[0]) + ' ' + str(polygon[1]) +\ 52 | ' ' + str(polygon[2]) + ' ' + str(polygon[3]) + ' ' + str(polygon[4]) + ' ' + str(polygon[5]) +\ 53 | ' ' + str(polygon[6]) + ' ' + str(polygon[7]) + '\n' 54 | f.write(line) 55 | 56 | # Step3. Calculate Precision, Recall, mAP, plot PR Curve 57 | mAP, Precision, Recall = eval_mAP(gt_root_dir=params.data_path, 58 | test_path=test_path, # test_path = ground-truth 59 | eval_root_dir=os.path.join(params.output_path, evaluate_dir, _dir), 60 | use_07_metric=False, 61 | thres=0.5) # rotation nms threshold 62 | print(f'mAP: {mAP}\tPrecision: {Precision}\tRecall: {Recall}') 63 | return mAP, Precision, Recall 64 | 65 | 66 | if __name__ == '__main__': 67 | import argparse 68 | import torch 69 | from train import Params 70 | import time 71 | from models.model import RetinaNet 72 | 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--device', type=int, default=0) 75 | parser.add_argument('--Dataset', type=str, default='SSDD') 76 | parser.add_argument('--config_file', type=str, default='./configs/retinanet_r50_fpn_ssdd.yml') 77 | parser.add_argument('--target_size', type=int, default=512) 78 | parser.add_argument('--chkpt', type=str, default='best/best.pth', help='the checkpoint file of the trained model.') 79 | parser.add_argument('--score_thr', type=float, default=0.05) 80 | 81 | parser.add_argument('--evaluate', type=bool, default=True) 82 | parser.add_argument('--FPS', type=bool, default=False, help='Check the FPS of the Model.') # todo: Ready to Support 83 | args = parser.parse_args() 84 | params = Params(args.config_file) 85 | params.backbone['pretrained'] = False 86 | model = RetinaNet(params) 87 | 88 | checkpoint = os.path.join(params.output_path, 'checkpoints', args.chkpt) 89 | 90 | # from checkpoint load model weight file 91 | # model weight 92 | chkpt = torch.load(checkpoint, map_location='cpu') 93 | pth = chkpt['model'] 94 | model.load_state_dict(pth) 95 | model.cuda(device=args.device) 96 | 97 | """The following codes is used to Debug eval() function.""" 98 | if args.evaluate: 99 | model.eval() 100 | mAP, Precision, Recall = evaluate( 101 | model=model, 102 | target_size=[args.target_size], 103 | test_path='ground-truth', 104 | conf=args.score_thr, # score threshold 105 | device=args.device, 106 | mode='test', 107 | params=params) 108 | print(f'mAP: {mAP}\nPrecision: {Precision}\nRecall: {Recall}\n') 109 | 110 | """The following codes are used to calculate FPS of model.""" 111 | if args.FPS: 112 | times = 50 # 50 is enough to balance some additional times for IO 113 | image_path = os.path.join(params.data_path, args.single_image) 114 | image = cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 115 | model.eval() 116 | t1 = time.time() 117 | for _ in range(times): 118 | dets = im_detect(model=model, 119 | image=image, 120 | target_sizes=[args.target_size], 121 | use_gpu=True, 122 | conf=0.25, 123 | device=args.device, 124 | params=params) 125 | t2 = time.time() 126 | tact_time = (t2 - t1) / times 127 | print(f'{tact_time} seconds, {1 / tact_time} FPS, Batch_size = 1') 128 | -------------------------------------------------------------------------------- /models/__pycache__/anchors.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/anchors.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/fpn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/fpn.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/heads.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/heads.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Anchors(nn.Module): 7 | def __init__(self, 8 | params=None, 9 | pyramid_levels=None, 10 | strides=None, 11 | rotations=None): 12 | super(Anchors, self).__init__() 13 | self.pyramid_levels = pyramid_levels 14 | self.strides = strides 15 | self.base_size = params.base_size 16 | self.ratios = params.ratios 17 | self.scales = params.scales 18 | self.rotations = rotations 19 | 20 | if pyramid_levels is None: 21 | self.pyramid_levels = [3, 4, 5, 6, 7] 22 | 23 | if strides is None: 24 | self.strides = [2 ** x for x in self.pyramid_levels] 25 | 26 | self.base_size = params.base_size 27 | self.ratios = params.ratios 28 | self.scales = np.array([2**(i / 3) for i in range(params.scales_per_octave)]) 29 | self.rotations = np.array([params.angle / 180 * np.pi]) 30 | 31 | self.num_anchors = len(self.scales) * len(self.ratios) * len(self.rotations) 32 | 33 | print(f'[Info]: anchor ratios: {self.ratios}\tanchor scales: {self.scales}\tbase_size: {self.base_size}\t' 34 | f'angle: {self.rotations}') 35 | print(f'[Info]: number of anchors: {self.num_anchors}') 36 | 37 | @staticmethod 38 | def generate_anchors(base_size, ratios, scales, rotations): 39 | """ 40 | Generate anchor (reference) windows by enumerating aspect ratios X 41 | scales w.r.t. a reference window. 42 | 43 | anchors: [xc, yc, w, h, angle(radian)] 44 | """ 45 | num_anchors = len(ratios) * len(scales) * len(rotations) 46 | # initialize output anchors 47 | anchors = np.zeros((num_anchors, 5)) 48 | # scale base_size 49 | anchors[:, 2:4] = base_size * np.tile(scales, (2, len(ratios) * len(rotations))).T 50 | # compute areas of anchors 51 | areas = anchors[:, 2] * anchors[:, 3] 52 | # correct for ratios 53 | anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales) * len(rotations))) 54 | anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales) * len(rotations)) 55 | # add rotations 56 | anchors[:, 4] = np.tile(np.repeat(rotations, len(scales)), (1, len(ratios))).T[:, 0] 57 | # # transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2) 58 | # anchors[:, 0:3:2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T 59 | # anchors[:, 1:4:2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T 60 | return anchors # [x_ctr, y_ctr, w, h, angle(radian)] 61 | 62 | @staticmethod 63 | def shift(shape, stride, anchors): 64 | shift_x = np.arange(0, shape[1]) * stride 65 | shift_y = np.arange(0, shape[0]) * stride 66 | shift_x, shift_y = np.meshgrid(shift_x, shift_y) 67 | shifts = np.vstack(( 68 | shift_x.ravel(), shift_y.ravel(), 69 | np.zeros(shift_x.ravel().shape), np.zeros(shift_y.ravel().shape), 70 | np.zeros(shift_x.ravel().shape) 71 | )).transpose() 72 | # add A anchors (1, A, 5) to 73 | # cell K shifts (K, 1, 5) to get 74 | # shift anchors (K, A, 5) 75 | # reshape to (K*A, 5) shifted anchors 76 | A = anchors.shape[0] 77 | K = shifts.shape[0] 78 | all_anchors = (anchors.reshape((1, A, 5)) + shifts.reshape((1, K, 5)).transpose((1, 0, 2))) 79 | all_anchors = all_anchors.reshape((K * A, 5)) 80 | return all_anchors 81 | 82 | def forward(self, images): 83 | image_shape = np.array(images.shape[2:]) 84 | image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels] 85 | 86 | # compute anchors over all pyramid levels 87 | all_anchors = np.zeros((0, 5)).astype(np.float32) 88 | num_level_anchors = [] 89 | for idx, p in enumerate(self.pyramid_levels): 90 | base_anchors = self.generate_anchors( 91 | base_size=self.base_size * self.strides[idx], 92 | ratios=self.ratios, 93 | scales=self.scales, 94 | rotations=self.rotations) 95 | shifted_anchors = self.shift(image_shapes[idx], self.strides[idx], base_anchors) 96 | num_level_anchors.append(shifted_anchors.shape[0]) 97 | all_anchors = np.append(all_anchors, shifted_anchors, axis=0) 98 | all_anchors = np.expand_dims(all_anchors, axis=0) 99 | all_anchors = np.tile(all_anchors, (images.size(0), 1, 1)) 100 | all_anchors = torch.from_numpy(all_anchors.astype(np.float32)) 101 | if torch.is_tensor(images) and images.is_cuda: 102 | device = images.device 103 | all_anchors = all_anchors.cuda(device=device) 104 | return all_anchors, torch.from_numpy(np.array(num_level_anchors)).cuda(device=device) 105 | 106 | 107 | if __name__ == '__main__': 108 | from train import Params 109 | params = Params('/home/fzh/Pictures/Rotation-RetinaNet-PyTorch/configs/retinanet_r50_fpn_hrsc.yml') 110 | anchors = Anchors(params) 111 | feature_map_sizes = [(128, 128), (64, 64), (32, 32), (16, 16), (8, 8)] 112 | for level_idx in range(5): 113 | # print(f'# ============================base_anchor{level_idx}========================================= #') 114 | base_anchor = anchors.generate_anchors( 115 | base_size=anchors.base_size * anchors.strides[level_idx], 116 | ratios=anchors.ratios, 117 | scales=anchors.scales, 118 | rotations=anchors.rotations 119 | ) 120 | # print(base_anchor) 121 | print(f'# ============================shift_anchor{level_idx}========================================= #') 122 | shift_anchor = anchors.shift(feature_map_sizes[level_idx], anchors.strides[level_idx], base_anchor) 123 | print(shift_anchor) 124 | -------------------------------------------------------------------------------- /models/fpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | from utils.utils import kaiming_init, xavier_init 4 | 5 | init_method_list = ['random_init', 'kaiming_init', 'xavier_init', 'normal_init'] 6 | 7 | 8 | class FPN(nn.Module): 9 | def __init__(self, 10 | in_channel_list, 11 | out_channels, 12 | top_blocks, 13 | init_method=None): 14 | """ 15 | Args: 16 | out_channels(int): number of channels of the FPN feature. 17 | top_blocks(nn.Module or None): if provided, an extra op will be 18 | performed on the FPN output, and the result will extend the result list. 19 | init_method: which method to init lateral_conv and fpn_conv. 20 | kaiming_init: kaiming_init() 21 | xavier_init: xavier_init() 22 | random_init: PyTorch_init() 23 | """ 24 | super(FPN, self).__init__() 25 | self.inner_blocks = [] 26 | self.layer_blocks = [] 27 | self.init_method = init_method 28 | print('[Info]: ===== Neck Using FPN =====') 29 | 30 | assert init_method is not None, f'init_method in class FPN needs to be set.' 31 | assert init_method in init_method_list, f'init_method in class FPN is wrong.' 32 | if init_method is 'kaiming_init': 33 | print('[Info]: Using kaiming_init() to init lateral_conv and fpn_conv.') 34 | if init_method is 'xavier_init': 35 | print('[Info]: Using xavier_init() to init lateral_conv and fpn_conv.') 36 | if init_method is 'random_init': 37 | print('[Info]: Using PyTorch_init() to init lateral_conv and fpn_conv.') 38 | 39 | for idx, in_channels in enumerate(in_channel_list, 1): 40 | inner_block = "fpn_inner{}".format(idx) 41 | layer_block = "fpn_layer{}".format(idx) 42 | 43 | if in_channels == 0: 44 | continue 45 | 46 | # lateral conv 1x1 47 | inner_block_module = nn.Conv2d(in_channels, out_channels, 1) # with bias, without BN Layer 48 | layer_block_module = nn.Conv2d(out_channels, out_channels, 3, 1, 1) # with bias, without BN Layer 49 | 50 | if self.init_method is 'kaiming_init': 51 | kaiming_init(inner_block_module, a=0, nonlinearity='relu') 52 | kaiming_init(layer_block_module, a=0, nonlinearity='relu') 53 | 54 | if self.init_method is 'xavier_init': 55 | xavier_init(inner_block_module, gain=1, bias=0, distribution='uniform') 56 | xavier_init(layer_block_module, gain=1, bias=0, distribution='uniform') 57 | 58 | # if self.init_method is 'random_init': 59 | # Don't do anything 60 | 61 | self.add_module(inner_block, inner_block_module) 62 | self.add_module(layer_block, layer_block_module) 63 | 64 | self.inner_blocks.append(inner_block) 65 | self.layer_blocks.append(layer_block) 66 | self.top_blocks = top_blocks 67 | 68 | def forward(self, x): 69 | """ 70 | Arguments: 71 | x : feature maps for each feature level. 72 | Returns: 73 | results (tuple[Tensor]): feature maps after FPN layers. 74 | They are ordered from highest resolution first. 75 | """ 76 | last_inner = getattr(self, self.inner_blocks[-1])(x[-1]) 77 | results = [] 78 | results.append(getattr(self, self.layer_blocks[-1])(last_inner)) 79 | for feature, inner_block, layer_block in zip( 80 | x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] 81 | ): 82 | if not inner_block: 83 | continue 84 | inner_lateral = getattr(self, inner_block)(feature) 85 | inner_top_down = F.interpolate( 86 | last_inner, size= 87 | (int(inner_lateral.shape[-2]), int(inner_lateral.shape[-1])), 88 | mode='nearest') 89 | last_inner = inner_lateral + inner_top_down 90 | results.insert(0, getattr(self, layer_block)(last_inner)) 91 | 92 | if isinstance(self.top_blocks, LastLevelP6_P7): 93 | last_results = self.top_blocks(x[-1], results[-1]) 94 | results.extend(last_results) 95 | else: 96 | raise NotImplementedError 97 | 98 | return tuple(results) 99 | 100 | 101 | class LastLevelP6_P7(nn.Module): 102 | """This module is used in RetinaNet to generate extra layers, P6 and P7. 103 | Args: 104 | init_method: which method to init P6_conv and P7_conv, 105 | support methods: kaiming_init:kaiming_init, 106 | xavier_init: xavier_init, 107 | random_init: PyTorch_init 108 | """ 109 | def __init__(self, in_channels, 110 | out_channels, 111 | init_method=None): 112 | super(LastLevelP6_P7, self).__init__() 113 | self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) # with bias without BN Layer 114 | self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) # with bias without BN Layer 115 | 116 | assert init_method is not None, f'init_method in class LastLevelP6_P7 needs to be set.' 117 | assert init_method in init_method_list, f'init_method in class LastLevelP6_P7 is wrong.' 118 | 119 | if init_method is 'kaiming_init': 120 | print('[Info]: Using kaiming_init() to init P6_conv and P7_conv') 121 | for layer in [self.p6, self.p7]: 122 | kaiming_init(layer, a=0, nonlinearity='relu') 123 | 124 | if init_method is 'xavier_init': 125 | print('[Info]: Using xavier_init() to init P6_conv and P7_conv') 126 | for layer in [self.p6, self.p7]: 127 | xavier_init(layer, gain=1, bias=0, distribution='uniform') 128 | 129 | if init_method is 'random_init': 130 | print('[Info]: Using PyTorch_init() to init P6_conv and P7_conv') 131 | # Don't do anything 132 | 133 | self.use_p5 = in_channels == out_channels 134 | 135 | def forward(self, c5, p5): 136 | x = p5 if self.use_p5 else c5 137 | p6 = self.p6(x) 138 | p7 = self.p7(p6) 139 | return [p6, p7] 140 | -------------------------------------------------------------------------------- /models/heads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.utils import kaiming_init, constant_init, normal_init 4 | import math 5 | 6 | init_method_list = ['random_init', 'kaiming_init', 'xavier_init', 'normal_init'] 7 | 8 | 9 | class CLSBranch(nn.Module): 10 | def __init__(self, 11 | in_channels, 12 | feat_channels, 13 | num_stacked, 14 | init_method=None): 15 | super(CLSBranch, self).__init__() 16 | 17 | assert init_method is not None, f'init_method in class CLSBranch needs to be set.' 18 | assert init_method in init_method_list, f'init_method in class CLSBranch is wrong.' 19 | 20 | self.convs = nn.ModuleList() 21 | for i in range(num_stacked): 22 | chns = in_channels if i == 0 else feat_channels 23 | # : Conv(wo bias) + BN + Relu() 24 | # self.convs.append(nn.Conv2d(chns, feat_channels, 3, 1, 1, bias=False)) # conv_weight -> bn -> relu() 25 | # self.convs.append(nn.BatchNorm2d(feat_channels, affine=True)) # add BN layer 26 | # self.convs.append(nn.ReLU(inplace=True)) 27 | # self.init_weights() 28 | 29 | # : Conv(bias) + Relu() and using kaiming_init_weight / mmdet_init_weight 30 | self.convs.append(nn.Conv2d(chns, feat_channels, 3, 1, 1, bias=True)) # conv with bias -> relu() 31 | self.convs.append(nn.ReLU(inplace=True)) 32 | 33 | if init_method is 'kaiming_init': 34 | self.kaiming_init_weights() 35 | if init_method is 'normal_init': 36 | self.mmdet_init_weights() 37 | 38 | def mmdet_init_weights(self): 39 | print('[Info]: Using mmdet_init_weights() {normal_init} to init Cls Branch.') 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | normal_init(m, mean=0, std=0.01, bias=0) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | constant_init(m, 1, bias=0) 45 | 46 | def kaiming_init_weights(self): 47 | print('[Info]: Using kaiming_init_weights() to init Cls Branch.') 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | kaiming_init(m, a=0, nonlinearity='relu') 51 | elif isinstance(m, nn.BatchNorm2d): 52 | constant_init(m, 1, bias=0) 53 | 54 | def forward(self, x): 55 | for conv in self.convs: 56 | x = conv(x) 57 | return x 58 | 59 | 60 | class CLSHead(nn.Module): 61 | def __init__(self, 62 | feat_channels, 63 | num_anchors, 64 | num_classes): 65 | super(CLSHead, self).__init__() 66 | self.num_anchors = num_anchors 67 | self.num_classes = num_classes 68 | self.feat_channels = feat_channels 69 | self.head = nn.Conv2d(self.feat_channels, self.num_anchors * self.num_classes, 3, 1, 1) # with bias 70 | self.head_init_weights() 71 | 72 | def head_init_weights(self): 73 | print('[Info]: Using RetinaNet Paper Init Method to init Cls Head.') 74 | prior = 0.01 75 | self.head.weight.data.fill_(0) 76 | self.head.bias.data.fill_(-math.log((1.0 - prior) / prior)) 77 | 78 | def forward(self, x): 79 | x = torch.sigmoid(self.head(x)) 80 | x = x.permute(0, 2, 3, 1) 81 | n, h, w, c = x.shape 82 | x = x.reshape(n, h, w, self.num_anchors, self.num_classes) 83 | return x.reshape(x.shape[0], -1, self.num_classes) 84 | 85 | 86 | class REGBranch(nn.Module): 87 | def __init__(self, 88 | in_channels, 89 | feat_channels, 90 | num_stacked, 91 | init_method=None): 92 | super(REGBranch, self).__init__() 93 | 94 | assert init_method is not None, f'init_method in class RegBranch needs to be set.' 95 | assert init_method in init_method_list, f'init_method in class RegBranch is wrong.' 96 | 97 | self.convs = nn.ModuleList() 98 | 99 | for i in range(num_stacked): 100 | chns = in_channels if i == 0 else feat_channels 101 | 102 | # : Conv(wo bias) + BN + Relu() 103 | # self.convs.append(nn.Conv2d(chns, feat_channels, 3, 1, 1, bias=False)) # conv_weight -> bn -> relu() 104 | # self.convs.append(nn.BatchNorm2d(feat_channels, affine=True)) 105 | # self.convs.append(nn.ReLU(inplace=True)) 106 | # self.init_weights() 107 | 108 | # : Conv(bias) + Relu() and using kaiming_init_weight / mmdet_init_weight 109 | self.convs.append(nn.Conv2d(chns, feat_channels, 3, 1, 1, bias=True)) # conv with bias -> relu() 110 | self.convs.append(nn.ReLU(inplace=True)) 111 | if init_method is 'kaiming_init': 112 | self.kaiming_init_weights() 113 | if init_method is 'normal_init': 114 | self.mmdet_init_weights() 115 | 116 | def mmdet_init_weights(self): 117 | print('[Info]: Using mmdet_init_weights() {normal_init} to init Reg Branch.') 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | normal_init(m, mean=0, std=0.01, bias=0) 121 | elif isinstance(m, nn.BatchNorm2d): 122 | constant_init(m, 1, bias=0) 123 | 124 | def kaiming_init_weights(self): 125 | print('[Info]: Using kaiming_init_weights() to init Reg Branch.') 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | kaiming_init(m, a=0, nonlinearity='relu') 129 | elif isinstance(m, nn.BatchNorm2d): 130 | constant_init(m, 1, bias=0) 131 | 132 | def forward(self, x): 133 | for conv in self.convs: 134 | x = conv(x) 135 | return x 136 | 137 | 138 | class REGHead(nn.Module): 139 | def __init__(self, 140 | feat_channels, 141 | num_anchors, 142 | num_regress): 143 | super(REGHead, self).__init__() 144 | self.num_anchors = num_anchors 145 | self.num_regress = num_regress 146 | self.feat_channels = feat_channels 147 | self.head = nn.Conv2d(self.feat_channels, self.num_anchors * self.num_regress, 3, 1, 1) # with bias 148 | self.mmdet_init_weights() 149 | 150 | def mmdet_init_weights(self): 151 | print('[Info]: Using mmdet_init_weights() {normal_init} to init Reg Head.') 152 | normal_init(self.head, mean=0, std=0.01, bias=0) 153 | 154 | def forward(self, x, with_deform=False): 155 | x = self.head(x) 156 | if with_deform is False: 157 | x = x.permute(0, 2, 3, 1) 158 | return x.reshape(x.shape[0], -1, self.num_regress) 159 | else: 160 | return x 161 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils.utils import bbox_overlaps 3 | from utils.bbox_transforms import * 4 | from utils.box_coder import BoxCoder 5 | from utils.rotation_overlaps.rbox_overlaps import rbox_overlaps 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | class IntegratedLoss(nn.Module): 10 | def __init__(self, params): 11 | super(IntegratedLoss, self).__init__() 12 | loss_dict = params.loss 13 | self.alpha = loss_dict['cls']['alpha'] 14 | self.gamma = loss_dict['cls']['gamma'] 15 | func = loss_dict['reg']['type'] 16 | 17 | assign_dict = params.assigner 18 | self.pos_iou_thr = assign_dict['pos_iou_thr'] 19 | self.neg_iou_thr = assign_dict['neg_iou_thr'] 20 | self.min_pos_iou = assign_dict['min_pos_iou'] 21 | self.low_quality_match = assign_dict['low_quality_match'] 22 | 23 | self.box_coder = BoxCoder() 24 | 25 | if func == 'smooth': 26 | self.criteron = smooth_l1_loss 27 | print(f'[Info]: Using {func} Loss.') 28 | 29 | def forward(self, classifications, regressions, anchors, annotations, image_names): 30 | cls_losses = [] 31 | reg_losses = [] 32 | batch_size = classifications.shape[0] 33 | device = classifications[0].device 34 | for j in range(batch_size): 35 | image_name = image_names[j] 36 | anchor = anchors[j] # [xc, yc, w, h, angle(radian)] 37 | classification = classifications[j, :, :] 38 | regression = regressions[j, :, :] # [xc_offset, yc_offset, h_offset, w_offset, angle_offset] 39 | bbox_annotation = annotations[j, :, :] # [xc, yc, h, w, angle(radian)] 40 | bbox_annotation = bbox_annotation[bbox_annotation[:, -1] != -1] 41 | num_gt = len(bbox_annotation) 42 | if bbox_annotation.shape[0] == 0: 43 | cls_losses.append(torch.tensor(0).float().cuda(device=device)) 44 | reg_losses.append(torch.tensor(0).float().cuda(device=device)) 45 | continue 46 | classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4) 47 | 48 | # get minimum circumscribed rectangle of the rotated ground-truth box and 49 | # calculate the horizontal overlaps between minimum circumscribed rectangles and anchor boxes 50 | 51 | horizontal_overlaps = bbox_overlaps( 52 | anchor.clone(), # generate anchor data copy 53 | obb2hbb_oc(bbox_annotation[:, :-1])) 54 | 55 | # obb_rect = [xc, yc, h, w, angle(radian)] 56 | ious = rbox_overlaps( 57 | swap_axis(anchor[:, :]).cpu().numpy(), 58 | bbox_annotation[:, :-1].cpu().numpy(), 59 | horizontal_overlaps.cpu().numpy(), 60 | thresh=1e-1 61 | ) 62 | 63 | if not torch.is_tensor(ious): 64 | ious = torch.from_numpy(ious).cuda(device=device) 65 | 66 | iou_max, iou_argmax = torch.max(ious, dim=1) 67 | 68 | positive_indices = torch.ge(iou_max, self.pos_iou_thr) 69 | 70 | if self.low_quality_match is True: 71 | max_gt, argmax_gt = ious.max(dim=0) 72 | for idx in range(num_gt): 73 | if max_gt[idx] >= self.min_pos_iou: 74 | positive_indices[argmax_gt[idx]] = 1 75 | 76 | # calculate classification loss 77 | cls_targets = (torch.ones(classification.shape) * -1).cuda(device=device) 78 | cls_targets[torch.lt(iou_max, self.neg_iou_thr), :] = 0 79 | num_positive_anchors = positive_indices.sum() 80 | assigned_annotations = bbox_annotation[iou_argmax, :] 81 | cls_targets[positive_indices, :] = 0 82 | cls_targets[positive_indices, assigned_annotations[positive_indices, 5].long()] = 1 83 | alpha_factor = torch.ones(cls_targets.shape).cuda(device=device) * self.alpha 84 | alpha_factor = torch.where(torch.eq(cls_targets, 1.), alpha_factor, 1. - alpha_factor) 85 | focal_weight = torch.where(torch.eq(cls_targets, 1.), 1. - classification, classification) 86 | focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma) 87 | # bin_cross_entropy = -(cls_targets * torch.log(classification + 1e-6) + (1.0 - cls_targets) * torch.log( 88 | # 1.0 - classification + 1e-6)) 89 | bin_cross_entropy = -(cls_targets * torch.log(classification) + (1.0 - cls_targets) * torch.log( 90 | 1.0 - classification)) 91 | cls_loss = focal_weight * bin_cross_entropy 92 | cls_loss = torch.where(torch.ne(cls_targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda(device=device)) 93 | cls_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.float(), min=1.0)) 94 | 95 | # calculate regression loss 96 | if positive_indices.sum() > 0: 97 | all_rois = anchor[positive_indices, :] 98 | gt_boxes = assigned_annotations[positive_indices, :] 99 | reg_targets = self.box_coder.encode(all_rois, gt_boxes) 100 | reg_loss = self.criteron(regression[positive_indices, :], reg_targets) 101 | reg_losses.append(reg_loss) 102 | else: 103 | reg_losses.append(torch.tensor(0).float().cuda(device=device)) 104 | loss_cls = torch.stack(cls_losses).mean(dim=0, keepdim=True) 105 | loss_reg = torch.stack(reg_losses).mean(dim=0, keepdim=True) 106 | return loss_cls, loss_reg 107 | 108 | 109 | def smooth_l1_loss(inputs, 110 | targets, 111 | beta=1. / 9, 112 | size_average=True, 113 | weight=None): 114 | """https://github.com/facebookresearch/maskrcnn-benchmark""" 115 | diff = torch.abs(inputs - targets) 116 | if weight is None: 117 | loss = torch.where( 118 | diff < beta, 119 | 0.5 * diff ** 2 / beta, 120 | diff - 0.5 * beta 121 | ) 122 | else: 123 | loss = torch.where( 124 | diff < beta, 125 | 0.5 * diff ** 2 / beta, 126 | diff - 0.5 * beta 127 | ) * weight.max(1)[0].unsqueeze(1).repeat(1,5) 128 | if size_average: 129 | return loss.mean() 130 | return loss.sum() 131 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.anchors import Anchors 4 | from models.fpn import FPN, LastLevelP6_P7 5 | from models import resnet 6 | from models.heads import CLSBranch, REGBranch, CLSHead, REGHead 7 | from models.losses import IntegratedLoss 8 | from utils.utils import clip_boxes 9 | from utils.box_coder import BoxCoder 10 | from utils.rotation_nms.cpu_nms import cpu_nms 11 | import math 12 | import cv2 13 | import numpy as np 14 | 15 | 16 | class RetinaNet(nn.Module): 17 | def __init__(self, params): 18 | super(RetinaNet, self).__init__() 19 | self.num_class = len(params.classes) 20 | self.num_regress = 5 21 | self.anchor_generator = Anchors(params) 22 | self.num_anchors = self.anchor_generator.num_anchors 23 | self.pretrained = params.backbone['pretrained'] 24 | self.init_backbone(params.backbone['type']) 25 | self.cls_branch_num_stacked = params.head['num_stacked'] 26 | self.rotation_nms_thr = params.rotation_nms_thr 27 | self.score_thr = params.score_thr 28 | 29 | self.fpn = FPN( 30 | in_channel_list=self.fpn_in_channels, 31 | out_channels=256, 32 | top_blocks=LastLevelP6_P7(in_channels=256, 33 | out_channels=256, 34 | init_method=params.neck['extra_conv_init_method']), # in_channels: 1) 2048 on C5, 2) 256 on P5 35 | init_method=params.neck['init_method']) 36 | 37 | self.cls_branch = CLSBranch( 38 | in_channels=256, 39 | feat_channels=256, 40 | num_stacked=self.cls_branch_num_stacked, 41 | init_method=params.head['cls_branch_init_method'] 42 | ) 43 | 44 | self.cls_head = CLSHead( 45 | feat_channels=256, 46 | num_anchors=self.num_anchors, 47 | num_classes=self.num_class 48 | ) 49 | 50 | self.reg_branch = REGBranch( 51 | in_channels=256, 52 | feat_channels=256, 53 | num_stacked=self.cls_branch_num_stacked, 54 | init_method=params.head['reg_branch_init_method'] 55 | ) 56 | 57 | self.reg_head = REGHead( 58 | feat_channels=256, 59 | num_anchors=self.num_anchors, 60 | num_regress=self.num_regress # x, y, w, h, angle 61 | ) 62 | 63 | self.loss = IntegratedLoss(params) 64 | 65 | self.box_coder = BoxCoder() 66 | 67 | def init_backbone(self, backbone): 68 | if backbone == 'resnet34': 69 | print(f'[Info]: Use Backbone is {backbone}.') 70 | self.backbone = resnet.resnet34(pretrained=self.pretrained) 71 | self.fpn_in_channels = [128, 256, 512] 72 | 73 | elif backbone == 'resnet50': 74 | print(f'[Info]: Use Backbone is {backbone}.') 75 | self.backbone = resnet.resnet50(pretrained=self.pretrained) 76 | self.fpn_in_channels = [512, 1024, 2048] 77 | 78 | elif backbone == 'resnet101': 79 | print(f'[Info]: Use Backbone is {backbone}.') 80 | self.backbone = resnet.resnet101(pretrained=self.pretrained) 81 | self.fpn_in_channels = [512, 1024, 2048] 82 | 83 | elif backbone == 'resnet152': 84 | print(f'[Info]: Use Backbone is {backbone}.') 85 | self.backbone = resnet.resnet101(pretrained=self.pretrained) 86 | self.fpn_in_channels = [512, 1024, 2048] 87 | else: 88 | raise NotImplementedError 89 | 90 | del self.backbone.avgpool 91 | del self.backbone.fc 92 | 93 | def backbone_output(self, imgs): 94 | feature = self.backbone.relu(self.backbone.bn1(self.backbone.conv1(imgs))) 95 | c2 = self.backbone.layer1(self.backbone.maxpool(feature)) 96 | c3 = self.backbone.layer2(c2) 97 | c4 = self.backbone.layer3(c3) 98 | c5 = self.backbone.layer4(c4) 99 | return [c3, c4, c5] 100 | 101 | def forward(self, images, annots=None, image_names=None, test_conf=None): 102 | anchors_list, offsets_list = [], [] 103 | original_anchors, num_level_anchors = self.anchor_generator(images) 104 | anchors_list.append(original_anchors) 105 | 106 | features = self.fpn(self.backbone_output(images)) 107 | 108 | cls_score = torch.cat([self.cls_head(self.cls_branch(feature)) for feature in features], dim=1) 109 | bbox_pred = torch.cat([self.reg_head(self.reg_branch(feature), with_deform=False) 110 | for feature in features], dim=1) 111 | 112 | # get the predicted bboxes 113 | # predicted_boxes = torch.cat( 114 | # [self.box_coder.decode(anchors_list[-1][index], bbox_pred[index]).unsqueeze(0) 115 | # for index in range(len(bbox_pred))], dim=0).detach() 116 | 117 | if self.training: 118 | # Max IoU Assigner with Focal Loss and Smooth L1 loss 119 | loss_cls, loss_reg = self.loss(cls_score, # cls_score with all levels 120 | bbox_pred, # bbox_pred with all levels 121 | anchors_list[-1], 122 | annots, 123 | image_names) 124 | 125 | return loss_cls, loss_reg 126 | 127 | else: # for model eval() 128 | return self.decoder(images, anchors_list[-1], cls_score, bbox_pred, 129 | thresh=self.score_thr, nms_thresh=self.rotation_nms_thr, test_conf=test_conf) 130 | 131 | def decoder(self, ims, anchors, cls_score, bbox_pred, 132 | thresh=0.6, nms_thresh=0.1, test_conf=None): 133 | """ 134 | Args: 135 | thresh: equal to score_thr. 136 | nms_thresh: nms_thr. 137 | test_conf: equal to thresh. 138 | """ 139 | if test_conf is not None: 140 | thresh = test_conf 141 | bboxes = self.box_coder.decode(anchors, bbox_pred) # bboxes: [pred_xc, pred_yc, pred_h, pred_w, pred_angle(radian)] 142 | # bboxes = clip_boxes(bboxes, ims) 143 | scores = torch.max(cls_score, dim=2, keepdim=True)[0] 144 | keep = (scores >= thresh)[0, :, 0] 145 | if keep.sum() == 0: 146 | return [torch.zeros(1), torch.zeros(1), torch.zeros(1, 5)] 147 | scores = scores[:, keep, :] 148 | anchors = anchors[:, keep, :] 149 | cls_score = cls_score[:, keep, :] 150 | bboxes = bboxes[:, keep, :] 151 | 152 | # NMS 153 | anchors_nms_idx = cpu_nms(torch.cat([bboxes, scores], dim=2)[0, :, :].cpu().detach().numpy(), nms_thresh) 154 | nms_scores, nms_class = cls_score[0, anchors_nms_idx, :].max(dim=1) 155 | output_boxes = torch.cat([ 156 | bboxes[0, anchors_nms_idx, :], 157 | anchors[0, anchors_nms_idx, :]], 158 | dim=1 159 | ) 160 | return [nms_scores, nms_class, output_boxes] 161 | 162 | def freeze_bn(self): 163 | """Set BN.eval(), BN is in the model's Backbone. """ 164 | for layer in self.backbone.modules(): 165 | if isinstance(layer, nn.BatchNorm2d): 166 | # is only used to make the bn.running_mean and running_var not change in training phase. 167 | layer.eval() 168 | 169 | # freeze the bn.weight and bn.bias which are two learnable params in BN Layer. 170 | # layer.weight.requires_grad = False 171 | # layer.bias.requires_grad = False 172 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import os 6 | 7 | model_urls = { 8 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', 9 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', 10 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 11 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 12 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', 13 | } 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=dilation, groups=groups, bias=False, dilation=dilation) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 29 | base_width=64, dilation=1, norm_layer=None): 30 | super(BasicBlock, self).__init__() 31 | if norm_layer is None: 32 | norm_layer = nn.BatchNorm2d 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = norm_layer(planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = norm_layer(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | identity = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | identity = self.downsample(x) 57 | 58 | out += identity 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 68 | base_width=64, dilation=1, norm_layer=None): 69 | super(Bottleneck, self).__init__() 70 | if norm_layer is None: 71 | norm_layer = nn.BatchNorm2d 72 | width = int(planes * (base_width / 64.)) * groups 73 | 74 | self.conv1 = conv1x1(inplanes, width) 75 | self.bn1 = norm_layer(width) 76 | 77 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 78 | self.bn2 = norm_layer(width) 79 | 80 | self.conv3 = conv1x1(width, planes * self.expansion) 81 | self.bn3 = norm_layer(planes * self.expansion) 82 | 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | 104 | out += identity 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet(nn.Module): 111 | def __init__(self, block, layers, num_classes=1000): 112 | 113 | self.inplanes = 64 114 | super(ResNet, self).__init__() 115 | 116 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 117 | self.bn1 = nn.BatchNorm2d(64) 118 | self.relu = nn.ReLU(inplace=True) 119 | 120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 121 | 122 | self.layer1 = self._make_layer(block, 64, layers[0]) 123 | 124 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 125 | 126 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 127 | 128 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 129 | 130 | self.avgpool = nn.AvgPool2d(7) 131 | self.fc = nn.Linear(512 * block.expansion, num_classes) 132 | 133 | for m in self.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 136 | m.weight.data.normal_(0, math.sqrt(2. / n)) 137 | elif isinstance(m, nn.BatchNorm2d): 138 | m.weight.data.fill_(1) 139 | m.bias.data.zero_() 140 | 141 | def _make_layer(self, block, planes, blocks, stride=1): 142 | downsample = None 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | nn.Conv2d(self.inplanes, planes * block.expansion, 146 | kernel_size=1, stride=stride, bias=False), 147 | nn.BatchNorm2d(planes * block.expansion), 148 | ) 149 | 150 | layers = [] 151 | layers.append(block(self.inplanes, planes, stride, downsample)) 152 | self.inplanes = planes * block.expansion 153 | for i in range(1, blocks): 154 | layers.append(block(self.inplanes, planes)) 155 | 156 | return nn.Sequential(*layers) 157 | 158 | def forward(self, x): 159 | x = self.conv1(x) 160 | x = self.bn1(x) 161 | x = self.relu(x) 162 | x = self.maxpool(x) 163 | 164 | x = self.layer1(x) 165 | x = self.layer2(x) 166 | x = self.layer3(x) 167 | x = self.layer4(x) 168 | 169 | x = self.avgpool(x) 170 | x = x.view(x.size(0), -1) 171 | x = self.fc(x) 172 | 173 | return x 174 | 175 | 176 | def resnet18(pretrained=False, **kwargs): 177 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 178 | if pretrained: 179 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'], model_dir='model_data'), strict=False) 180 | return model 181 | 182 | 183 | def resnet34(pretrained=False, **kwargs): 184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'], model_dir='model_data'), strict=False) 187 | return model 188 | 189 | 190 | def resnet50(pretrained=False, **kwargs): 191 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 192 | if pretrained: 193 | dir_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 194 | weight_path = dir_path + '/resnet_pretrained_pth/resnet50-0676ba61.pth' 195 | if os.path.exists(weight_path): 196 | model.load_state_dict(torch.load(weight_path), strict=False) 197 | else: 198 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], model_dir='model_data'), strict=False) 199 | return model 200 | 201 | 202 | def resnet101(pretrained=False, **kwargs): 203 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 204 | if pretrained: 205 | dir_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 206 | weight_path = dir_path + '/resnet_pretrained_pth/resnet101-5d3b4d8f.pth' 207 | if os.path.exists(weight_path): 208 | model.load_state_dict(torch.load(weight_path), strict=False) 209 | else: 210 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'], model_dir='model_data'), strict=False) 211 | return model 212 | 213 | 214 | def resnet152(pretrained=False, **kwargs): 215 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 216 | if pretrained: 217 | dir_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 218 | weight_path = dir_path + '/resnet_pretrained_pth/resnet152-b121ed2d.pth' 219 | if os.path.exists(weight_path): 220 | model.load_state_dict(torch.load(weight_path), strict=False) 221 | else: 222 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'], model_dir='model_data'), strict=False) 223 | return model 224 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colorlog==6.6.0 2 | Cython==0.29.28 3 | matplotlib==3.5.1 4 | numpy==1.21.2 5 | opencv_python==4.5.5.64 6 | Pillow==9.0.1 7 | PyYAML==6.0 8 | setuptools==58.0.4 9 | Shapely==1.8.1.post1 10 | tensorboardX==2.5 11 | torch==1.7.0 12 | torchvision==0.8.0 13 | tqdm==4.63.0 14 | -------------------------------------------------------------------------------- /resnet_pretrained_pth/.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | -------------------------------------------------------------------------------- /resnet_pretrained_pth/README.md: -------------------------------------------------------------------------------- 1 | ### Put the pretrained resnet-50/101/152 weight file here. 2 | -------------------------------------------------------------------------------- /resource/HRSC_Result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/resource/HRSC_Result.png -------------------------------------------------------------------------------- /resource/RSSDD_Result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/resource/RSSDD_Result.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import platform 2 | from setuptools import Extension, setup 3 | import os 4 | import numpy as np 5 | from Cython.Build import cythonize 6 | from torch.utils.cpp_extension import BuildExtension 7 | 8 | 9 | def make_cython_ext(name, module, sources): 10 | extra_compile_args = None 11 | if platform.system() != 'Windows': 12 | extra_compile_args = { 13 | 'cxx': ['-Wno-unused-function', '-Wno-write-strings'] 14 | } 15 | 16 | extension = Extension( 17 | '{}.{}'.format(module, name), 18 | [os.path.join(*module.split('.'), p) for p in sources], 19 | include_dirs=[np.get_include()], 20 | language='c++', 21 | extra_compile_args=extra_compile_args) 22 | extension, = cythonize(extension) 23 | return extension 24 | 25 | 26 | if __name__ == '__main__': 27 | setup( 28 | name='extension', 29 | ext_modules=[ 30 | make_cython_ext( 31 | name='rbox_overlaps', 32 | module='utils.rotation_overlaps', 33 | sources=['rbox_overlaps.pyx']), 34 | 35 | make_cython_ext( 36 | name='cpu_nms', 37 | module='utils.rotation_nms', 38 | sources=['cpu_nms.pyx']), 39 | ], 40 | cmdclass={'build_ext': BuildExtension}, 41 | ) -------------------------------------------------------------------------------- /show.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from models.model import RetinaNet 3 | import os 4 | import cv2 5 | import torch 6 | from detect import im_detect 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import math 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--backbone', type=str, default='resnet50') 15 | parser.add_argument('--config_file', type=str, default='./configs/retinanet_r50_fpn_ssdd.yml') 16 | parser.add_argument('--target_sizes', type=list, default=[512], help='the size of the input image.') 17 | parser.add_argument('--chkpt', type=str, default='best/best.pth', help='the chkpt file name') 18 | parser.add_argument('--result_path', type=str, default='show_result', help='the relative path for saving' 19 | 'ori pic and predicted pic') 20 | parser.add_argument('--score_thresh', type=float, default=0.05, help='score threshold') 21 | parser.add_argument('--pic_name', type=str, default='demo6.jpg', help='relative path') 22 | parser.add_argument('--device', type=int, default=1) 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def plot_box(image, coord, label_index=None, score=None, color=None, line_thickness=None): 28 | bbox_color = [226, 43, 138] if color is None else color 29 | text_color = [255, 255, 255] 30 | line_thickness = 1 if line_thickness is None else line_thickness 31 | xc, yc, h, w, ag = coord[:5] 32 | wx, wy = -w / 2 * math.sin(ag), w / 2 * math.cos(ag) 33 | hx, hy = h / 2 * math.cos(ag), h / 2 * math.sin(ag) 34 | p1 = (xc - wx - hx, yc - wy - hy) 35 | p2 = (xc - wx + hx, yc - wy + hy) 36 | p3 = (xc + wx + hx, yc + wy + hy) 37 | p4 = (xc + wx - hx, yc + wy - hy) 38 | ps = np.int0(np.array([p1, p2, p3, p4])) 39 | cv2.drawContours(image, [ps], -1, bbox_color, thickness=3) 40 | if label_index is not None: 41 | label_text = params.classes[label_index] 42 | label_text += '|{:.02f}'.format(score) 43 | font = cv2.FONT_HERSHEY_COMPLEX 44 | text_size = cv2.getTextSize(label_text, font, fontScale=0.25, thickness=line_thickness) 45 | text_width = text_size[0][0] 46 | text_height = text_size[0][1] 47 | try: 48 | cv2.rectangle(image, (int(xc), int(yc) - text_height -2), 49 | (int(xc) + text_width, int(yc) + 3), (0, 128, 0), -1) 50 | cv2.putText(image, label_text, (int(xc), int(yc)), font, 0.25, text_color, thickness=1) 51 | except: 52 | print(f'{coord} is wrong!') 53 | 54 | 55 | def show_pred_box(args, params): 56 | # create folder 57 | if not os.path.exists(args.result_path): 58 | os.makedirs(args.result_path) 59 | 60 | model = RetinaNet(params) 61 | chkpt_path = os.path.join(params.output_path, 'checkpoints', args.chkpt) 62 | chkpt = torch.load(chkpt_path, map_location='cpu') 63 | print(f"The current model training {chkpt['epoch']} epoch(s)") 64 | print(f"The current model mAP: {chkpt['best_fitness']} based on test_conf={params.score_thr} & nms_thr={params.nms_thr}") 65 | 66 | model.load_state_dict(chkpt['model']) 67 | model.cuda(device=args.device) 68 | model.eval() 69 | 70 | image = cv2.cvtColor(cv2.imread(os.path.join(args.result_path, args.pic_name), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 71 | 72 | dets = im_detect(model, 73 | image, 74 | target_sizes=args.target_sizes, 75 | params=params, 76 | use_gpu=True, 77 | conf=args.score_thresh, 78 | device=args.device) 79 | 80 | # dets: list[class_index, 0 81 | # score, 1 82 | # pred_xc, pred_yc, pred_w, pred_h, pred_angle(radian), 2 - 6 83 | # anchor_xc, anchor_yc, anchor_w, anchor_h, anchor_angle(radian)] 7 - 11 84 | for det in dets: 85 | cls_index = int(det[0]) 86 | score = float(det[1]) 87 | pred_box = det[2:7] 88 | anchor = det[7:12] 89 | 90 | # plot predict box 91 | plot_box(image, coord=pred_box, label_index=cls_index, score=score, color=None, 92 | line_thickness=4) 93 | 94 | # plot which anchor to create predict box 95 | # plot_box(image, coord=anchor, color=[0, 0, 255]) 96 | 97 | plt.imsave(os.path.join(args.result_path, f"{args.pic_name.split('.')[0]}_predict.png"), image) 98 | plt.imshow(image) 99 | plt.show() 100 | 101 | 102 | if __name__ == '__main__': 103 | from train import Params 104 | 105 | args = get_args() 106 | params = Params(args.config_file) 107 | if args.score_thresh != params.score_thr: 108 | print('[Info]: score_thresh is not equal to cfg.score_thr') 109 | params.backbone['pretrained'] = False 110 | show_pred_box(args, params) 111 | -------------------------------------------------------------------------------- /show_result/HRSC/demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/HRSC/demo1.jpg -------------------------------------------------------------------------------- /show_result/HRSC/demo2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/HRSC/demo2.jpg -------------------------------------------------------------------------------- /show_result/HRSC/demo3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/HRSC/demo3.jpg -------------------------------------------------------------------------------- /show_result/RSSDD/demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/RSSDD/demo1.jpg -------------------------------------------------------------------------------- /show_result/RSSDD/demo2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/RSSDD/demo2.jpg -------------------------------------------------------------------------------- /show_result/RSSDD/demo3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/RSSDD/demo3.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | # from datasets.HRSC_dataset import HRSCDataset 4 | from datasets.SSDD_dataset import SSDDataset 5 | from datasets.collater import Collater 6 | import torch.utils.data as data 7 | from utils.utils import set_random_seed, count_param 8 | from models.model import RetinaNet 9 | import torch.optim as optim 10 | from tqdm import tqdm 11 | import os 12 | from tensorboardX import SummaryWriter 13 | import datetime 14 | import torch.nn as nn 15 | from warmup import WarmupLR 16 | import yaml 17 | from pprint import pprint 18 | from eval import evaluate 19 | from Logger import Logger 20 | 21 | 22 | class Params: 23 | def __init__(self, project_file): 24 | self.filename = os.path.basename(project_file) 25 | self.params = yaml.safe_load(open(project_file).read()) 26 | 27 | def __getattr__(self, item): 28 | return self.params.get(item, None) 29 | 30 | def info(self): 31 | return '\n'.join([(f'{key}: {value}') for key, value in self.params.items()]) 32 | 33 | def save(self): 34 | with open(os.path.join(self.params.get('output_path'), f'{self.filename}'), 'w') as f: 35 | yaml.dump(self.params, f, sort_keys=False) 36 | 37 | def show(self): 38 | print('=================== Show Params =====================') 39 | pprint(self.params) 40 | 41 | 42 | def get_args(): 43 | parser = argparse.ArgumentParser('A Rotation Detector based on RetinaNet by PyTorch.') 44 | parser.add_argument('--config_file', type=str, default='./configs/retinanet_r50_fpn_{Dataset Name}.yml') 45 | parser.add_argument('--resume', type=str, 46 | # default='{epoch}_{step}.pth', 47 | default=None, # train from scratch 48 | help='the last checkpoint file.') 49 | args = parser.parse_args() 50 | return args 51 | 52 | 53 | def train(args, params): 54 | epochs = params.epoch 55 | if torch.cuda.is_available(): 56 | if len(params.device) == 1: 57 | device = params.device[0] 58 | else: 59 | print(f'[Info]: Traing with {params.device} GPUs') 60 | 61 | weight = '' 62 | if args.resume: 63 | weight = params.output_path + os.sep + params.checkpoint + os.sep + args.resume 64 | 65 | start_epoch = 0 66 | best_fitness = 0 67 | fitness = 0 68 | last_step = 0 69 | 70 | # create folder 71 | tensorboard_path = os.path.join(params.output_path, params.tensorboard) 72 | if not os.path.exists(tensorboard_path): 73 | os.makedirs(tensorboard_path) 74 | 75 | checkpoint_path = os.path.join(params.output_path, params.checkpoint) 76 | if not os.path.exists(checkpoint_path): 77 | os.makedirs(checkpoint_path) 78 | 79 | best_checkpoint_path = os.path.join(checkpoint_path, 'best') 80 | if not os.path.exists(best_checkpoint_path): 81 | os.makedirs(best_checkpoint_path) 82 | 83 | log_file_path = os.path.join(params.output_path, params.log) 84 | if os.path.isfile(log_file_path): 85 | os.remove(log_file_path) 86 | 87 | log = Logger(log_path=os.path.join(params.output_path, params.log), logging_name='R-RetinaNet') 88 | logger = log.logger_config() 89 | env_info = params.info() 90 | logger.info('Config info:\n' + log.dash_line + env_info + '\n' + log.dash_line) 91 | 92 | # save config yaml file 93 | params.save() 94 | 95 | train_dataset = SSDDataset(root_path=params.data_path, set_name='train', augment=params.augment, 96 | classes=params.classes) 97 | collater = Collater(scales=params.image_size, keep_ratio=params.keep_ratio, multiple=32) 98 | train_generator = data.DataLoader( 99 | dataset=train_dataset, 100 | batch_size=params.batch_size, 101 | num_workers=8, # 4 * number of the GPU 102 | collate_fn=collater, 103 | shuffle=True, 104 | pin_memory=True, 105 | drop_last=True) 106 | 107 | # Initialize model & set random seed 108 | set_random_seed(seed=42, deterministic=False) 109 | model = RetinaNet(params) 110 | count_param(model) 111 | 112 | # init tensorboardX 113 | writer = SummaryWriter(tensorboard_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/') 114 | 115 | # Optimizer Option 116 | optimizer = optim.Adam(model.parameters(), lr=params.lr) 117 | 118 | # Scheduler Option 119 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) 120 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[round(epochs * x) for x in [0.6, 0.8]], gamma=0.1) 121 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.94) 122 | 123 | # Warm-up 124 | is_warmup = False 125 | if params.warm_up and args.resume is None: 126 | print('[Info]: Launching Warmup.') 127 | scheduler = WarmupLR(scheduler, init_lr=params.warmup_lr, num_warmup=params.warmup_epoch, warmup_strategy='cos') 128 | is_warmup = True 129 | if is_warmup is False: 130 | print('[Info]: Not Launching Warmup.') 131 | 132 | if torch.cuda.is_available() and len(params.device) == 1: 133 | model = model.cuda(device=device) 134 | else: 135 | model = nn.DataParallel(model, device_ids=[0, 1], output_device=0) 136 | model.cuda() # put the model on the main card in the condition of the multi-gpus 137 | 138 | if args.resume: 139 | if weight.endswith('.pth'): 140 | chkpt = torch.load(weight) 141 | last_step = chkpt['step'] 142 | 143 | # Load model 144 | if 'model' in chkpt.keys(): 145 | model.load_state_dict(chkpt['model']) 146 | else: 147 | model.load_state_dict(chkpt) 148 | 149 | # Load optimizer 150 | if 'optimizer' in chkpt.keys() and chkpt['optimizer'] is not None: 151 | optimizer.load_state_dict(chkpt['optimizer']) 152 | best_fitness = chkpt['best_fitness'] 153 | for state in optimizer.state.values(): 154 | for k, v in state.items(): 155 | if isinstance(v, torch.Tensor): 156 | state[k] = v.cuda(device=device) 157 | 158 | # Load scheduler 159 | if 'scheduler' in chkpt.keys() and chkpt['scheduler'] is not None: 160 | scheduler_state = chkpt['scheduler'] 161 | scheduler._step_count = scheduler_state['step_count'] 162 | scheduler.last_epoch = scheduler_state['last_epoch'] 163 | 164 | start_epoch = chkpt['epoch'] + 1 165 | 166 | del chkpt 167 | 168 | # start training 169 | step = max(0, last_step) 170 | num_iter_per_epoch = len(train_generator) 171 | 172 | head_line = ('%10s' * 8) % ('Epoch', 'Steps', 'gpu_mem', 'cls', 'reg', 'total', 'targets', 'img_size') 173 | print(('\n' + '%10s' * 8) % ('Epoch', 'Steps', 'gpu_mem', 'cls', 'reg', 'total', 'targets', 'img_size')) 174 | logger.debug(head_line) 175 | 176 | if is_warmup: 177 | scheduler.step() 178 | for epoch in range(start_epoch, epochs): 179 | last_epoch = step // num_iter_per_epoch 180 | if epoch < last_epoch: 181 | continue 182 | pbar = tqdm(enumerate(train_generator), total=len(train_generator)) # progress bar 183 | 184 | # for each epoch, we set model.eval() to model.train() 185 | # and freeze backbone BN Layers parameters 186 | model.train() 187 | 188 | if params.freeze_bn and len(params.device) == 1: 189 | model.freeze_bn() 190 | else: 191 | model.module.freeze_bn() 192 | 193 | for iter, (ni, batch) in enumerate(pbar): 194 | 195 | if iter < step - last_epoch * num_iter_per_epoch: 196 | pbar.update() 197 | continue 198 | 199 | optimizer.zero_grad() 200 | images, annots, image_names = batch['image'], batch['bboxes'], batch['image_name'] 201 | if torch.cuda.is_available(): 202 | if len(params.device) == 1: 203 | images, annots = images.cuda(device=device), annots.cuda(device=device) 204 | else: 205 | images, annots = images.cuda(), annots.cuda() 206 | loss_cls, loss_reg = model(images, annots, image_names) 207 | 208 | # Using .mean() is following Ming71 and Zylo117 repo 209 | loss_cls = loss_cls.mean() 210 | loss_reg = loss_reg.mean() 211 | 212 | total_loss = loss_cls + loss_reg 213 | 214 | if not torch.isfinite(total_loss): 215 | print('[Warning]: loss is nan') 216 | break 217 | 218 | if bool(total_loss == 0): 219 | continue 220 | 221 | total_loss.backward() 222 | 223 | # Update parameters 224 | 225 | # if loss is not nan not using grad clip 226 | # nn.utils.clip_grad_norm_(model.parameters(), 0.1) 227 | 228 | optimizer.step() 229 | 230 | # print batch result 231 | if len(params.device) == 1: 232 | mem = torch.cuda.memory_reserved(device=device) / 1E9 if torch.cuda.is_available() else 0 233 | else: 234 | mem = sum(torch.cuda.memory_reserved(device=idx) for idx in range(len(params.device))) / 1E9 235 | 236 | s = ('%10s' * 3 + '%10.3g' * 4 + '%10s' * 1) % ( 237 | '%g/%g' % (epoch, epochs - 1), 238 | '%g' % iter, 239 | '%.3gG' % mem, loss_cls.item(), loss_reg.item(), total_loss.item(), annots.shape[1], 240 | '%gx%g' % (int(images.shape[2]), int(images.shape[3]))) 241 | 242 | pbar.set_description(s) 243 | 244 | # write loss info into tensorboard 245 | writer.add_scalars('Loss', {'train': total_loss}, step) 246 | writer.add_scalars('Regression_loss', {'train': loss_reg}, step) 247 | writer.add_scalars('Classfication_loss', {'train': loss_cls}, step) 248 | 249 | # write lr info into tensorboard 250 | current_lr = optimizer.param_groups[0]['lr'] 251 | writer.add_scalar('lr_per_step', current_lr, step) 252 | step = step + 1 253 | 254 | # Update scheduler / learning rate 255 | scheduler.step() 256 | logger.debug(s) 257 | 258 | final_epoch = epoch + 1 == epochs 259 | 260 | # # check the mAP on training set begin ------------------------------------------------ 261 | # if epoch >= params.evaluate_train_start and epoch % params.val_interval == 0: 262 | # test_path = 'train-ground-truth' 263 | # train_results = evaluate( 264 | # target_size=[params.image_size], 265 | # test_path=test_path, 266 | # eval_method=args.eval_method, 267 | # model=model, 268 | # conf=params.score_thr, 269 | # device=args.device, 270 | # mode='train') 271 | # 272 | # train_fitness = train_results[0] # Update best mAP 273 | # writer.add_scalar('train_mAP', train_fitness, epoch) 274 | # --------------------------end 275 | 276 | # save model 277 | # create checkpoint 278 | chkpt = {'epoch': epoch, 279 | 'step': step, 280 | 'best_fitness': best_fitness, 281 | 'model': model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel 282 | else model.state_dict(), 283 | 'optimizer': None if final_epoch else optimizer.state_dict(), 284 | 'scheduler': {'step_count': scheduler._step_count, 285 | 'last_epoch': scheduler.last_epoch} 286 | } 287 | 288 | # save interval checkpoint 289 | if epoch % params.save_interval == 0 and epoch >= 30: 290 | torch.save(chkpt, os.path.join(checkpoint_path, f'{epoch}_{step}.pth')) 291 | 292 | if epoch >= params.evaluation_val_start and epoch % params.val_interval == 0: 293 | test_path = 'ground-truth' 294 | model.eval() 295 | val_mAP, val_Precision, val_Recall = evaluate(model=model, 296 | target_size=params.image_size, 297 | test_path=test_path, 298 | conf=params.score_thr, 299 | device=device, 300 | mode='test', 301 | params=params) 302 | 303 | eval_line = ('%10s' * 7) % ('[%g/%g]' % (epoch, epochs - 1), 'Val mAP:', '%10.3f' % val_mAP, 304 | 'Precision:', '%10.3f' % val_Precision, 305 | 'Recall:', '%10.3f' % val_Recall) 306 | logger.debug(eval_line) 307 | 308 | fitness = val_mAP # Update best mAP 309 | 310 | if fitness > best_fitness: 311 | best_fitness = fitness 312 | 313 | # write mAP info into tensorboard 314 | writer.add_scalar('val_mAP', fitness, epoch) 315 | 316 | # save best checkpoint 317 | if best_fitness == fitness: 318 | torch.save(chkpt, os.path.join(best_checkpoint_path, 'best.pth')) 319 | 320 | # TensorboardX writer close 321 | writer.close() 322 | 323 | 324 | if __name__ == '__main__': 325 | # os.environ["CUDA_VISIBLE_DEVICES"] = '3, 2' # for multi-GPU 326 | from utils.utils import show_args 327 | args = get_args() 328 | params = Params(args.config_file) 329 | show_args(args) 330 | train(args, params) 331 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/augment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/augment.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/bbox_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/bbox_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/box_coder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/box_coder.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/map.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/map.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | class HorizontalFlip(object): 7 | """ 8 | Args: 9 | p: the probability of the horizontal flip 10 | """ 11 | def __init__(self, p=0.5): 12 | self.p = p 13 | 14 | def __call__(self, image, bboxes): 15 | """ 16 | Args: 17 | image: array([C, H, W]) 18 | bboxes: array (N, 8) :[[x1, y1, x2, y2, x3, y3, x4, y4] ... ] 19 | """ 20 | if random.random() < self.p: 21 | h, w, _ = image.shape 22 | image = np.array(np.fliplr(image)) 23 | for idx, single_box in enumerate(bboxes): 24 | bboxes[idx, 0::2] = w - single_box[0::2] 25 | return image, bboxes 26 | 27 | 28 | class VerticalFlip(object): 29 | """ 30 | Args: 31 | p: the probability of the vertical flip 32 | """ 33 | def __init__(self, p=0.5): 34 | self.p = p 35 | 36 | def __call__(self, image, bboxes): 37 | """ 38 | Args: 39 | image: array([C, H, W]) 40 | bboxes: list (N, 9) :[[x1, y1, x2, y2, x3, y3, x4, y4, class_index] ... ] 41 | """ 42 | if random.random() < self.p: 43 | h, w, _ = image.shape 44 | image = np.array(np.flipud(image)) 45 | for idx, single_box in enumerate(bboxes): 46 | bboxes[idx, 1::2] = h - single_box[1::2] 47 | return image, bboxes 48 | 49 | 50 | class HSV(object): 51 | def __init__(self, saturation=0, brightness=0, p=0.): 52 | self.saturation = saturation 53 | self.brightness = brightness 54 | self.p = p 55 | 56 | def __call__(self, image, bboxes, mode=None): 57 | if random.random() < self.p: 58 | img_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) # hue, sat, val 59 | S = img_hsv[:, :, 1].astype(np.float32) # saturation 60 | V = img_hsv[:, :, 2].astype(np.float32) # value 61 | a = random.uniform(-1, 1) * self.saturation + 1 62 | b = random.uniform(-1, 1) * self.brightness + 1 63 | S *= a 64 | V *= b 65 | img_hsv[:, :, 1] = S if a < 1 else S.clip(None, 255) 66 | img_hsv[:, :, 2] = V if b < 1 else V.clip(None, 255) 67 | cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=image) 68 | return image, bboxes 69 | 70 | 71 | class Augment(object): 72 | def __init__(self, transforms): 73 | self.transforms = transforms 74 | 75 | def __call__(self, image, bboxes): 76 | for transform in self.transforms: 77 | image, bboxes = transform(image, bboxes) 78 | return image, bboxes 79 | -------------------------------------------------------------------------------- /utils/bbox_transforms.py: -------------------------------------------------------------------------------- 1 | """ Original code is from: 2 | `https://github.com/open-mmlab/mmrotate/blob/main/mmrotate/core/bbox/transforms.py`""" 3 | 4 | import numpy as np 5 | import math 6 | import torch 7 | import cv2 8 | 9 | 10 | def swap_axis(tensor): 11 | if torch.is_tensor(tensor): 12 | swap_bbox = torch.zeros_like(tensor) 13 | swap_bbox[:, 0] = tensor[:, 0] 14 | swap_bbox[:, 1] = tensor[:, 1] 15 | swap_bbox[:, 2] = tensor[:, 3] 16 | swap_bbox[:, 3] = tensor[:, 2] 17 | swap_bbox[:, 4] = tensor[:, 4] 18 | else: 19 | swap_bbox = np.zeros_like(tensor) 20 | swap_bbox[:, 0] = tensor[:, 0] 21 | swap_bbox[:, 1] = tensor[:, 1] 22 | swap_bbox[:, 2] = tensor[:, 3] 23 | swap_bbox[:, 3] = tensor[:, 2] 24 | swap_bbox[:, 4] = tensor[:, 4] 25 | 26 | return swap_bbox 27 | 28 | 29 | def norm_angle(angle, angle_range): 30 | """Limit the range of angles. 31 | 32 | Args: 33 | angle (ndarray): shape(n, ). 34 | angle_range (Str): angle representations. 35 | Returns: 36 | angle (ndarray): shape(n, ). 37 | """ 38 | if angle_range == 'oc': 39 | return angle 40 | elif angle_range == 'le135': 41 | return (angle + np.pi / 4) % np.pi - np.pi / 4 42 | elif angle_range == 'le90': 43 | return (angle + np.pi / 2) % np.pi - np.pi / 2 44 | else: 45 | print('Not yet implemented.') 46 | 47 | 48 | def obb2poly_oc(rboxes): 49 | """Convert oriented bounding boxes to polygons. 50 | Args: 51 | obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle] 52 | Returns: 53 | polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3] 54 | """ 55 | x = rboxes[:, 0] 56 | y = rboxes[:, 1] 57 | w = rboxes[:, 2] 58 | h = rboxes[:, 3] 59 | a = rboxes[:, 4] 60 | cosa = torch.cos(a) 61 | sina = torch.sin(a) 62 | wx, wy = w / 2 * cosa, w / 2 * sina 63 | hx, hy = -h / 2 * sina, h / 2 * cosa 64 | p1x, p1y = x - wx - hx, y - wy - hy 65 | p2x, p2y = x + wx - hx, y + wy - hy 66 | p3x, p3y = x + wx + hx, y + wy + hy 67 | p4x, p4y = x - wx + hx, y - wy + hy 68 | return torch.stack([p1x, p1y, p2x, p2y, p3x, p3y, p4x, p4y], dim=-1) 69 | 70 | 71 | def obb2poly_le135(rboxes): 72 | """Convert oriented bounding boxes to polygons. 73 | Args: 74 | obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle] 75 | Returns: 76 | polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3] 77 | """ 78 | N = rboxes.shape[0] 79 | if N == 0: 80 | return rboxes.new_zeros((rboxes.size(0), 8)) 81 | x_ctr, y_ctr, width, height, angle = rboxes.select(1, 0), rboxes.select( 82 | 1, 1), rboxes.select(1, 2), rboxes.select(1, 3), rboxes.select(1, 4) 83 | tl_x, tl_y, br_x, br_y = \ 84 | -width * 0.5, -height * 0.5, \ 85 | width * 0.5, height * 0.5 86 | rects = torch.stack([tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y], 87 | dim=0).reshape(2, 4, N).permute(2, 0, 1) 88 | sin, cos = torch.sin(angle), torch.cos(angle) 89 | M = torch.stack([cos, -sin, sin, cos], dim=0).reshape(2, 2, 90 | N).permute(2, 0, 1) 91 | polys = M.matmul(rects).permute(2, 1, 0).reshape(-1, N).transpose(1, 0) 92 | polys[:, ::2] += x_ctr.unsqueeze(1) 93 | polys[:, 1::2] += y_ctr.unsqueeze(1) 94 | return polys.contiguous() 95 | 96 | 97 | def obb2poly_le90(rboxes): 98 | """Convert oriented bounding boxes to polygons. 99 | Args: 100 | obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle] 101 | Returns: 102 | polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3] 103 | """ 104 | N = rboxes.shape[0] 105 | if N == 0: 106 | return rboxes.new_zeros((rboxes.size(0), 8)) 107 | x_ctr, y_ctr, width, height, angle = rboxes.select(1, 0), rboxes.select( 108 | 1, 1), rboxes.select(1, 2), rboxes.select(1, 3), rboxes.select(1, 4) 109 | tl_x, tl_y, br_x, br_y = \ 110 | -width * 0.5, -height * 0.5, \ 111 | width * 0.5, height * 0.5 112 | rects = torch.stack([tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y], 113 | dim=0).reshape(2, 4, N).permute(2, 0, 1) 114 | sin, cos = torch.sin(angle), torch.cos(angle) 115 | M = torch.stack([cos, -sin, sin, cos], dim=0).reshape(2, 2, 116 | N).permute(2, 0, 1) 117 | polys = M.matmul(rects).permute(2, 1, 0).reshape(-1, N).transpose(1, 0) 118 | polys[:, ::2] += x_ctr.unsqueeze(1) 119 | polys[:, 1::2] += y_ctr.unsqueeze(1) 120 | return polys.contiguous() 121 | 122 | 123 | def cal_line_length(point1, point2): 124 | """Calculate the length of line. 125 | Args: 126 | point1 (List): [x,y] 127 | point2 (List): [x,y] 128 | Returns: 129 | length (float) 130 | """ 131 | return math.sqrt( 132 | math.pow(point1[0] - point2[0], 2) + 133 | math.pow(point1[1] - point2[1], 2)) 134 | 135 | 136 | def get_best_begin_point_single(coordinate): 137 | """Get the best begin point of the single polygon. 138 | Args: 139 | coordinate (List): [x1, y1, x2, y2, x3, y3, x4, y4, score] 140 | Returns: 141 | reorder coordinate (List): [x1, y1, x2, y2, x3, y3, x4, y4, score] 142 | """ 143 | x1, y1, x2, y2, x3, y3, x4, y4, score = coordinate 144 | xmin = min(x1, x2, x3, x4) 145 | ymin = min(y1, y2, y3, y4) 146 | xmax = max(x1, x2, x3, x4) 147 | ymax = max(y1, y2, y3, y4) 148 | combine = [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], 149 | [[x2, y2], [x3, y3], [x4, y4], [x1, y1]], 150 | [[x3, y3], [x4, y4], [x1, y1], [x2, y2]], 151 | [[x4, y4], [x1, y1], [x2, y2], [x3, y3]]] 152 | dst_coordinate = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] 153 | force = 100000000.0 154 | force_flag = 0 155 | for i in range(4): 156 | temp_force = cal_line_length(combine[i][0], dst_coordinate[0]) \ 157 | + cal_line_length(combine[i][1], dst_coordinate[1]) \ 158 | + cal_line_length(combine[i][2], dst_coordinate[2]) \ 159 | + cal_line_length(combine[i][3], dst_coordinate[3]) 160 | if temp_force < force: 161 | force = temp_force 162 | force_flag = i 163 | if force_flag != 0: 164 | pass 165 | return np.hstack( 166 | (np.array(combine[force_flag]).reshape(8), np.array(score))) 167 | 168 | 169 | def get_best_begin_point(coordinates): 170 | """Get the best begin points of polygons. 171 | Args: 172 | coordinate (ndarray): shape(n, 9). 173 | Returns: 174 | reorder coordinate (ndarray): shape(n, 9). 175 | """ 176 | coordinates = list(map(get_best_begin_point_single, coordinates.tolist())) 177 | coordinates = np.array(coordinates) 178 | return coordinates 179 | 180 | 181 | def obb2poly_np_oc(rbboxes): 182 | """ Modified ! 183 | Convert oriented bounding boxes to polygons. 184 | Args: 185 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle,score] modify-> [x_ctr, y_ctr, h, w, angle(radian), score] 186 | Returns: 187 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3,score] 188 | """ 189 | x = rbboxes[:, 0] 190 | y = rbboxes[:, 1] 191 | h = rbboxes[:, 2] 192 | w = rbboxes[:, 3] 193 | a = rbboxes[:, 4] 194 | score = rbboxes[:, 5] 195 | 196 | cosa = np.cos(a) 197 | sina = np.sin(a) 198 | wx, wy = -w / 2 * sina, w / 2 * cosa 199 | hx, hy = h / 2 * cosa, h / 2 * sina 200 | p1x, p1y = x - wx - hx, y - wy - hy 201 | p2x, p2y = x - wx + hx, y - wy + hy 202 | p3x, p3y = x + wx + hx, y + wy + hy 203 | p4x, p4y = x + wx - hx, y + wy - hy 204 | polys = np.stack([p1x, p1y, p2x, p2y, p3x, p3y, p4x, p4y, score], axis=-1) 205 | polys = get_best_begin_point(polys) 206 | return polys 207 | 208 | 209 | def obb2poly_np_le135(rrects): 210 | """Convert oriented bounding boxes to polygons. 211 | Args: 212 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle,score] 213 | Returns: 214 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3,score] 215 | """ 216 | polys = [] 217 | for rrect in rrects: 218 | x_ctr, y_ctr, width, height, angle, score = rrect[:6] 219 | tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2 220 | rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]]) 221 | R = np.array([[np.cos(angle), -np.sin(angle)], 222 | [np.sin(angle), np.cos(angle)]]) 223 | poly = R.dot(rect) 224 | x0, x1, x2, x3 = poly[0, :4] + x_ctr 225 | y0, y1, y2, y3 = poly[1, :4] + y_ctr 226 | poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3, score], 227 | dtype=np.float32) 228 | polys.append(poly) 229 | polys = np.array(polys) 230 | polys = get_best_begin_point(polys) 231 | return polys 232 | 233 | 234 | def obb2poly_np_le90(obboxes): 235 | """Convert oriented bounding boxes to polygons. 236 | Args: 237 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle,score] 238 | Returns: 239 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3,score] 240 | """ 241 | try: 242 | center, w, h, theta, score = np.split(obboxes, (2, 3, 4, 5), axis=-1) 243 | except: # noqa: E722 244 | results = np.stack([0., 0., 0., 0., 0., 0., 0., 0., 0.], axis=-1) 245 | return results.reshape(1, -1) 246 | Cos, Sin = np.cos(theta), np.sin(theta) 247 | vector1 = np.concatenate([w / 2 * Cos, w / 2 * Sin], axis=-1) 248 | vector2 = np.concatenate([-h / 2 * Sin, h / 2 * Cos], axis=-1) 249 | point1 = center - vector1 - vector2 250 | point2 = center + vector1 - vector2 251 | point3 = center + vector1 + vector2 252 | point4 = center - vector1 + vector2 253 | polys = np.concatenate([point1, point2, point3, point4, score], axis=-1) 254 | polys = get_best_begin_point(polys) 255 | return polys 256 | 257 | 258 | def obb2poly(rbboxes, version='oc'): 259 | """Convert oriented bounding boxes to polygons. 260 | Args: 261 | obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle] 262 | version (Str): angle representations. 263 | Returns: 264 | polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3] 265 | """ 266 | if version == 'oc': 267 | results = obb2poly_oc(rbboxes) 268 | elif version == 'le135': 269 | results = obb2poly_le135(rbboxes) 270 | elif version == 'le90': 271 | results = obb2poly_le90(rbboxes) 272 | else: 273 | raise NotImplementedError 274 | return results 275 | 276 | 277 | def obb2poly_np(rbboxes, version='oc'): 278 | """Convert oriented bounding boxes to polygons. 279 | Args: 280 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle] 281 | version (Str): angle representations. 282 | Returns: 283 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3] 284 | """ 285 | if version == 'oc': 286 | results = obb2poly_np_oc(rbboxes) 287 | elif version == 'le135': 288 | results = obb2poly_np_le135(rbboxes) 289 | elif version == 'le90': 290 | results = obb2poly_np_le90(rbboxes) 291 | else: 292 | raise NotImplementedError 293 | return results 294 | 295 | 296 | def poly2obb_np(polys, version='oc'): 297 | """Convert polygons to oriented bounding boxes. 298 | Args: 299 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3] 300 | version (Str): angle representations. 301 | Returns: 302 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle] 303 | """ 304 | if version == 'oc': 305 | results = poly2obb_np_oc(polys) 306 | elif version == 'le135': 307 | results = poly2obb_np_le135(polys) 308 | elif version == 'le90': 309 | results = poly2obb_np_le90(polys) 310 | else: 311 | raise NotImplementedError 312 | return results 313 | 314 | 315 | def poly2obb_np_oc(poly): 316 | """ Modified !! 317 | Convert polygons to oriented bounding boxes. 318 | Args: 319 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3] 320 | Returns: 321 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle] modified -> [x_ctr, y_ctr, h, w, angle(radian)] 322 | """ 323 | bboxps = np.array(poly).reshape((4, 2)) 324 | rbbox = cv2.minAreaRect(bboxps) 325 | x, y, h, w, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[2] 326 | # assert 0 < a <= 90, f'error from poly2obb_np_oc function.' 327 | if w < 2 or h < 2: 328 | return 329 | while not 0 < a <= 90: 330 | if a == -90: 331 | a += 180 332 | else: 333 | a += 90 334 | w, h = h, w 335 | a = a / 180 * np.pi 336 | assert 0 < a <= np.pi / 2 337 | return x, y, h, w, a 338 | 339 | 340 | def poly2obb_np_le135(poly): 341 | """Convert polygons to oriented bounding boxes. 342 | Args: 343 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3] 344 | Returns: 345 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle] 346 | """ 347 | poly = np.array(poly[:8], dtype=np.float32) 348 | pt1 = (poly[0], poly[1]) 349 | pt2 = (poly[2], poly[3]) 350 | pt3 = (poly[4], poly[5]) 351 | pt4 = (poly[6], poly[7]) 352 | edge1 = np.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[1] - pt2[1]) * 353 | (pt1[1] - pt2[1])) 354 | edge2 = np.sqrt((pt2[0] - pt3[0]) * (pt2[0] - pt3[0]) + (pt2[1] - pt3[1]) * 355 | (pt2[1] - pt3[1])) 356 | if edge1 < 2 or edge2 < 2: 357 | return 358 | width = max(edge1, edge2) 359 | height = min(edge1, edge2) 360 | angle = 0 361 | if edge1 > edge2: 362 | angle = np.arctan2(float(pt2[1] - pt1[1]), float(pt2[0] - pt1[0])) 363 | elif edge2 >= edge1: 364 | angle = np.arctan2(float(pt4[1] - pt1[1]), float(pt4[0] - pt1[0])) 365 | angle = norm_angle(angle, 'le135') 366 | x_ctr = float(pt1[0] + pt3[0]) / 2 367 | y_ctr = float(pt1[1] + pt3[1]) / 2 368 | return x_ctr, y_ctr, width, height, angle 369 | 370 | 371 | def poly2obb_np_le90(poly): 372 | """Convert polygons to oriented bounding boxes. 373 | Args: 374 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3] 375 | Returns: 376 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle] 377 | """ 378 | bboxps = np.array(poly).reshape((4, 2)) 379 | rbbox = cv2.minAreaRect(bboxps) 380 | x, y, w, h, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[ 381 | 2] 382 | if w < 2 or h < 2: 383 | return 384 | a = a / 180 * np.pi 385 | if w < h: 386 | w, h = h, w 387 | a += np.pi / 2 388 | while not np.pi / 2 > a >= -np.pi / 2: 389 | if a >= np.pi / 2: 390 | a -= np.pi 391 | else: 392 | a += np.pi 393 | assert np.pi / 2 > a >= -np.pi / 2 394 | return x, y, w, h, a 395 | 396 | 397 | def obb2hbb_oc(rbboxes): 398 | """ Modified ! 399 | Convert oriented bounding boxes to horizontal bounding boxes. 400 | 401 | Args: 402 | obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle] modify -> [x_ctr, y_ctr, h, w, angle(radian)] 403 | Returns: 404 | hbbs (torch.Tensor): [x_ctr,y_ctr,w,h,pi/2] 405 | """ 406 | h = rbboxes[:, 2::5] 407 | w = rbboxes[:, 3::5] 408 | a = rbboxes[:, 4::5] 409 | cosa = torch.cos(a) 410 | sina = torch.sin(a) 411 | hbbox_h = cosa * w + sina * h 412 | hbbox_w = sina * w + cosa * h 413 | hbboxes = rbboxes.clone().detach() 414 | hbboxes[:, 2::5] = hbbox_w 415 | hbboxes[:, 3::5] = hbbox_h 416 | hbboxes[:, 4::5] = np.pi / 2 417 | return hbboxes 418 | -------------------------------------------------------------------------------- /utils/box_coder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class BoxCoder(object): 6 | """ 7 | This class encodes and decodes a set of bounding boxes into 8 | the representation used for training the regressors. 9 | 10 | Args: 11 | encode() function: 12 | ex_rois: positive anchors: [xc, yc, w, h, angle(radian)] 13 | gt_rois: positive anchor ground-truth box: [xc, yc, h, w, angle(radian)] 14 | 15 | decode() function: 16 | boxes: anchors: [xc, yc, w, h, angle(radian)] 17 | deltas: offset: [xc_offset, yc_offset, h_offset, w_offset, angle_offset(radian)] 18 | """ 19 | def __init__(self, means=(0., 0., 0., 0., 0.), stds=(0.1, 0.1, 0.1, 0.1, 0.05)): 20 | self.means = means 21 | self.stds = stds 22 | 23 | def encode(self, ex_rois, gt_rois): 24 | ex_widths = ex_rois[:, 2] 25 | ex_heights = ex_rois[:, 3] 26 | ex_widths = torch.clamp(ex_widths, min=1) 27 | ex_heights = torch.clamp(ex_heights, min=1) 28 | ex_ctr_x = ex_rois[:, 0] 29 | ex_ctr_y = ex_rois[:, 1] 30 | ex_thetas = ex_rois[:, 4] 31 | 32 | gt_widths = gt_rois[:, 3] 33 | gt_heights = gt_rois[:, 2] 34 | gt_widths = torch.clamp(gt_widths, min=1) 35 | gt_heights = torch.clamp(gt_heights, min=1) 36 | gt_ctr_x = gt_rois[:, 0] 37 | gt_ctr_y = gt_rois[:, 1] 38 | gt_thetas = gt_rois[:, 4] 39 | 40 | targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths # t_x = (x - x_a) / w_a 41 | targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights # t_y = (y - y_a) / h_a 42 | targets_dw = torch.log(gt_widths / ex_widths) # t_w = log(w / w_a) 43 | targets_dh = torch.log(gt_heights / ex_heights) # t_h = log(h / h_a) 44 | targets_dt = gt_thetas - ex_thetas 45 | 46 | targets = torch.stack( 47 | (targets_dx, targets_dy, targets_dh, targets_dw, targets_dt), dim=1) 48 | 49 | means = targets.new_tensor(self.means).unsqueeze(0) 50 | stds = targets.new_tensor(self.stds).unsqueeze(0) 51 | targets = targets.sub_(means).div_(stds) 52 | return targets 53 | 54 | def decode(self, boxes, deltas): 55 | means = deltas.new_tensor(self.means).view(1, 1, -1).repeat(1, deltas.size(1), 1) 56 | stds = deltas.new_tensor(self.stds).view(1, 1, -1).repeat(1, deltas.size(1), 1) 57 | denorm_deltas = deltas * stds + means 58 | 59 | dx = denorm_deltas[:, :, 0] 60 | dy = denorm_deltas[:, :, 1] 61 | dh = denorm_deltas[:, :, 2] 62 | dw = denorm_deltas[:, :, 3] 63 | dt = denorm_deltas[:, :, 4] 64 | 65 | widths = boxes[:, :, 2] 66 | heights = boxes[:, :, 3] 67 | widths = torch.clamp(widths, min=1) 68 | heights = torch.clamp(heights, min=1) 69 | ctr_x = boxes[:, :, 0] 70 | ctr_y = boxes[:, :, 1] 71 | thetas = boxes[:, :, 4] 72 | 73 | pred_ctr_x = ctr_x + dx * widths 74 | pred_ctr_y = ctr_y + dy * heights 75 | pred_w = torch.exp(dw) * widths 76 | pred_h = torch.exp(dh) * heights 77 | pred_t = thetas + dt 78 | 79 | pred_boxes = torch.stack([ 80 | pred_ctr_x, 81 | pred_ctr_y, 82 | pred_h, 83 | pred_w, 84 | pred_t], dim=2) 85 | return pred_boxes 86 | -------------------------------------------------------------------------------- /utils/map.py: -------------------------------------------------------------------------------- 1 | # from shapely.geometry import Polygon 2 | import glob 3 | import json 4 | import os 5 | import shutil 6 | import operator 7 | import sys 8 | import argparse 9 | import math 10 | # import shapely 11 | import cv2 12 | from shapely.geometry import Polygon, MultiPoint 13 | 14 | import numpy as np 15 | # from tqdm import tqdm 16 | 17 | 18 | def skewiou(box1, box2): 19 | box1=np.asarray(box1).reshape(4,2) 20 | box2=np.asarray(box2).reshape(4,2) 21 | # ---------------- original code ---------------------------- 22 | poly1 = Polygon(box1).convex_hull 23 | poly2 = Polygon(box2).convex_hull 24 | if not poly1.is_valid or not poly2.is_valid : 25 | print('formatting errors for boxes!!!! ') 26 | return 0 27 | if poly1.area == 0 or poly2.area == 0 : 28 | return 0, 0 29 | inter = Polygon(poly1).intersection(Polygon(poly2)).area 30 | union = poly1.area + poly2.area - inter 31 | if union == 0: 32 | return 0, 0 33 | else: 34 | return inter/union, inter 35 | 36 | # ------------------ cv2 implementation ----------------------- 37 | 38 | 39 | 40 | def log_average_miss_rate(precision, fp_cumsum, num_images): 41 | """ 42 | log-average miss rate: 43 | Calculated by averaging miss rates at 9 evenly spaced FPPI points 44 | between 10e-2 and 10e0, in log-space. 45 | 46 | output: 47 | lamr | log-average miss rate 48 | mr | miss rate 49 | fppi | false positives per image 50 | 51 | references: 52 | [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the 53 | State of the Art." Pattern Analysis and Machine Intelligence, IEEE 54 | Transactions on 34.4 (2012): 743 - 761. 55 | """ 56 | 57 | # if there were no detections of that class 58 | if precision.size == 0: 59 | lamr = 0 60 | mr = 1 61 | fppi = 0 62 | return lamr, mr, fppi 63 | 64 | fppi = fp_cumsum / float(num_images) 65 | mr = (1 - precision) 66 | 67 | fppi_tmp = np.insert(fppi, 0, -1.0) 68 | mr_tmp = np.insert(mr, 0, 1.0) 69 | 70 | # Use 9 evenly spaced reference points in log-space 71 | ref = np.logspace(-2.0, 0.0, num = 9) 72 | for i, ref_i in enumerate(ref): 73 | # np.where() will always find at least 1 index, since min(ref) = 0.01 and min(fppi_tmp) = -1.0 74 | j = np.where(fppi_tmp <= ref_i)[-1][-1] 75 | ref[i] = mr_tmp[j] 76 | 77 | # log(0) is undefined, so we use the np.maximum(1e-10, ref) 78 | lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref)))) 79 | 80 | return lamr, mr, fppi 81 | 82 | """ 83 | throw error and exit 84 | """ 85 | def error(msg): 86 | print(msg) 87 | sys.exit(0) 88 | 89 | """ 90 | check if the number is a float between 0.0 and 1.0 91 | """ 92 | def is_float_between_0_and_1(value): 93 | try: 94 | val = float(value) 95 | if val > 0.0 and val < 1.0: 96 | return True 97 | else: 98 | return False 99 | except ValueError: 100 | return False 101 | 102 | """ 103 | Calculate the AP given the recall and precision array 104 | 1st) We compute a version of the measured precision/recall curve with 105 | precision monotonically decreasing 106 | 2nd) We compute the AP as the area under this curve by numerical integration. 107 | """ 108 | def voc_ap(rec, prec, use_07_metric=False): 109 | """ ap = voc_ap(rec, prec, [use_07_metric]) 110 | Compute VOC AP given precision and recall. 111 | If use_07_metric is true, uses the 112 | VOC 07 11 point method (default:False). 113 | """ 114 | if use_07_metric: 115 | mrec = np.concatenate(([0.], rec, [1.])) 116 | mpre = np.concatenate(([0.], prec, [0.])) 117 | # 11 point metric 118 | ap = 0. 119 | for t in np.arange(0., 1.1, 0.1): 120 | if np.sum(rec >= t) == 0: 121 | p = 0 122 | else: 123 | p = np.max(np.array(prec)[rec >= t]) 124 | ap = ap + p / 11. 125 | else: 126 | # correct AP calculation 127 | # first append sentinel values at the end 128 | mrec = np.concatenate(([0.], rec, [1.])) 129 | mpre = np.concatenate(([0.], prec, [0.])) 130 | 131 | # compute the precision envelope 132 | for i in range(mpre.size - 1, 0, -1): 133 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 134 | 135 | # to calculate area under PR curve, look for points 136 | # where X axis (recall) changes value 137 | i = np.where(mrec[1:] != mrec[:-1])[0] 138 | 139 | # and sum (\Delta recall) * prec 140 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 141 | return ap, mrec, mpre 142 | 143 | 144 | """ 145 | Convert the lines of a file to a list 146 | """ 147 | def file_lines_to_list(path): 148 | # open txt file lines to a list 149 | with open(path) as f: 150 | content = f.readlines() 151 | # remove whitespace characters like `\n` at the end of each line 152 | content = [x.strip() for x in content] 153 | return content 154 | 155 | """ 156 | Draws text in image 157 | """ 158 | def draw_text_in_image(img, text, pos, color, line_width): 159 | font = cv2.FONT_HERSHEY_PLAIN 160 | fontScale = 1 161 | lineType = 1 162 | bottomLeftCornerOfText = pos 163 | cv2.putText(img, text, 164 | bottomLeftCornerOfText, 165 | font, 166 | fontScale, 167 | color, 168 | lineType) 169 | text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0] 170 | return img, (line_width + text_width) 171 | 172 | """ 173 | Plot - adjust axes 174 | """ 175 | def adjust_axes(r, t, fig, axes): 176 | # get text width for re-scaling 177 | bb = t.get_window_extent(renderer=r) 178 | text_width_inches = bb.width / fig.dpi 179 | # get axis width in inches 180 | current_fig_width = fig.get_figwidth() 181 | new_fig_width = current_fig_width + text_width_inches 182 | propotion = new_fig_width / current_fig_width 183 | # get axis limit 184 | x_lim = axes.get_xlim() 185 | axes.set_xlim([x_lim[0], x_lim[1]*propotion]) 186 | 187 | """ 188 | Draw plot using Matplotlib 189 | """ 190 | def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar): 191 | # sort the dictionary by decreasing value, into a list of tuples 192 | sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1)) 193 | # unpacking the list of tuples into two lists 194 | sorted_keys, sorted_values = zip(*sorted_dic_by_value) 195 | # 196 | import matplotlib.pyplot as plt 197 | if true_p_bar != "": 198 | """ 199 | Special case to draw in: 200 | - green -> TP: True Positives (object detected and matches ground-truth) 201 | - red -> FP: False Positives (object detected but does not match ground-truth) 202 | - pink -> FN: False Negatives (object not detected but present in the ground-truth) 203 | """ 204 | fp_sorted = [] 205 | tp_sorted = [] 206 | for key in sorted_keys: 207 | fp_sorted.append(dictionary[key] - true_p_bar[key]) 208 | tp_sorted.append(true_p_bar[key]) 209 | plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive') 210 | plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted) 211 | # add legend 212 | plt.legend(loc='lower right') 213 | """ 214 | Write number on side of bar 215 | """ 216 | fig = plt.gcf() # gcf - get current figure 217 | axes = plt.gca() 218 | r = fig.canvas.get_renderer() 219 | for i, val in enumerate(sorted_values): 220 | fp_val = fp_sorted[i] 221 | tp_val = tp_sorted[i] 222 | fp_str_val = " " + str(fp_val) 223 | tp_str_val = fp_str_val + " " + str(tp_val) 224 | # trick to paint multicolor with offset: 225 | # first paint everything and then repaint the first number 226 | t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold') 227 | plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold') 228 | if i == (len(sorted_values)-1): # largest bar 229 | adjust_axes(r, t, fig, axes) 230 | else: 231 | plt.barh(range(n_classes), sorted_values, color=plot_color) 232 | """ 233 | Write number on side of bar 234 | """ 235 | fig = plt.gcf() # gcf - get current figure 236 | axes = plt.gca() 237 | r = fig.canvas.get_renderer() 238 | for i, val in enumerate(sorted_values): 239 | str_val = " " + str(val) # add a space before 240 | if val < 1.0: 241 | str_val = " {0:.2f}".format(val) 242 | t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold') 243 | # re-set axes to show number inside the figure 244 | if i == (len(sorted_values)-1): # largest bar 245 | adjust_axes(r, t, fig, axes) 246 | # set window title 247 | fig.canvas.set_window_title(window_title) 248 | # write classes in y axis 249 | tick_font_size = 12 250 | plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size) 251 | """ 252 | Re-scale height accordingly 253 | """ 254 | init_height = fig.get_figheight() 255 | # comput the matrix height in points and inches 256 | dpi = fig.dpi 257 | height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing) 258 | height_in = height_pt / dpi 259 | # compute the required figure height 260 | top_margin = 0.15 # in percentage of the figure height 261 | bottom_margin = 0.05 # in percentage of the figure height 262 | figure_height = height_in / (1 - top_margin - bottom_margin) 263 | # set new height 264 | if figure_height > init_height: 265 | fig.set_figheight(figure_height) 266 | 267 | # set plot title 268 | plt.title(plot_title, fontsize=14) 269 | # set axis titles 270 | # plt.xlabel('classes') 271 | plt.xlabel(x_label, fontsize='large') 272 | # adjust size of window 273 | fig.tight_layout() 274 | # save the plot 275 | fig.savefig(output_path) 276 | # show image 277 | if to_show: 278 | plt.show() 279 | # close the plot 280 | plt.close() 281 | 282 | 283 | def eval_mAP(gt_root_dir=None, test_path=None, eval_root_dir=None, use_07_metric=False, thres=0.5): 284 | """ 285 | Args: 286 | thres: rotation nms threshold 287 | """ 288 | MINOVERLAP = thres # default value (defined in the PASCAL VOC2012 challenge) 289 | 290 | # parser = argparse.ArgumentParser() 291 | # parser.add_argument('-na', '--no-animation', help="no animation is shown.", action="store_true") 292 | # parser.add_argument('-np', '--no-plot', help="no plot is shown.", action="store_true") 293 | # parser.add_argument('-q', '--quiet', help="minimalistic console output.", action="store_true") 294 | # # argparse receiving list of classes to be ignored (e.g., python map.py --ignore person book) 295 | # parser.add_argument('-i', '--ignore', nargs='+', type=str, help="ignore a list of classes.") 296 | # # argparse receiving list of classes with specific IoU (e.g., python map.py --set-class-iou person 0.7) 297 | # parser.add_argument('--set-class-iou', nargs='+', type=str, help="set IoU for a specific class.") 298 | # args = parser.parse_args() 299 | 300 | no_animation = False 301 | no_plot = False 302 | quiet = False 303 | ignore = None 304 | set_class_iou = None 305 | 306 | # if there are no classes to ignore then replace None by empty list 307 | if ignore is None: 308 | ignore = [] 309 | 310 | specific_iou_flagged = False 311 | if set_class_iou is not None: 312 | specific_iou_flagged = True 313 | 314 | # make sure that the cwd() is the location of the python script (so that every path makes sense) 315 | # os.chdir(os.path.dirname(os.path.abspath(__file__))) 316 | 317 | GT_PATH = os.path.join(gt_root_dir, test_path) 318 | DR_PATH = os.path.join(eval_root_dir, 'detection-results') 319 | # if there are no images then no animation can be shown 320 | IMG_PATH = os.path.join(gt_root_dir, 'images-optional') 321 | if os.path.exists(IMG_PATH): 322 | for dirpath, dirnames, files in os.walk(IMG_PATH): 323 | if not files: 324 | # no image files found 325 | no_animation = True 326 | else: 327 | no_animation = True 328 | 329 | # try to import OpenCV if the user didn't choose the option --no-animation 330 | show_animation = False 331 | if not no_animation: 332 | try: 333 | import cv2 334 | show_animation = True 335 | except ImportError: 336 | print("\"opencv-python\" not found, please install to visualize the results.") 337 | no_animation = True 338 | 339 | # try to import Matplotlib if the user didn't choose the option --no-plot 340 | draw_plot = False 341 | if not no_plot: 342 | try: 343 | import matplotlib.pyplot as plt 344 | draw_plot = True 345 | except ImportError: 346 | print("\"matplotlib\" not found, please install it to get the resulting plots.") 347 | no_plot = True 348 | 349 | 350 | """ 351 | Create a ".temp_files/" and "output/" directory 352 | """ 353 | TEMP_FILES_PATH = os.path.join(eval_root_dir, ".temp_files") 354 | if not os.path.exists(TEMP_FILES_PATH): # if it doesn't exist already 355 | os.makedirs(TEMP_FILES_PATH) 356 | output_files_path = os.path.join(eval_root_dir, "output") 357 | if os.path.exists(output_files_path): # if it exist already 358 | # reset the output directory 359 | shutil.rmtree(output_files_path) 360 | 361 | os.makedirs(output_files_path) 362 | if draw_plot: # plot some curves 363 | os.makedirs(os.path.join(output_files_path, "classes")) 364 | if show_animation: 365 | os.makedirs(os.path.join(output_files_path, "images", "detections_one_by_one")) 366 | 367 | """ 368 | ground-truth 369 | Load each of the ground-truth files into a temporary ".json" file. 370 | Create a list of all the class names present in the ground-truth (gt_classes). 371 | """ 372 | # get a list with the ground-truth files 373 | ground_truth_files_list = glob.glob(GT_PATH + '/*.txt') 374 | if len(ground_truth_files_list) == 0: 375 | error("Error: No ground-truth files found!") 376 | ground_truth_files_list.sort() 377 | # dictionary with counter per class 378 | gt_counter_per_class = {} # save the number of per ground-truth 379 | counter_images_per_class = {} 380 | 381 | gt_files = [] 382 | for txt_file in ground_truth_files_list: 383 | #print(txt_file) 384 | file_id = txt_file.split(".txt", 1)[0] 385 | file_id = os.path.basename(os.path.normpath(file_id)) 386 | # check if there is a correspondent detection-results file 387 | temp_path = os.path.join(DR_PATH, (file_id + ".txt")) 388 | if not os.path.exists(temp_path): 389 | error_msg = "Error. File not found: {}\n".format(temp_path) 390 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)" 391 | error(error_msg) 392 | lines_list = file_lines_to_list(txt_file) 393 | # create ground-truth dictionary 394 | bounding_boxes = [] 395 | is_difficult = False 396 | already_seen_classes = [] 397 | for line in lines_list: 398 | try: 399 | if "difficult" in line: 400 | class_name, x1, y1, x2, y2, x3, y3, x4, y4, _difficult = line.split() 401 | is_difficult = True 402 | else: 403 | class_name, x1, y1, x2, y2, x3, y3, x4, y4 = line.split() 404 | except ValueError: 405 | error_msg = "Error: File " + txt_file + " in the wrong format.\n" 406 | error_msg += " Expected: ['difficult']\n" 407 | error_msg += " Received: " + line 408 | error_msg += "\n\nIf you have a with spaces between words you should remove them\n" 409 | error_msg += "by running the script \"remove_space.py\" or \"rename_class.py\" in the \"extra/\" folder." 410 | error(error_msg) 411 | # check if class is in the ignore list, if yes skip 412 | if class_name in ignore: 413 | continue 414 | bbox = x1 + " " + y1 + " " + x2 + " " + y2 + " " + x3 + " " + y3 + " " + x4 + " " + y4 415 | if is_difficult: 416 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True}) 417 | is_difficult = False 418 | else: 419 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False}) 420 | # count that object 421 | if class_name in gt_counter_per_class: 422 | gt_counter_per_class[class_name] += 1 423 | else: 424 | # if class didn't exist yet 425 | gt_counter_per_class[class_name] = 1 426 | 427 | if class_name not in already_seen_classes: 428 | if class_name in counter_images_per_class: 429 | counter_images_per_class[class_name] += 1 430 | else: 431 | # if class didn't exist yet 432 | counter_images_per_class[class_name] = 1 433 | already_seen_classes.append(class_name) 434 | 435 | 436 | # dump bounding_boxes into a ".json" file 437 | new_temp_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json" 438 | gt_files.append(new_temp_file) 439 | with open(new_temp_file, 'w') as outfile: 440 | json.dump(bounding_boxes, outfile) 441 | 442 | gt_classes = list(gt_counter_per_class.keys()) 443 | # let's sort the classes alphabetically 444 | gt_classes = sorted(gt_classes) 445 | n_classes = len(gt_classes) 446 | #print(gt_classes) 447 | #print(gt_counter_per_class) 448 | 449 | """ 450 | Check format of the flag --set-class-iou (if used) 451 | e.g. check if class exists 452 | """ 453 | if specific_iou_flagged: 454 | n_args = len( set_class_iou) 455 | error_msg = \ 456 | '\n --set-class-iou [class_1] [IoU_1] [class_2] [IoU_2] [...]' 457 | if n_args % 2 != 0: 458 | error('Error, missing arguments. Flag usage:' + error_msg) 459 | # [class_1] [IoU_1] [class_2] [IoU_2] 460 | # specific_iou_classes = ['class_1', 'class_2'] 461 | specific_iou_classes = set_class_iou[::2] # even 462 | # iou_list = ['IoU_1', 'IoU_2'] 463 | iou_list = set_class_iou[1::2] # odd 464 | if len(specific_iou_classes) != len(iou_list): 465 | error('Error, missing arguments. Flag usage:' + error_msg) 466 | for tmp_class in specific_iou_classes: 467 | if tmp_class not in gt_classes: 468 | error('Error, unknown class \"' + tmp_class + '\". Flag usage:' + error_msg) 469 | for num in iou_list: 470 | if not is_float_between_0_and_1(num): 471 | error('Error, IoU must be between 0.0 and 1.0. Flag usage:' + error_msg) 472 | 473 | """ 474 | detection-results 475 | Load each of the detection-results files into a temporary ".json" file. 476 | """ 477 | # get a list with the detection-results files 478 | dr_files_list = glob.glob(DR_PATH + '/*.txt') 479 | dr_files_list.sort() 480 | 481 | for class_index, class_name in enumerate(gt_classes): 482 | bounding_boxes = [] 483 | for txt_file in dr_files_list: 484 | #print(txt_file) 485 | # the first time it checks if all the corresponding ground-truth files exist 486 | file_id = txt_file.split(".txt",1)[0] 487 | file_id = os.path.basename(os.path.normpath(file_id)) 488 | temp_path = os.path.join(GT_PATH, (file_id + ".txt")) 489 | if class_index == 0: 490 | if not os.path.exists(temp_path): 491 | error_msg = "Error. File not found: {}\n".format(temp_path) 492 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)" 493 | error(error_msg) 494 | lines = file_lines_to_list(txt_file) 495 | for line in lines: 496 | try: 497 | tmp_class_name, confidence, x1, y1, x2, y2, x3, y3, x4, y4 = line.split() 498 | except ValueError: 499 | error_msg = "Error: File " + txt_file + " in the wrong format.\n" 500 | error_msg += " Expected: \n" 501 | error_msg += " Received: " + line 502 | error(error_msg) 503 | if tmp_class_name == class_name: 504 | #print("match") 505 | bbox = x1 + " " + y1 + " " + x2 + " " + y2 + " " + x3 + " " + y3 + " " + x4 + " " + y4 506 | bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox}) 507 | #print(bounding_boxes) 508 | # sort detection-results by decreasing confidence 509 | bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True) 510 | with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile: 511 | json.dump(bounding_boxes, outfile) 512 | 513 | """ 514 | Calculate the AP for each class 515 | """ 516 | sum_AP = 0.0 517 | ap_dictionary = {} 518 | lamr_dictionary = {} 519 | # open file to store the output 520 | with open(output_files_path + "/output.txt", 'w') as output_file: 521 | output_file.write("# AP and precision/recall per class\n") 522 | count_true_positives = {} 523 | for class_index, class_name in enumerate(gt_classes): 524 | count_true_positives[class_name] = 0 525 | """ 526 | Load detection-results of that class 527 | """ 528 | dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json" 529 | dr_data = json.load(open(dr_file)) 530 | 531 | """ 532 | Assign detection-results to ground-truth objects 533 | """ 534 | nd = len(dr_data) # the number of the detections 535 | tp = [0] * nd # creates an array of zeros of size nd 536 | fp = [0] * nd 537 | # print('evaluate on class: {} '.format(class_name)) 538 | # for idx, detection in enumerate(tqdm(dr_data)): 539 | for _index, detection in enumerate(dr_data): # todo: idx -> _index 540 | file_id = detection["file_id"] 541 | if show_animation: 542 | # find ground truth image 543 | ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*") 544 | #tifCounter = len(glob.glob1(myPath,"*.tif")) 545 | if len(ground_truth_img) == 0: 546 | error("Error. Image not found with id: " + file_id) 547 | elif len(ground_truth_img) > 1: 548 | error("Error. Multiple image with id: " + file_id) 549 | else: # found image 550 | #print(IMG_PATH + "/" + ground_truth_img[0]) 551 | # Load image 552 | img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0]) 553 | # load image with draws of multiple detections 554 | img_cumulative_path = output_files_path + "/images/" + ground_truth_img[0] 555 | if os.path.isfile(img_cumulative_path): 556 | img_cumulative = cv2.imread(img_cumulative_path) 557 | else: 558 | img_cumulative = img.copy() 559 | # Add bottom border to image 560 | bottom_border = 60 561 | BLACK = [0, 0, 0] 562 | img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK) 563 | # as3;fmtsign detection-results to ground truth object if any 564 | # open ground-truth with that file_id 565 | gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json" 566 | ground_truth_data = json.load(open(gt_file)) 567 | ovmax = -1 568 | gt_match = -1 569 | # load detected object bounding-box 570 | bb = [ float(x) for x in detection["bbox"].split() ] 571 | for idx, obj in enumerate(ground_truth_data): 572 | # look for a class_name match 573 | if obj["class_name"] == class_name: 574 | bbgt = [ float(x) for x in obj["bbox"].split() ] 575 | ### IoU calculation 576 | iou, inter = skewiou(bbgt, bb) 577 | if inter != 0: 578 | ov = iou 579 | if ov > ovmax: 580 | ovmax = ov 581 | gt_match = obj 582 | # bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])] 583 | # iw = bi[2] - bi[0] + 1 584 | # ih = bi[3] - bi[1] + 1 585 | # if iw > 0 and ih > 0: 586 | # # compute overlap (IoU) = area of intersection / area of union 587 | # ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0] 588 | # + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih 589 | # ov = iw * ih / ua 590 | # if ov > ovmax: 591 | # ovmax = ov 592 | # gt_match = obj 593 | 594 | # assign detection as true positive/don't care/false positive 595 | if show_animation: 596 | status = "NO MATCH FOUND!" # status is only used in the animation 597 | # set minimum overlap 598 | min_overlap = MINOVERLAP 599 | if specific_iou_flagged: 600 | if class_name in specific_iou_classes: 601 | index = specific_iou_classes.index(class_name) 602 | min_overlap = float(iou_list[index]) 603 | if ovmax >= min_overlap: 604 | if "difficult" not in gt_match: 605 | if not bool(gt_match["used"]): 606 | # true positive 607 | tp[_index] = 1 608 | gt_match["used"] = True 609 | count_true_positives[class_name] += 1 610 | # update the ".json" file 611 | with open(gt_file, 'w') as f: 612 | f.write(json.dumps(ground_truth_data)) 613 | if show_animation: 614 | status = "MATCH!" 615 | else: 616 | # false positive (multiple detection) 617 | fp[_index] = 1 618 | if show_animation: 619 | status = "REPEATED MATCH!" 620 | else: 621 | # false positive 622 | fp[_index] = 1 623 | if ovmax > 0: 624 | status = "INSUFFICIENT OVERLAP" 625 | 626 | """ 627 | Draw image to show animation 628 | """ 629 | if show_animation: 630 | height, widht = img.shape[:2] 631 | # colors (OpenCV works with BGR) 632 | white = (255,255,255) 633 | light_blue = (255,200,100) 634 | green = (0,255,0) 635 | light_red = (30,30,255) 636 | # 1st line 637 | margin = 10 638 | v_pos = int(height - margin - (bottom_border / 2.0)) 639 | text = "Image: " + ground_truth_img[0] + " " 640 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) 641 | text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " " 642 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width) 643 | if ovmax != -1: 644 | color = light_red 645 | if status == "INSUFFICIENT OVERLAP": 646 | text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100) 647 | else: 648 | text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100) 649 | color = green 650 | img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) 651 | # 2nd line 652 | v_pos += int(bottom_border / 2.0) 653 | rank_pos = str(idx+1) # rank position (idx starts at 0) 654 | text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100) 655 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) 656 | color = light_red 657 | if status == "MATCH!": 658 | color = green 659 | text = "Result: " + status + " " 660 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) 661 | 662 | font = cv2.FONT_HERSHEY_SIMPLEX 663 | if ovmax > 0: # if there is intersections between the bounding-boxes 664 | bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ] 665 | cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) 666 | cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) 667 | cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA) 668 | bb = [int(i) for i in bb] 669 | cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2) 670 | cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2) 671 | cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA) 672 | # show image 673 | cv2.imshow("Animation", img) 674 | cv2.waitKey(20) # show for 20 ms 675 | # save image to output 676 | output_img_path = output_files_path + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg" 677 | cv2.imwrite(output_img_path, img) 678 | # save the image with all the objects drawn to it 679 | cv2.imwrite(img_cumulative_path, img_cumulative) 680 | 681 | #print(tp) 682 | # compute precision/recall 683 | cumsum = 0 684 | for idx, val in enumerate(fp): 685 | fp[idx] += cumsum 686 | cumsum += val 687 | cumsum = 0 688 | for idx, val in enumerate(tp): 689 | tp[idx] += cumsum 690 | cumsum += val 691 | #print(tp) 692 | rec = tp[:] 693 | for idx, val in enumerate(tp): 694 | rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name] 695 | recall = rec[-1] 696 | prec = tp[:] 697 | for idx, val in enumerate(tp): 698 | prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx] + 1e-6) 699 | precision = prec[-1] 700 | 701 | ap, mrec, mprec = voc_ap(rec[:], prec[:],use_07_metric=use_07_metric) 702 | sum_AP += ap 703 | text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100) 704 | """ 705 | Write to output.txt 706 | """ 707 | rounded_prec = [ '%.2f' % elem for elem in prec ] 708 | rounded_rec = [ '%.2f' % elem for elem in rec ] 709 | output_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n") 710 | # if not quiet: 711 | # print(text) 712 | ap_dictionary[class_name] = ap 713 | 714 | n_images = counter_images_per_class[class_name] 715 | lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images) 716 | lamr_dictionary[class_name] = lamr 717 | 718 | """ 719 | Draw plot 720 | """ 721 | if draw_plot: 722 | plt.plot(rec, prec, '-o') 723 | # add a new penultimate point to the list (mrec[-2], 0.0) 724 | # since the last line segment (and respective area) do not affect the AP value 725 | area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]] 726 | area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]] 727 | plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r') 728 | # set window title 729 | fig = plt.gcf() # gcf - get current figure 730 | fig.canvas.set_window_title('AP ' + class_name) 731 | # set plot title 732 | plt.title('class: ' + text) 733 | #plt.suptitle('This is a somewhat long figure title', fontsize=16) 734 | # set axis titles 735 | plt.xlabel('Recall') 736 | plt.ylabel('Precision') 737 | # optional - set axes 738 | axes = plt.gca() # gca - get current axes 739 | axes.set_xlim([0.0,1.0]) 740 | axes.set_ylim([0.0,1.05]) # .05 to give some extra space 741 | # Alternative option -> wait for button to be pressed 742 | #while not plt.waitforbuttonpress(): pass # wait for key display 743 | # Alternative option -> normal display 744 | #plt.show() 745 | # save the plot 746 | fig.savefig(output_files_path + "/classes/" + class_name + ".png") 747 | plt.cla() # clear axes for next plot 748 | 749 | if show_animation: 750 | cv2.destroyAllWindows() 751 | 752 | output_file.write("\n# mAP of all classes\n") 753 | mAP = sum_AP / n_classes 754 | text = "mAP = {0:.2f}%".format(mAP*100) 755 | output_file.write(text + "\n") 756 | # print(text) 757 | 758 | """ 759 | Draw false negatives 760 | """ 761 | # pink = (203,192,255) 762 | # for tmp_file in gt_files: 763 | # ground_truth_data = json.load(open(tmp_file)) 764 | # #print(ground_truth_data) 765 | # # get name of corresponding image 766 | # start = TEMP_FILES_PATH + '/' 767 | # img_id = tmp_file[tmp_file.find(start)+len(start):tmp_file.rfind('_ground_truth.json')] 768 | # img_cumulative_path = output_files_path + "/images/" + img_id + ".jpg" 769 | # import cv2 770 | # img = cv2.imread(img_cumulative_path) 771 | # if img is None: 772 | # img_path = IMG_PATH + '/' + img_id + ".jpg" 773 | # img = cv2.imread(img_path) 774 | # draw false negatives 775 | # for obj in ground_truth_data: 776 | # if not obj['used']: 777 | # bbgt = [ int(round(float(x))) for x in obj["bbox"].split() ] 778 | # cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),pink,2) 779 | # cv2.imwrite(img_cumulative_path, img) 780 | 781 | # remove the temp_files directory 782 | shutil.rmtree(TEMP_FILES_PATH) 783 | 784 | """ 785 | Count total of detection-results 786 | """ 787 | # iterate through all the files 788 | det_counter_per_class = {} 789 | for txt_file in dr_files_list: 790 | # get lines to list 791 | lines_list = file_lines_to_list(txt_file) 792 | for line in lines_list: 793 | class_name = line.split()[0] 794 | # check if class is in the ignore list, if yes skip 795 | if class_name in ignore: 796 | continue 797 | # count that object 798 | if class_name in det_counter_per_class: 799 | det_counter_per_class[class_name] += 1 800 | else: 801 | # if class didn't exist yet 802 | det_counter_per_class[class_name] = 1 803 | #print(det_counter_per_class) 804 | dr_classes = list(det_counter_per_class.keys()) 805 | 806 | 807 | """ 808 | Plot the total number of occurences of each class in the ground-truth 809 | """ 810 | if draw_plot: 811 | window_title = "ground-truth-info" 812 | plot_title = "ground-truth\n" 813 | plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)" 814 | x_label = "Number of objects per class" 815 | output_path = output_files_path + "/ground-truth-info.png" 816 | to_show = False 817 | plot_color = 'forestgreen' 818 | draw_plot_func( 819 | gt_counter_per_class, 820 | n_classes, 821 | window_title, 822 | plot_title, 823 | x_label, 824 | output_path, 825 | to_show, 826 | plot_color, 827 | '', 828 | ) 829 | 830 | """ 831 | Write number of ground-truth objects per class to results.txt 832 | """ 833 | with open(output_files_path + "/output.txt", 'a') as output_file: 834 | output_file.write("\n# Number of ground-truth objects per class\n") 835 | for class_name in sorted(gt_counter_per_class): 836 | output_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n") 837 | 838 | """ 839 | Finish counting true positives 840 | """ 841 | for class_name in dr_classes: 842 | # if class exists in detection-result but not in ground-truth then there are no true positives in that class 843 | if class_name not in gt_classes: 844 | count_true_positives[class_name] = 0 845 | #print(count_true_positives) 846 | 847 | """ 848 | Plot the total number of occurences of each class in the "detection-results" folder 849 | """ 850 | if draw_plot: 851 | window_title = "detection-results-info" 852 | # Plot title 853 | plot_title = "detection-results\n" 854 | plot_title += "(" + str(len(dr_files_list)) + " files and " 855 | count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values())) 856 | plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)" 857 | # end Plot title 858 | x_label = "Number of objects per class" 859 | output_path = output_files_path + "/detection-results-info.png" 860 | to_show = False 861 | plot_color = 'forestgreen' 862 | true_p_bar = count_true_positives 863 | try: 864 | draw_plot_func( 865 | det_counter_per_class, 866 | len(det_counter_per_class), 867 | window_title, 868 | plot_title, 869 | x_label, 870 | output_path, 871 | to_show, 872 | plot_color, 873 | true_p_bar 874 | ) 875 | except: 876 | pass 877 | 878 | """ 879 | Write number of detected objects per class to output.txt 880 | """ 881 | with open(output_files_path + "/output.txt", 'a') as output_file: 882 | output_file.write("\n# Number of detected objects per class\n") 883 | for class_name in sorted(dr_classes): 884 | n_det = det_counter_per_class[class_name] 885 | text = class_name + ": " + str(n_det) 886 | text += " (tp:" + str(count_true_positives[class_name]) + "" 887 | text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n" 888 | output_file.write(text) 889 | 890 | """ 891 | Draw log-average miss rate plot (Show lamr of all classes in decreasing order) 892 | """ 893 | if draw_plot: 894 | window_title = "lamr" 895 | plot_title = "log-average miss rate" 896 | x_label = "log-average miss rate" 897 | output_path = output_files_path + "/lamr.png" 898 | to_show = False 899 | plot_color = 'royalblue' 900 | draw_plot_func( 901 | lamr_dictionary, 902 | n_classes, 903 | window_title, 904 | plot_title, 905 | x_label, 906 | output_path, 907 | to_show, 908 | plot_color, 909 | "" 910 | ) 911 | 912 | """ 913 | Draw mAP plot (Show AP's of all classes in decreasing order) 914 | """ 915 | if draw_plot: 916 | window_title = "mAP" 917 | plot_title = "mAP = {0:.2f}%".format(mAP*100) 918 | x_label = "Average Precision" 919 | output_path = output_files_path + "/mAP.png" 920 | to_show = False 921 | plot_color = 'royalblue' 922 | draw_plot_func( 923 | ap_dictionary, 924 | n_classes, 925 | window_title, 926 | plot_title, 927 | x_label, 928 | output_path, 929 | to_show, 930 | plot_color, 931 | "" 932 | ) 933 | return mAP, precision, recall 934 | -------------------------------------------------------------------------------- /utils/rotation_nms/.gitignore: -------------------------------------------------------------------------------- 1 | *.cpp 2 | *.so 3 | -------------------------------------------------------------------------------- /utils/rotation_nms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/rotation_nms/__init__.py -------------------------------------------------------------------------------- /utils/rotation_nms/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/rotation_nms/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/rotation_nms/cpu_nms.pyx: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------- 2 | # Soft-NMS: Improving Object Detection With One Line of Code 3 | # Copyright (c) University of Maryland, College Park 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Navaneeth Bodla and Bharat Singh 6 | # ---------------------------------------------------------- 7 | 8 | import cv2 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | cdef inline np.float32_t max(np.float32_t a, np.float32_t b): 13 | return a if a >= b else b 14 | 15 | cdef inline np.float32_t min(np.float32_t a, np.float32_t b): 16 | return a if a <= b else b 17 | 18 | def cpu_soft_nms( 19 | np.ndarray[float, ndim=2] boxes, 20 | float thresh=0.3, 21 | unsigned int method=1, 22 | float sigma=0.5, 23 | float min_score=0.001 24 | ): 25 | cdef unsigned int N = boxes.shape[0] 26 | cdef float iw, ih, box_area 27 | cdef float ua 28 | cdef int pos = 0 29 | cdef float maxscore = 0 30 | cdef int maxpos = 0 31 | cdef float x1, x2, y1, y2, t, s, w, h, xx, yy, tx1, tx2, ty1, ty2, tt, ts, tw, th, txx, tyy, area, weight, ov, inter 32 | 33 | inds = np.arange(N) 34 | for i in range(N): 35 | maxscore = boxes[i, 5] 36 | maxpos = i 37 | 38 | tx1 = boxes[i, 0] 39 | ty1 = boxes[i, 1] 40 | tx2 = boxes[i, 2] 41 | ty2 = boxes[i, 3] 42 | tt = boxes[i, 4] 43 | ts = boxes[i, 5] 44 | ti = inds[i] 45 | 46 | pos = i + 1 47 | # get max box 48 | while pos < N: 49 | if maxscore < boxes[pos, 5]: 50 | maxscore = boxes[pos, 5] 51 | maxpos = pos 52 | pos = pos + 1 53 | 54 | # add max box as a detection 55 | boxes[i, 0] = boxes[maxpos, 0] 56 | boxes[i, 1] = boxes[maxpos, 1] 57 | boxes[i, 2] = boxes[maxpos, 2] 58 | boxes[i, 3] = boxes[maxpos, 3] 59 | boxes[i, 4] = boxes[maxpos, 4] 60 | boxes[i, 5] = boxes[maxpos, 5] 61 | inds[i] = inds[maxpos] 62 | 63 | # swap ith box with position of max box 64 | boxes[maxpos, 0] = tx1 65 | boxes[maxpos, 1] = ty1 66 | boxes[maxpos, 2] = tx2 67 | boxes[maxpos, 3] = ty2 68 | boxes[maxpos, 4] = tt 69 | boxes[maxpos, 5] = ts 70 | inds[maxpos] = ti 71 | 72 | tx1 = boxes[i, 0] 73 | ty1 = boxes[i, 1] 74 | tx2 = boxes[i, 2] 75 | ty2 = boxes[i, 3] 76 | tt = boxes[i, 4] 77 | ts = boxes[i, 5] 78 | 79 | tw = tx2 - tx1 80 | th = ty2 - ty1 81 | txx = tx1 + tw * 0.5 82 | tyy = ty1 + th * 0.5 83 | 84 | pos = i + 1 85 | # NMS iterations, note that N changes if detection boxes fall below threshold 86 | while pos < N: 87 | x1 = boxes[pos, 0] 88 | y1 = boxes[pos, 1] 89 | x2 = boxes[pos, 2] 90 | y2 = boxes[pos, 3] 91 | t = boxes[pos, 4] 92 | s = boxes[pos, 5] 93 | 94 | w = x2 - x1 95 | h = y2 - y1 96 | xx = x1 + w * 0.5 97 | yy = y1 + h * 0.5 98 | 99 | rtn, contours = cv2.rotatedRectangleIntersection( 100 | ((txx, tyy), (tw, th), tt), 101 | ((xx, yy), (w, h), t) 102 | ) 103 | if rtn == 1: 104 | inter = np.round(np.abs(cv2.contourArea(contours))) 105 | elif rtn == 2: 106 | inter = min(tw * th, w * h) 107 | else: 108 | inter = 0.0 109 | 110 | if inter > 0.0: 111 | # iou between max box and detection box 112 | ov = inter / (tw * th + w * h - inter) 113 | if method == 1: # linear 114 | if ov > thresh: 115 | weight = 1 - ov 116 | else: 117 | weight = 1 118 | elif method == 2: # gaussian 119 | weight = np.exp(-(ov * ov) / sigma) 120 | else: # original NMS 121 | if ov > thresh: 122 | weight = 0 123 | else: 124 | weight = 1 125 | boxes[pos, 5] = weight * boxes[pos, 5] 126 | # if box score falls below threshold, discard the box by swapping with last box, update N 127 | if boxes[pos, 5] < min_score: 128 | boxes[pos, 0] = boxes[N-1, 0] 129 | boxes[pos, 1] = boxes[N-1, 1] 130 | boxes[pos, 2] = boxes[N-1, 2] 131 | boxes[pos, 3] = boxes[N-1, 3] 132 | boxes[pos, 4] = boxes[N-1, 4] 133 | boxes[pos, 5] = boxes[N-1, 5] 134 | inds[pos] = inds[N - 1] 135 | N = N - 1 136 | pos = pos - 1 137 | pos = pos + 1 138 | 139 | return inds[:N] 140 | 141 | 142 | def cpu_nms( 143 | np.ndarray[np.float32_t, ndim=2] dets, 144 | np.float thresh 145 | ): 146 | cdef np.ndarray[np.float32_t, ndim=1] ws = dets[:, 2] 147 | cdef np.ndarray[np.float32_t, ndim=1] hs = dets[:, 3] 148 | cdef np.ndarray[np.float32_t, ndim=1] xx = dets[:, 0] 149 | cdef np.ndarray[np.float32_t, ndim=1] yy = dets[:, 1] 150 | cdef np.ndarray[np.float32_t, ndim=1] tt = dets[:, 4] 151 | cdef np.ndarray[np.float32_t, ndim=1] areas = ws * hs 152 | 153 | cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 5] 154 | cdef np.ndarray[np.intp_t, ndim=1] order = scores.argsort()[::-1] 155 | 156 | cdef int ndets = dets.shape[0] 157 | cdef np.ndarray[np.int_t, ndim=1] suppressed = np.zeros((ndets), dtype=np.int) 158 | 159 | cdef int _i, _j, i, j, rtn 160 | cdef np.float32_t inter, ovr 161 | 162 | keep = [] 163 | for _i in range(ndets): 164 | i = order[_i] 165 | if suppressed[i] == 1: 166 | continue 167 | keep.append(i) 168 | for _j in range(_i + 1, ndets): 169 | j = order[_j] 170 | if suppressed[j] == 1: 171 | continue 172 | rtn, contours = cv2.rotatedRectangleIntersection( 173 | ((xx[i], yy[i]), (ws[i], hs[i]), tt[i]), 174 | ((xx[j], yy[j]), (ws[j], hs[j]), tt[j]) 175 | ) 176 | if rtn == 1: 177 | inter = np.round(np.abs(cv2.contourArea(contours))) 178 | elif rtn == 2: 179 | inter = min(areas[i], areas[j]) 180 | else: 181 | inter = 0.0 182 | ovr = inter / (areas[i] + areas[j] - inter + 1e-6) 183 | if ovr >= thresh: 184 | suppressed[j] = 1 185 | 186 | return keep 187 | -------------------------------------------------------------------------------- /utils/rotation_overlaps/.gitignore: -------------------------------------------------------------------------------- 1 | *.cpp 2 | *.so 3 | -------------------------------------------------------------------------------- /utils/rotation_overlaps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/rotation_overlaps/__init__.py -------------------------------------------------------------------------------- /utils/rotation_overlaps/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/rotation_overlaps/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/rotation_overlaps/rbox_overlaps.pyx: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | cimport cython 4 | cimport numpy as np 5 | 6 | ctypedef np.float32_t DTYPE_t 7 | 8 | 9 | def rbox_overlaps( 10 | np.ndarray[DTYPE_t, ndim=2] boxes, 11 | np.ndarray[DTYPE_t, ndim=2] query_boxes, 12 | np.ndarray[DTYPE_t, ndim=2] indicator=None, 13 | np.float thresh=1e-4): 14 | """ 15 | Parameters: 16 | boxes: (N, 5) ndarray of float: [xc, yc, w, h, angle(radian)] 17 | query_boxes: (K, 5) ndarray of float: [xc, yc, w, h, angle(radian)] 18 | 19 | Returns: 20 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 21 | """ 22 | cdef unsigned int N = boxes.shape[0] 23 | cdef unsigned int K = query_boxes.shape[0] 24 | cdef DTYPE_t box_area 25 | cdef DTYPE_t ua, ia 26 | cdef unsigned int k, n, rtn 27 | cdef np.ndarray[DTYPE_t, ndim=3] contours 28 | cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=np.float32) 29 | 30 | cdef np.ndarray[DTYPE_t, ndim=1] a_tt = boxes[:, 4] * 180 / np.pi 31 | cdef np.ndarray[DTYPE_t, ndim=1] a_ws = boxes[:, 2] 32 | cdef np.ndarray[DTYPE_t, ndim=1] a_hs = boxes[:, 3] 33 | cdef np.ndarray[DTYPE_t, ndim=1] a_xx = boxes[:, 0] 34 | cdef np.ndarray[DTYPE_t, ndim=1] a_yy = boxes[:, 1] 35 | 36 | cdef np.ndarray[DTYPE_t, ndim=1] b_tt = query_boxes[:, 4] * 180 / np.pi 37 | cdef np.ndarray[DTYPE_t, ndim=1] b_ws = query_boxes[:, 2] 38 | cdef np.ndarray[DTYPE_t, ndim=1] b_hs = query_boxes[:, 3] 39 | cdef np.ndarray[DTYPE_t, ndim=1] b_xx = query_boxes[:, 0] 40 | cdef np.ndarray[DTYPE_t, ndim=1] b_yy = query_boxes[:, 1] 41 | 42 | for k in range(K): 43 | box_area = b_ws[k] * b_hs[k] 44 | for n in range(N): 45 | if indicator is not None and indicator[n, k] < thresh: 46 | continue 47 | ua = a_ws[n] * a_hs[n] + box_area 48 | rtn, contours = cv2.rotatedRectangleIntersection( 49 | ((a_xx[n], a_yy[n]), (a_ws[n], a_hs[n]), a_tt[n]), 50 | ((b_xx[k], b_yy[k]), (b_ws[k], b_hs[k]), b_tt[k]) 51 | ) 52 | if rtn == 1: 53 | ia = np.round(np.abs(cv2.contourArea(contours))) 54 | overlaps[n, k] = ia / (ua - ia) 55 | elif rtn == 2: 56 | ia = np.minimum(ua - box_area, box_area) 57 | overlaps[n, k] = ia / (ua - ia) 58 | return overlaps 59 | 60 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | import torchvision.transforms as transforms 5 | import random 6 | import torch.nn as nn 7 | 8 | 9 | def set_random_seed(seed, deterministic=False): 10 | """Set random seed. 11 | Args: 12 | deterministic is set True if use torch.backends.cudnn.deterministic 13 | Default is False. 14 | """ 15 | print(f'[Info]: Set random seed to {seed}, deterministic: {deterministic}.') 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | 21 | if deterministic: 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | 25 | 26 | def xavier_init(module, gain=1, bias=0, distribution='normal'): 27 | assert distribution in ['uniform', 'normal'] 28 | if hasattr(module, 'weight') and module.weight is not None: 29 | if distribution == 'uniform': 30 | nn.init.xavier_uniform_(module.weight, gain=gain) 31 | else: 32 | nn.init.xavier_normal_(module.weight, gain=gain) 33 | 34 | if hasattr(module, 'bias') and module.bias is not None: 35 | nn.init.constant_(module.bias, bias) 36 | 37 | 38 | def kaiming_init(module, a=0, mode='fan_out', nonlinearity='relu', bias=0, distribution='normal'): 39 | assert distribution in ['uniform', 'normal'] 40 | if hasattr(module, 'weight') and module.weight is not None: 41 | if distribution == 'uniform': 42 | nn.init.kaiming_uniform_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 43 | else: 44 | nn.init.kaiming_normal_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 45 | if hasattr(module, 'bias') and module.bias is not None: 46 | nn.init.constant_(module.bias, bias) 47 | 48 | 49 | def constant_init(module, val, bias=0): 50 | if hasattr(module, 'weight') and module.weight is not None: 51 | nn.init.constant_(module.weight, val) 52 | if hasattr(module, 'bias') and module.bias is not None: 53 | nn.init.constant_(module.bias, bias) 54 | 55 | 56 | def normal_init(module, mean=0, std=1, bias=0): 57 | if hasattr(module, 'weight') and module.weight is not None: 58 | nn.init.normal_(module.weight, mean, std) 59 | if hasattr(module, 'bias') and module.bias is not None: 60 | nn.init.constant_(module.bias, bias) 61 | 62 | 63 | def pretty_print(num_params, units=None, precision=2): 64 | if units is None: 65 | if num_params // 10**6 > 0: 66 | print(f'[Info]: Model Params = {str(round(num_params / 10**6, precision))}' + ' M') 67 | elif num_params // 10**3: 68 | print(f'[Info]: Model Params = {str(round(num_params / 10**3, precision))}' + ' k') 69 | else: 70 | print(f'[Info]: Model Params = {str(num_params)}') 71 | 72 | 73 | def count_param(model, units=None, precision=2): 74 | """Count Params""" 75 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 76 | pretty_print(num_params) 77 | 78 | 79 | def show_args(args): 80 | print('=============== Show Args ===============') 81 | for k in list(vars(args).keys()): 82 | print('%s: %s' % (k, vars(args)[k])) 83 | 84 | 85 | def clip_boxes(boxes, ims): 86 | _, _, h, w = ims.shape 87 | boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0) 88 | boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0) 89 | boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=w) 90 | boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=h) 91 | return boxes 92 | 93 | 94 | # (num_boxes, 5) xyxya 95 | def min_area_square(rboxes): 96 | w = rboxes[:, 2] - rboxes[:, 0] 97 | h = rboxes[:, 3] - rboxes[:, 1] 98 | ctr_x = rboxes[:, 0] + w * 0.5 99 | ctr_y = rboxes[:, 1] + h * 0.5 100 | s = torch.max(w, h) 101 | return torch.stack(( 102 | ctr_x - s * 0.5, ctr_y - s * 0.5, 103 | ctr_x + s * 0.5, ctr_y + s * 0.5), 104 | dim=1 105 | ) 106 | 107 | 108 | def rbox_overlaps(boxes, query_boxes, indicator=None, thresh=1e-1): 109 | # rewrited by cython 110 | N = boxes.shape[0] 111 | K = query_boxes.shape[0] 112 | 113 | a_tt = boxes[:, 4] 114 | a_ws = boxes[:, 2] - boxes[:, 0] 115 | a_hs = boxes[:, 3] - boxes[:, 1] 116 | a_xx = boxes[:, 0] + a_ws * 0.5 117 | a_yy = boxes[:, 1] + a_hs * 0.5 118 | 119 | b_tt = query_boxes[:, 4] 120 | b_ws = query_boxes[:, 2] - query_boxes[:, 0] 121 | b_hs = query_boxes[:, 3] - query_boxes[:, 1] 122 | b_xx = query_boxes[:, 0] + b_ws * 0.5 123 | b_yy = query_boxes[:, 1] + b_hs * 0.5 124 | 125 | overlaps = np.zeros((N, K), dtype=np.float32) 126 | for k in range(K): 127 | box_area = b_ws[k] * b_hs[k] 128 | for n in range(N): 129 | if indicator is not None and indicator[n, k] < thresh: 130 | continue 131 | ua = a_ws[n] * a_hs[n] + box_area 132 | rtn, contours = cv2.rotatedRectangleIntersection( 133 | ((a_xx[n], a_yy[n]), (a_ws[n], a_hs[n]), a_tt[n]), 134 | ((b_xx[k], b_yy[k]), (b_ws[k], b_hs[k]), b_tt[k]) 135 | ) 136 | if rtn == 1: 137 | ia = cv2.contourArea(contours) 138 | overlaps[n, k] = ia / (ua - ia) 139 | elif rtn == 2: 140 | ia = np.minimum(ua - box_area, box_area) 141 | overlaps[n, k] = ia / (ua - ia) 142 | return overlaps 143 | 144 | 145 | def bbox_overlaps(boxes, query_boxes): 146 | """Calculate the horizontal overlaps 147 | 148 | Args: 149 | boxes: [xc, yc, w, h, angle] 150 | query_boxes: [xc, yc, w, h, pi/2] 151 | """ 152 | if not isinstance(boxes, float): # apex 153 | boxes = boxes.float() 154 | 155 | # convert the [xc, yc, w, h, angle] to [x1, y1, x2, y2, angle] 156 | query_boxes[:, 0] = query_boxes[:, 0] - query_boxes[:, 2] / 2 157 | query_boxes[:, 1] = query_boxes[:, 1] - query_boxes[:, 3] / 2 158 | query_boxes[:, 2] = query_boxes[:, 0] + query_boxes[:, 2] 159 | query_boxes[:, 3] = query_boxes[:, 1] + query_boxes[:, 3] 160 | 161 | boxes[:, 0] = boxes[:, 0] - boxes[:, 2] / 2 162 | boxes[:, 1] = boxes[:, 1] - boxes[:, 3] / 2 163 | boxes[:, 2] = boxes[:, 0] + boxes[:, 2] 164 | boxes[:, 3] = boxes[:, 1] + boxes[:, 3] 165 | 166 | area = (query_boxes[:, 2] - query_boxes[:, 0]) * \ 167 | (query_boxes[:, 3] - query_boxes[:, 1]) 168 | iw = torch.min(torch.unsqueeze(boxes[:, 2], dim=1), query_boxes[:, 2]) - \ 169 | torch.max(torch.unsqueeze(boxes[:, 0], 1), query_boxes[:, 0]) 170 | ih = torch.min(torch.unsqueeze(boxes[:, 3], dim=1), query_boxes[:, 3]) - \ 171 | torch.max(torch.unsqueeze(boxes[:, 1], 1), query_boxes[:, 1]) 172 | iw = torch.clamp(iw, min=0) 173 | ih = torch.clamp(ih, min=0) 174 | ua = torch.unsqueeze((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]), dim=1) + area - iw * ih 175 | ua = torch.clamp(ua, min=1e-8) 176 | intersection = iw * ih 177 | return intersection / ua 178 | 179 | 180 | def rescale(im, target_size, max_size, keep_ratio, multiple=32): 181 | im_shape = im.shape 182 | im_size_min = np.min(im_shape[0:2]) 183 | im_size_max = np.max(im_shape[0:2]) 184 | if keep_ratio: 185 | # scale method 1: 186 | # scale the shorter side to target size by the constraint of the max size 187 | im_scale = float(target_size) / float(im_size_min) 188 | if np.round(im_scale * im_size_max) > max_size: 189 | im_scale = float(max_size) / float(im_size_max) 190 | im_scale_x = np.floor(im.shape[1] * im_scale / multiple) * multiple / im.shape[1] 191 | im_scale_y = np.floor(im.shape[0] * im_scale / multiple) * multiple / im.shape[0] 192 | im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_LINEAR) 193 | im_scale = np.array([im_scale_x, im_scale_y, im_scale_x, im_scale_y]) 194 | 195 | # scale method 2: 196 | # scale the longer side to target size 197 | # im_scale = float(target_size) / float(im_size_max) 198 | # im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR) 199 | # im_scale = np.array([im_scale, im_scale, im_scale, im_scale]) 200 | 201 | else: 202 | target_size = int(np.floor(float(target_size) / multiple) * multiple) 203 | im_scale_x = float(target_size) / float(im_shape[1]) 204 | im_scale_y = float(target_size) / float(im_shape[0]) 205 | im = cv2.resize(im, (target_size, target_size), interpolation=cv2.INTER_LINEAR) 206 | im_scale = np.array([im_scale_x, im_scale_y, im_scale_x, im_scale_y]) 207 | return im, im_scale 208 | 209 | 210 | class Rescale(object): 211 | def __init__(self, target_size, keep_ratio): 212 | self.target_size = target_size 213 | self.keep_ratio = keep_ratio 214 | self.max_size = 2000 # for scale method 1 215 | 216 | def __call__(self, image): 217 | im, im_scale = rescale(image, target_size=self.target_size, max_size=self.max_size, 218 | keep_ratio=self.keep_ratio) 219 | return im, im_scale 220 | 221 | 222 | class Normalize(object): 223 | def __init__(self): 224 | self._transform = transforms.Compose([ 225 | transforms.ToTensor(), 226 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # mean / std 227 | 228 | def __call__(self, im): 229 | im = self._transform(im) 230 | return im 231 | 232 | 233 | class Reshape(object): 234 | def __init__(self, unsqueeze=True): 235 | self._unsqueeze = unsqueeze 236 | return 237 | 238 | def __call__(self, ims): 239 | if not torch.is_tensor(ims): 240 | ims = torch.from_numpy(ims.transpose((2, 0, 1))) 241 | if self._unsqueeze: 242 | ims = ims.unsqueeze(0) 243 | return ims 244 | 245 | -------------------------------------------------------------------------------- /warmup.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class WarmupLR(_LRScheduler): 6 | def __init__(self, scheduler, init_lr=1e-3, num_warmup=1, warmup_strategy='linear'): 7 | if warmup_strategy not in ['linear', 'cos', 'constant']: 8 | raise ValueError( 9 | "Expect warmup_strategy to be one of ['linear', 'cos', 'constant'] but got {}".format(warmup_strategy)) 10 | self._scheduler = scheduler 11 | self._init_lr = init_lr 12 | self._num_warmup = num_warmup 13 | self._step_count = 0 14 | 15 | # Define the strategy to warm up learning rate 16 | self._warmup_strategy = warmup_strategy 17 | if warmup_strategy == 'cos': 18 | self._warmup_func = self._warmup_cos 19 | elif warmup_strategy == 'linear': 20 | self._warmup_func = self._warmup_linear 21 | else: 22 | self._warmup_func = self._warmup_const 23 | 24 | # save initial learning rate of each param group 25 | # only useful when each param groups having different learning rate 26 | self._format_param() 27 | 28 | def __getattr__(self, name): 29 | return getattr(self._scheduler, name) 30 | 31 | def state_dict(self): 32 | """Returns the state of the scheduler as a :class:`dict`. 33 | It contains an entry for every variable in self.__dict__ which 34 | is not the optimizer. 35 | """ 36 | wrapper_state_dict = {key: value for key, value in self.__dict__.items() if 37 | (key != 'optimizer' and key != '_scheduler')} 38 | wrapped_state_dict = {key: value for key, value in self._scheduler.__dict__.items() if key != 'optimizer'} 39 | return {'wrapped': wrapped_state_dict, 'wrapper': wrapper_state_dict} 40 | 41 | def load_state_dict(self, state_dict): 42 | """Loads the schedulers state. 43 | Arguments: 44 | state_dict (dict): scheduler state. Should be an object returned 45 | from a call to :meth:`state_dict`. 46 | """ 47 | self.__dict__.update(state_dict['wrapper']) 48 | self._scheduler.__dict__.update(state_dict['wrapped']) 49 | 50 | def _format_param(self): 51 | # learning rate of each param group will increase 52 | # from the min_lr to initial_lr 53 | for group in self._scheduler.optimizer.param_groups: 54 | group['warmup_max_lr'] = group['lr'] 55 | group['warmup_initial_lr'] = min(self._init_lr, group['lr']) 56 | 57 | def _warmup_cos(self, start, end, pct): 58 | """cosine annealing function: 59 | current = end + 0.5 * (start + end) * (1 + cos(t_current / t_total * pi)). """ 60 | cos_out = math.cos(math.pi * pct) + 1 61 | return end + (start - end) / 2.0 * cos_out 62 | 63 | def _warmup_const(self, start, end, pct): 64 | return start if pct < 0.9999 else end 65 | 66 | def _warmup_linear(self, start, end, pct): 67 | return (end - start) * pct + start 68 | 69 | def get_lr(self): 70 | lrs = [] 71 | step_num = self._step_count 72 | # warm up learning rate 73 | if step_num <= self._num_warmup: 74 | for group in self._scheduler.optimizer.param_groups: 75 | computed_lr = self._warmup_func(group['warmup_initial_lr'], 76 | group['warmup_max_lr'], 77 | step_num / self._num_warmup) 78 | lrs.append(computed_lr) 79 | else: 80 | lrs = self._scheduler.get_lr() 81 | return lrs 82 | 83 | def step(self, *args): 84 | if self._step_count <= self._num_warmup: 85 | values = self.get_lr() 86 | for param_group, lr in zip(self._scheduler.optimizer.param_groups, values): 87 | param_group['lr'] = lr 88 | self._step_count += 1 89 | else: 90 | # method 1: 91 | # self._scheduler.step(epoch=self._step_count) 92 | # self._step_count += 1 93 | 94 | # method 2: 95 | self._scheduler._step_count = self._step_count + 1 96 | self._scheduler.last_epoch = self._step_count 97 | self._scheduler.step() 98 | self._step_count += 1 --------------------------------------------------------------------------------