├── Logger.py
├── Makefile
├── README.md
├── configs
├── retinanet_r50_fpn_hrsc.yml
└── retinanet_r50_fpn_ssdd.yml
├── datasets
├── HRSC_dataset.py
├── SSDD_dataset.py
├── __pycache__
│ ├── HRSC_dataset.cpython-37.pyc
│ ├── SSDD_dataset.cpython-37.pyc
│ └── collater.cpython-37.pyc
├── collater.py
├── convert.py
├── prepare_dataset.py
└── test_collater.py
├── detect.py
├── eval.py
├── models
├── __pycache__
│ ├── anchors.cpython-37.pyc
│ ├── fpn.cpython-37.pyc
│ ├── heads.cpython-37.pyc
│ ├── losses.cpython-37.pyc
│ ├── model.cpython-37.pyc
│ └── resnet.cpython-37.pyc
├── anchors.py
├── fpn.py
├── heads.py
├── losses.py
├── model.py
└── resnet.py
├── requirements.txt
├── resnet_pretrained_pth
├── .gitignore
└── README.md
├── resource
├── HRSC_Result.png
└── RSSDD_Result.png
├── setup.py
├── show.py
├── show_result
├── HRSC
│ ├── demo1.jpg
│ ├── demo2.jpg
│ └── demo3.jpg
└── RSSDD
│ ├── demo1.jpg
│ ├── demo2.jpg
│ └── demo3.jpg
├── train.py
├── utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── augment.cpython-37.pyc
│ ├── bbox_transforms.cpython-37.pyc
│ ├── box_coder.cpython-37.pyc
│ ├── map.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── augment.py
├── bbox_transforms.py
├── box_coder.py
├── map.py
├── rotation_nms
│ ├── .gitignore
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-37.pyc
│ └── cpu_nms.pyx
├── rotation_overlaps
│ ├── .gitignore
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-37.pyc
│ └── rbox_overlaps.pyx
└── utils.py
└── warmup.py
/Logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import colorlog
3 |
4 | """ Logger Rank: (low -> high)
5 | # 1. DEBUG
6 | # 2. INFO
7 | # 3. WARNING
8 | # 4. ERROR
9 | # 5. CRITICAL
10 | """
11 |
12 |
13 | class Logger(object):
14 | def __init__(self, log_path, logging_name):
15 | self.log_path = log_path
16 | self.logging_name = logging_name
17 | self.dash_line = '-' * 60 + '\n'
18 | self.level_color = {'DEBUG': 'cyan',
19 | 'INFO': 'bold_white',
20 | 'WARNING': 'yellow',
21 | 'ERROR': 'red',
22 | 'CRITICAL': 'red'}
23 |
24 | def logger_config(self):
25 | logger = logging.getLogger(self.logging_name)
26 | logger.setLevel(level=logging.DEBUG)
27 | handler = logging.FileHandler(self.log_path, encoding='UTF-8')
28 | handler.setLevel(logging.DEBUG)
29 | file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s',
30 | datefmt="%Y-%m-%d %H:%M:%S")
31 | handler.setFormatter(file_formatter)
32 |
33 | console_formatter = colorlog.ColoredFormatter(
34 | '%(log_color)s[%(asctime)s] - [%(name)s] - [%(levelname)s]:\n%(message)s', datefmt="%Y-%m-%d %H:%M:%S",
35 | log_colors=self.level_color)
36 |
37 | console = logging.StreamHandler()
38 | console.setFormatter(console_formatter)
39 | console.setLevel(logging.INFO)
40 |
41 | logger.addHandler(handler)
42 | logger.addHandler(console)
43 | return logger
44 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | all:
2 | python setup.py build_ext --inplace
3 | rm -rf build
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## :rocket:RetinaNet Oriented Detector Based PyTorch
2 | This is an oriented detector **Rotation-RetinaNet** implementation on Optical and SAR **ship dataset**.
3 | - SAR ship dataset (SSDD): [SSDD Dataset link](https://github.com/TianwenZhang0825/Official-SSDD)
4 | - Optical ship dataset (HRSC): [HRSC Dataset link](https://www.kaggle.com/guofeng/hrsc2016)
5 | - RetinaNet Detector original paper link is [here](https://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf).
6 | ## :star2:Performance of the implemented Rotation-RetinaNet Detector
7 |
8 | ### Detection Performance on HRSC Dataset.
9 |
10 |
11 | ### Detection Performance on SSDD Dataset.
12 |
13 |
14 | ## :dart:Experiment
15 |
16 | | Dataset | Backbone | Input Size | bs | Trick | mAP.5 | Config |
17 | |:-------:|:--------:|:----------:|:--:|:-----:|:-----:|:------:|
18 | | SSDD | ResNet-50| 512 x 512 | 16 | N | 78.96 |[config file](/configs/retinanet_r50_fpn_ssdd.yml)|
19 | | SSDD | ResNet-50| 512 x 512 | 16 |Augment| 85.6 |[config file](/configs/retinanet_r50_fpn_ssdd.yml)|
20 | | HRSC | ResNet-50| 512 x 512 | 16 | N | 70.71 |[config file](/configs/retinanet_r50_fpn_hrsc.yml)|
21 | | HRSC | ResNet-50| 512 x 512 | 4 | N | 74.22 |[config file](/configs/retinanet_r50_fpn_hrsc.yml)|
22 | | HRSC | ResNet-50| 512 x 512 | 16 |Augment| 80.20 |[config file](/configs/retinanet_r50_fpn_hrsc.yml)|
23 |
24 | ## :boom:Get Started
25 | ### Installation
26 | #### A. Install requirements:
27 | ```
28 | conda create -n rotate python=3.7
29 | conda activate rotate
30 | conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=11.0 -c pytorch
31 | pip install -r requirements.txt
32 |
33 | Note: the opencv version must > 4.5.1
34 | ```
35 | #### B. Install rotation\_nms and rotation\_overlaps module:
36 | ```
37 | Only need one Step:
38 | make
39 | ```
40 | ## Demo
41 | ### A. Set project's data path
42 | you should set project's data path in `yml` file first.
43 | ```
44 | # .yml file
45 | # Note: all the path should be absolute path.
46 | data_path = r'/$ROOT_PATH/SSDD_data/' # absolute data root path
47 | output_path = r'/$ROOT_PATH/Output/' # absolute model output path
48 |
49 | # For example
50 | $ROOT_PATH
51 | -HRSC/
52 | -train/ # train set
53 | -Annotations/
54 | -*.xml
55 | -images/
56 | -*.jpg
57 | -test/ # test set
58 | -Annotations/
59 | -*.xml
60 | -images/
61 | -*.jpg
62 | -ground-truth/
63 | -*.txt # gt label in txt format (for voc evaluation method)
64 |
65 | -SSDD/
66 | -train/ # train set
67 | -Annotations/
68 | -*.xml
69 | -images/
70 | -*.jpg
71 | -test/ # test set
72 | -Annotations/
73 | -*.xml
74 | -images/
75 | -*.jpg
76 | -ground-truth/
77 | -*.txt # gt label in txt format (for voc evaluation method)
78 |
79 |
80 | -Output/
81 | -checkpoints/
82 | - the path of saving chkpt files
83 | -tensorboard/
84 | - the path of saving tensorboard event files
85 | -evaluate/
86 | - the path of saving model detection results for evaluate (voc method method)
87 | -log.log (save the loss and eval result)
88 | -yml file (config file)
89 | ```
90 | ### B. Run the show.py
91 | ```
92 | # for SSDD dataset
93 | python show.py --config_file ./configs/retinanet_r50_fpn_ssdd.yml --chkpt {chkpt.file} --result_path show_result/RSSDD --pic_name demo1.jpg
94 |
95 | # for HRSC dataset
96 | python show.py --config_file ./configs/retinanet_r50_fpn_hrsc.yml --chkpt {chkpt.file} --result_path show_result/HRSC --pic_name demo1.jpg
97 | ```
98 | ## Train
99 | ### A. Prepare dataset
100 | you should structure your dataset files as shown above.
101 | ### B. Manual set project's hyper parameters
102 | you should manual set projcet's hyper parameters in `config` file.
103 | ```
104 | 1. data file structure (Must Be Set !)
105 | has shown above.
106 |
107 | 2. Other settings (Optional)
108 | if you want to follow my experiment, dont't change anything.
109 | ```
110 | ### C. Train Rotation-RetinaNet on SSDD or HRSC dataset with resnet-50 from scratch
111 | #### C.1 Download the pre-trained resnet-50 pth file
112 | you should download the pre-trained resnet-50 pth first and put the pth file in `resnet_pretrained_pth/` folder.
113 | #### C.2 Train Rotation-RetinaNet Detector on SSDD or HRSC Dataset with pre-trained pth file
114 | ```
115 | # train model on SSDD dataset from scratch
116 | python train.py --config_file ./configs/retinanet_r50_fpn_ssdd.yml --resume None
117 |
118 | # train model on HRSC dataset from scratch
119 | python train.py --config_file ./configs/retinanet_r50_hrsc.yml --resume None
120 |
121 | ```
122 | ### D. Resume training Rotation-RetinaNet detector on SSDD or HRSC dataset
123 | ```
124 | # train model on SSDD dataset from specific epoch
125 | python train.py --config_file ./configs/retinanet_r50_fpn_ssdd.yml --resume {epoch}_{step}.pth
126 |
127 | # train model on HRSC dataset from specific epoch
128 | python train.py --config_file ./configs/retinanet_r50_hrsc.yml --resume {epoch}_{step}.pth
129 |
130 | ```
131 | ## Evaluation
132 | ### A. evaluate model performance on SSDD or HRSC val set.
133 | ```
134 | python eval.py --Dataset SSDD --config_file ./configs/retinanet_r50_fpn_ssdd.yml --evaluate True --chkpt {epoch}_{step}.pth
135 | python eval.py --Dataset HRSC --config_file ./configs/retinanet_r50_fpn_hrsc.yml --evaluate True --chkpt {epoch}_{step}.pth
136 | ```
137 | ## :bulb:Inferences
138 | Thanks for these great work.
139 | [https://github.com/open-mmlab/mmrotate](https://github.com/open-mmlab/mmrotate)
140 | [https://github.com/ming71/Rotated-RetinaNet](https://github.com/ming71/Rotated-RetinaNet)
141 |
142 | ## :fast\_forward:Zhihu Link
143 | [zhihu article](https://zhuanlan.zhihu.com/p/490422549?)
144 |
--------------------------------------------------------------------------------
/configs/retinanet_r50_fpn_hrsc.yml:
--------------------------------------------------------------------------------
1 | backbone: {'type': 'resnet50',
2 | 'pretrained': True}
3 |
4 | neck: {'type': 'fpn',
5 | 'init_method': 'xavier_init',
6 | 'extra_conv_init_method': 'xavier_init'}
7 |
8 | head: {'type': 'retinanet',
9 | 'num_stacked': 4,
10 | 'cls_branch_init_method': 'normal_init',
11 | 'reg_branch_init_method': 'normal_init'}
12 |
13 | loss: {'cls': {'alpha': 0.25, 'gamma': 2.0},
14 | 'reg': {'type': 'smooth'}}
15 |
16 |
17 | assigner: {'pos_iou_thr': 0.3,
18 | 'neg_iou_thr': 0.2,
19 | 'min_pos_iou': 0.0,
20 | 'low_quality_match': True}
21 |
22 | # warmup settings
23 | warm_up: True
24 | warmup_epoch: 2
25 | warmup_lr: 0.00001
26 |
27 | # data settings
28 | dataset: HRSC
29 | classes: ['ship']
30 | image_size: 512
31 | keep_ratio: False
32 | batch_size: 16
33 | augment: False
34 |
35 | data_path: '{Data Root Path.}'
36 | output_path: '{Project Output Path.}'
37 | # weight the delta values.
38 |
39 | optimizer: adam
40 | lr: 0.0001
41 | epoch: 100
42 | evaluation_train_start: 101
43 | evaluation_val_start: 44
44 | save_interval: 4
45 | val_interval: 4
46 | eval_method: voc
47 | freeze_bn: True
48 | device: [3]
49 |
50 | # anchor settings
51 | base_size: 4
52 | ratios: [0.2, 0.5, 1.0, 2.0, 5.0]
53 | #ratios: [0.5, 1.0, 2.0]
54 | scales_per_octave: 3
55 | angle: 0 # opencv version > 4.5.1
56 |
57 | rotation_nms_thr: 0.5
58 | score_thr: 0.05
59 |
60 | tensorboard: 'tensorboard'
61 | checkpoint: 'checkpoints'
62 | log: 'log.log'
63 |
--------------------------------------------------------------------------------
/configs/retinanet_r50_fpn_ssdd.yml:
--------------------------------------------------------------------------------
1 | backbone: {'type': 'resnet50',
2 | 'pretrained': True}
3 |
4 | neck: {'type': 'fpn',
5 | 'init_method': 'xavier_init',
6 | 'extra_conv_init_method': 'xavier_init'}
7 |
8 | head: {'type': 'retinanet',
9 | 'num_stacked': 4,
10 | 'cls_branch_init_method': 'normal_init',
11 | 'reg_branch_init_method': 'normal_init'}
12 |
13 | loss: {'cls': {'alpha': 0.25, 'gamma': 2.0},
14 | 'reg': {'type': 'smooth'}}
15 |
16 |
17 | assigner: {'pos_iou_thr': 0.3,
18 | 'neg_iou_thr': 0.2,
19 | 'min_pos_iou': 0.0,
20 | 'low_quality_match': True}
21 |
22 | # warmup settings
23 | warm_up: True
24 | warmup_epoch: 2
25 | warmup_lr: 0.00001
26 |
27 | # data settings
28 | dataset: SSDD
29 | classes: ['ship']
30 | image_size: 512
31 | keep_ratio: False
32 | batch_size: 16
33 | augment: False
34 |
35 | data_path: '{Data Root Path.}'
36 | output_path: '{Project Root Path.}'
37 |
38 | optimizer: adam
39 | lr: 0.0001
40 | epoch: 100
41 | evaluation_train_start: 101
42 | evaluation_val_start: 101
43 | save_interval: 4
44 | val_interval: 4
45 | eval_method: voc
46 | freeze_bn: True
47 | device: [2]
48 |
49 | # anchor settings
50 | base_size: 4
51 | ratios: [0.2, 0.5, 1.0, 2.0, 5.0]
52 | #ratios: [0.5, 1.0, 2.0]
53 | scales_per_octave: 3
54 | angle: 0 # opencv version > 4.5.1
55 |
56 | rotation_nms_thr: 0.5
57 | score_thr: 0.05
58 |
59 | tensorboard: 'tensorboard'
60 | checkpoint: 'checkpoints'
61 | log: 'log.log'
62 |
--------------------------------------------------------------------------------
/datasets/HRSC_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.utils.data as data
3 | import matplotlib.pyplot as plt
4 | from utils.bbox_transforms import *
5 | import cv2
6 | from utils.augment import *
7 |
8 |
9 | class HRSCDataset(data.Dataset):
10 | def __init__(self, root_path, set_name, augment=False, classes=None):
11 | self.root_path = root_path
12 | self.set_name = set_name
13 | self.augment = augment
14 | self.image_lists = self._load_image_names()
15 | self.classes = classes
16 | self.num_classes = len(self.classes)
17 | self.class_to_ind = dict(zip(self.classes, range(self.num_classes)))
18 | if self.augment is True:
19 | print(f'[Info]: Using the data augmentation.')
20 | else:
21 | print(f'[Info]: Not using the data augmentation.')
22 |
23 | def __len__(self):
24 | return len(self.image_lists)
25 |
26 | def __getitem__(self, index):
27 | imagename = self.image_lists[index]
28 | img_path = os.path.join(self.root_path, self.set_name, "images", imagename)
29 | image = cv2.cvtColor(cv2.imread(img_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
30 | roidb = self._load_annotation(imagename)
31 | gt_inds = np.where(roidb['gt_classes'] != 0)[0]
32 | num_gt = len(roidb['boxes'])
33 | gt_boxes = np.zeros((len(gt_inds), 9), dtype=np.float32) # [x1,y1,x2,y2,x3,y3,x4,y4,class_index]
34 | if num_gt:
35 | # get the bboxes and classes info from the self._load_annotation() result.
36 | bboxes = roidb['boxes'][gt_inds, :]
37 | classes = roidb['gt_classes'][gt_inds] - 1
38 |
39 | # perform the data augmentation
40 | if self.augment is True:
41 | transforms = Augment([
42 | HSV(0.5, 0.5, p=0.5),
43 | HorizontalFlip(p=0.5),
44 | VerticalFlip(p=0.5)
45 | ])
46 | image, bboxes = transforms(image, bboxes)
47 | gt_boxes[:, :-1] = bboxes
48 |
49 | for i, bbox in enumerate(bboxes):
50 | gt_boxes[i, 8] = classes[i]
51 |
52 | return {'image': image, 'boxes': gt_boxes, 'imagename': imagename}
53 |
54 | def _load_image_names(self):
55 | return os.listdir(os.path.join(self.root_path, self.set_name, 'images'))
56 |
57 | def _load_annotation(self, imagename):
58 | filename = os.path.join(self.root_path, self.set_name, "Annotations", imagename.replace('jpg', 'xml'))
59 | boxes, gt_classes = [], []
60 | with open(filename, 'r', encoding='utf-8-sig') as f:
61 | content = f.read()
62 | objects = content.split('')
63 | info = objects.pop(0)
64 | for obj in objects:
65 | cls_id = obj[obj.find('') + 10: obj.find('')]
66 | cx = float(eval(obj[obj.find('') + 9: obj.find('')]))
67 | cy = float(eval(obj[obj.find('') + 9: obj.find('')]))
68 | w = float(eval(obj[obj.find('') + 8: obj.find('')]))
69 | h = float(eval(obj[obj.find('') + 8: obj.find('')]))
70 | angle = float(obj[obj.find('') + 10: obj.find('')]) # radian
71 |
72 | # add extra score parameter to use obb2poly_up
73 | bbox = np.array([[cx, cy, w, h, angle, 0]], dtype=np.float32)
74 | polygon = obb2poly_np(bbox, 'le90')[0, :-1].astype(np.float32)
75 | boxes.append(polygon)
76 | label_index = 1
77 | gt_classes.append(label_index)
78 | return {'boxes': np.array(boxes), 'gt_classes': np.array(gt_classes)}
79 |
80 |
81 | if __name__ == '__main__':
82 | hrsc = HRSCDataset(root_path='/data/fzh/HRSC/',
83 | set_name='train',
84 | augment=True,
85 | classes=['ship', ])
86 | for idx in range(len(hrsc)):
87 | a = hrsc[idx]
88 | bboxes = a['boxes'] # polygon format [x1, y1, x2, y2, x3, y3, x4, y4]
89 | img = a['image']
90 | image_name = a['imagename']
91 | for gt_bbox in bboxes:
92 | ps = gt_bbox[:-1].reshape(1, 4, 2).astype(np.int32)
93 | cv2.drawContours(img, [ps], -1, [0, 255, 0], thickness=2)
94 | plt.imshow(img)
95 | plt.title(image_name)
96 | plt.show()
97 |
--------------------------------------------------------------------------------
/datasets/SSDD_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.utils.data as data
3 | import matplotlib.pyplot as plt
4 | from utils.bbox_transforms import *
5 | import cv2
6 | import xml.etree.ElementTree as ET
7 | from utils.augment import *
8 |
9 |
10 | class SSDDataset(data.Dataset):
11 | def __init__(self, root_path, set_name, augment=False, classes=None):
12 | self.root_path = root_path
13 | self.set_name = set_name
14 | self.augment = augment
15 | self.image_lists = self._load_image_names()
16 | self.classes = classes
17 | self.num_classes = len(self.classes)
18 | self.class_to_ind = dict(zip(self.classes, range(self.num_classes)))
19 | if self.augment is True:
20 | print(f'[Info]: Using the data augmentation.')
21 | else:
22 | print(f'[Info]: Not using the data augmentation.')
23 |
24 | def __len__(self):
25 | return len(self.image_lists)
26 |
27 | def __getitem__(self, index):
28 | imagename = self.image_lists[index]
29 | img_path = os.path.join(self.root_path, self.set_name, "images", imagename)
30 | image = cv2.cvtColor(cv2.imread(img_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
31 | roidb = self._load_annotation(imagename)
32 | gt_inds = np.where(roidb['gt_classes'] != 0)[0]
33 | num_gt = len(roidb['boxes'])
34 | gt_boxes = np.zeros((len(gt_inds), 9), dtype=np.float32) # [x1,y1,x2,y2,x3,y3,x4,y4,class_index]
35 | if num_gt:
36 | # get the bboxes and classes info from the self._load_annotation() result.
37 | bboxes = roidb['boxes'][gt_inds, :]
38 | classes = roidb['gt_classes'][gt_inds] - 1
39 |
40 | # perform the data augmentation
41 | if self.augment is True:
42 | transforms = Augment([
43 | # HSV(0.5, 0.5, p=0.5),
44 | HorizontalFlip(p=0.5),
45 | VerticalFlip(p=0.5)
46 | ])
47 | image, bboxes = transforms(image, bboxes)
48 |
49 | gt_boxes[:, :-1] = bboxes
50 |
51 | for i, bbox in enumerate(bboxes):
52 | gt_boxes[i, 8] = classes[i]
53 |
54 | return {'image': image, 'boxes': gt_boxes, 'imagename': imagename}
55 |
56 | def _load_image_names(self):
57 | return os.listdir(os.path.join(self.root_path, self.set_name, 'images'))
58 |
59 | def _load_annotation(self, imagename):
60 | filename = os.path.join(self.root_path, self.set_name, "Annotations", imagename.replace('jpg', 'xml'))
61 | boxes, gt_classes = [], []
62 | infile = open(os.path.join(filename))
63 | tree = ET.parse(infile)
64 | root = tree.getroot()
65 | for obj in root.iter('object'):
66 | rbox = obj.find('rotated_bndbox')
67 | x1 = float(rbox.find('x1').text)
68 | y1 = float(rbox.find('y1').text)
69 | x2 = float(rbox.find('x2').text)
70 | y2 = float(rbox.find('y2').text)
71 | x3 = float(rbox.find('x3').text)
72 | y3 = float(rbox.find('y3').text)
73 | x4 = float(rbox.find('x4').text)
74 | y4 = float(rbox.find('y4').text)
75 | polygon = np.array([x1, y1, x2, y2, x3, y3, x4, y4], dtype=np.int32)
76 | boxes.append(polygon)
77 | label_index = 1
78 | gt_classes.append(label_index)
79 | return {'boxes': np.array(boxes), 'gt_classes': np.array(gt_classes)}
80 |
81 |
82 | if __name__ == '__main__':
83 | rssdd = SSDDataset(root_path='/data/fzh/RSSDD/',
84 | set_name='train',
85 | augment=True,
86 | classes=['ship', ])
87 | for idx in range(len(rssdd)):
88 | idx = 0
89 | a = rssdd[idx]
90 | bboxes = a['boxes']
91 | img = a['image']
92 | for gt_bbox in bboxes:
93 |
94 | ps = gt_bbox[:-1].reshape(-1, 4, 2).astype(np.int32)
95 | cv2.drawContours(img, [ps], -1, [0, 255, 0], thickness=2)
96 |
97 | plt.imshow(img)
98 | plt.show()
--------------------------------------------------------------------------------
/datasets/__pycache__/HRSC_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/datasets/__pycache__/HRSC_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/SSDD_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/datasets/__pycache__/SSDD_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/collater.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/datasets/__pycache__/collater.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/collater.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from utils.utils import Rescale, Normalize, Reshape
4 | from torchvision.transforms import Compose
5 | import cv2
6 | import matplotlib.pyplot as plt
7 | import math
8 | from utils.bbox_transforms import poly2obb_np
9 |
10 |
11 | class Collater(object):
12 | def __init__(self, scales, keep_ratio=False, multiple=32):
13 | self.scales = scales
14 | self.keep_ratio = keep_ratio
15 | self.multiple = multiple
16 |
17 | def __call__(self, batch):
18 | scales = int(np.floor(float(self.scales) / self.multiple) * self.multiple)
19 | rescale = Rescale(target_size=scales, keep_ratio=self.keep_ratio)
20 | transform = Compose([Normalize(), Reshape(unsqueeze=False)])
21 |
22 | images = [sample['image'] for sample in batch]
23 | bboxes = [sample['boxes'] for sample in batch]
24 | image_names = [sample['imagename'] for sample in batch]
25 |
26 | max_height, max_width = -1, -1
27 |
28 | for index in range(len(batch)):
29 | im, _ = rescale(images[index])
30 | height, width = im.shape[0], im.shape[1]
31 | max_height = height if height > max_height else max_height
32 | max_width = width if width > max_width else max_width
33 |
34 | padded_ims = torch.zeros(len(batch), 3, max_height, max_width)
35 |
36 | # ready to save the openCV format info [xc, yc, w, h, theta, class_index]
37 | num_params = 6
38 | max_num_boxes = max(bbox.shape[0] for bbox in bboxes)
39 | padded_boxes = torch.ones(len(batch), max_num_boxes, num_params) * -1
40 |
41 | for i in range(len(batch)):
42 | im, bbox = images[i], bboxes[i]
43 |
44 | # rescale the image
45 | im, im_scale = rescale(im)
46 | height, width = im.shape[0], im.shape[1]
47 | padded_ims[i, :, :height, :width] = transform(im) # transform is similar to the pipeline in mmdet
48 |
49 | # rescale the bounding box
50 | oc_bboxes = []
51 | labels = []
52 | for single in bbox:
53 |
54 | # rescale the bounding box
55 | single[0::2] *= im_scale[0]
56 | single[1::2] *= im_scale[1]
57 |
58 | # polygons to the opencv format, opencv version > 4.5.1
59 | oc_bbox = poly2obb_np(single[:-1], 'oc') # oc_bbox: [xc, yc, h, w, angle(radian)]
60 | assert 0 < oc_bbox[4] <= np.pi / 2
61 | oc_bboxes.append(np.array(oc_bbox, dtype=np.float32))
62 | labels.append(single[-1])
63 |
64 | if bbox.shape[0] != 0:
65 | padded_boxes[i, :bbox.shape[0], :-1] = torch.from_numpy(np.array(oc_bboxes))
66 | padded_boxes[i, :bbox.shape[0], -1] = torch.from_numpy(np.array(labels))
67 |
68 | # # visualize rescale result
69 | # vis_im = images[i]
70 | # vis_im, _ = rescale(vis_im)
71 | # for gt_bbox in oc_bboxes:
72 | # xc, yc, h, w, ag = gt_bbox[:5]
73 | # print(f'GT Annotation: xc:{xc} yc:{yc} h:{h} w:{w} ag:{ag}')
74 | # wx, wy = -w / 2 * math.sin(ag), w / 2 * math.cos(ag)
75 | # hx, hy = h / 2 * math.cos(ag), h / 2 * math.sin(ag)
76 | # p1 = (xc - wx - hx, yc - wy - hy)
77 | # p2 = (xc - wx + hx, yc - wy + hy)
78 | # p3 = (xc + wx + hx, yc + wy + hy)
79 | # p4 = (xc + wx - hx, yc + wy - hy)
80 | # ps = np.int0(np.array([p1, p2, p3, p4]))
81 | # cv2.drawContours(vis_im, [ps], -1, [0, 255, 0], thickness=2)
82 | # plt.imshow(vis_im)
83 | # plt.title(image_names[i])
84 | # plt.show()
85 |
86 | return {'image': padded_ims, 'bboxes': padded_boxes, 'image_name': image_names}
87 |
--------------------------------------------------------------------------------
/datasets/convert.py:
--------------------------------------------------------------------------------
1 | """This script is used to convert xml format to txt format for HRSC Dataset for evaluation."""
2 |
3 | import os
4 | import numpy as np
5 | import math
6 | import cv2
7 | import matplotlib.pyplot as plt
8 |
9 |
10 | class Convert(object):
11 | def __init__(self, xml_path, txt_path, image_path):
12 | self.xml_path = xml_path
13 | self.txt_path = txt_path
14 | self.image_path = image_path
15 | self.xml_lists = os.listdir(xml_path)
16 | self._makedir()
17 |
18 | def _makedir(self):
19 | if not os.path.exists(self.txt_path):
20 | os.makedirs(self.txt_path)
21 |
22 | def _readXml(self, single_xml):
23 | with open(os.path.join(self.xml_path, single_xml), 'r', encoding='utf-8-sig') as f:
24 | content = f.read()
25 | objects = content.split('')
26 | info = objects.pop(0)
27 |
28 | results = []
29 | for obj in objects:
30 | cls_name = 'ship'
31 | cx = round(eval(obj[obj.find('') + 9: obj.find('')]))
32 | cy = round(eval(obj[obj.find('') + 9: obj.find('')]))
33 | w = round(eval(obj[obj.find('') + 8: obj.find('')]))
34 | h = round(eval(obj[obj.find('') + 8: obj.find('')]))
35 | angle = eval(obj[obj.find('') + 10: obj.find('')]) / math.pi * 180
36 | rbox = np.array([cx, cy, w, h, angle])
37 | quad_box = rbox_2_quad(rbox, 'xywha').squeeze()
38 | line = cls_name + ' ' + str(quad_box[0]) + ' ' + str(quad_box[1]) + ' ' + str(quad_box[2]) + ' ' +\
39 | str(quad_box[3]) + ' ' + str(quad_box[4]) + ' ' + str(quad_box[5]) + ' ' + str(quad_box[6]) +\
40 | ' ' + str(quad_box[7]) + '\n'
41 | results.append(line)
42 | return results
43 |
44 | def writeTxt(self):
45 | for single_xml in self.xml_lists:
46 | lines = self._readXml(single_xml)
47 | txt_file = single_xml.replace('xml', 'txt')
48 | with open(os.path.join(self.txt_path, txt_file), 'w') as f:
49 | for single_line in lines:
50 | f.write(single_line)
51 |
52 | def plotgt(self):
53 | for single_xml in self.xml_lists:
54 | single_image = single_xml.replace('xml', 'jpg')
55 | image = cv2.cvtColor(cv2.imread(os.path.join(self.image_path, single_image), cv2.IMREAD_COLOR),
56 | cv2.COLOR_BGR2RGB)
57 | lines = self._readXml(single_xml)
58 | for single_line in lines:
59 | single_line = single_line.strip().split(' ')
60 | box = np.array(list(map(float, single_line[1:])))
61 | cv2.polylines(image, [box.reshape(-1, 2).astype(np.int32)], True, (255, 0, 0), 3)
62 | plt.imshow(image)
63 | plt.show()
64 |
65 |
66 | if __name__ == '__main__':
67 | convert = Convert(xml_path='/data/fzh/HRSC/train/Annotations/',
68 | txt_path='/data/fzh/HRSC/train/train-ground-truth/',
69 | image_path='/data/fzh/HRSC/train/images/')
70 |
71 | convert.writeTxt()
72 |
--------------------------------------------------------------------------------
/datasets/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | """This script is used to convert .bmp format to .jpg format."""
2 |
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import shutil
6 | import os
7 | import cv2
8 |
9 |
10 | class Convert(object):
11 | def __init__(self, root_path):
12 | self.root_path = root_path
13 | self.convert_image_folder = 'images'
14 | self.image_folder = 'AllImages'
15 | self._mkdir()
16 |
17 | def _mkdir(self):
18 | self.train_image_path = os.path.join(self.root_path, 'train', self.convert_image_folder)
19 | self.val_image_path = os.path.join(self.root_path, 'test', self.convert_image_folder)
20 |
21 | if not os.path.exists(self.train_image_path):
22 | os.makedirs(self.train_image_path)
23 |
24 | if not os.path.exists(self.val_image_path):
25 | os.makedirs(self.val_image_path)
26 |
27 | def convert(self, set_name):
28 | image_lists = os.listdir(os.path.join(self.root_path, set_name, self.image_folder))
29 | for single_image in image_lists:
30 | image = cv2.imread(os.path.join(self.root_path, set_name, self.image_folder, single_image))
31 | converted_single_image = single_image.replace('bmp', 'jpg')
32 | cv2.imwrite(os.path.join(self.root_path, set_name, self.convert_image_folder, converted_single_image),
33 | image)
34 |
35 |
36 | if __name__ == '__main__':
37 | convert = Convert(root_path='/home/fzh/Data/HRSC/')
38 | convert.convert(set_name='test')
39 |
40 |
--------------------------------------------------------------------------------
/datasets/test_collater.py:
--------------------------------------------------------------------------------
1 | from datasets.HRSC_dataset import HRSCDataset
2 | from datasets.SSDD_dataset import SSDDataset
3 | from datasets.collater import Collater
4 |
5 | if __name__ == '__main__':
6 | training_set = SSDDataset(root_path='/data/fzh/RSSDD/',
7 | set_name='train',
8 | augment=True,
9 | classes=['ship'])
10 |
11 | """Check some outputs from custom collater.
12 | 1. User can specify the test_idx manually.
13 | 2. User can visualize scale image result to cancel annotation line (57-65) in collater.py"""
14 | test_idxs = [0, 1, 2, 3, 4, 5, 6]
15 | batch = [training_set[idx] for idx in test_idxs]
16 | collater = Collater(scales=512, keep_ratio=False, multiple=32)
17 | result = collater(batch)
18 | print(result)
19 |
--------------------------------------------------------------------------------
/detect.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision.transforms import Compose
4 | from utils.utils import Rescale, Normalize, Reshape
5 | from utils.rotation_nms.cpu_nms import cpu_nms
6 | from utils.bbox_transforms import *
7 |
8 |
9 | def im_detect(model, src, target_sizes, params, use_gpu=True, conf=None, device=None):
10 | if isinstance(target_sizes, int):
11 | target_sizes = [target_sizes]
12 | if len(target_sizes) == 1:
13 | return single_scale_detect(model, src, target_size=target_sizes[0], params=params,
14 | use_gpu=use_gpu, conf=conf, device=device)
15 |
16 |
17 | def single_scale_detect(model, src, target_size, params=None,
18 | use_gpu=True, conf=None, device=None):
19 | im, im_scales = Rescale(target_size=target_size, keep_ratio=params.keep_ratio)(src)
20 | im = Compose([Normalize(), Reshape(unsqueeze=True)])(im)
21 | if use_gpu and torch.cuda.is_available():
22 | model, im = model.cuda(device=device), im.cuda(device=device)
23 | with torch.no_grad(): # bboxes: [x, y, x, y, a, a_x, a_y, a_x, a_y, a_a]
24 | scores, classes, boxes = model(im, test_conf=conf)
25 | scores = scores.data.cpu().numpy()
26 | classes = classes.data.cpu().numpy()
27 | boxes = boxes.data.cpu().numpy()
28 |
29 | # convert oc format to polygon for rescale predict box coordinate
30 | predicted_bboxes = []
31 | for idx in range(len(boxes)):
32 | single_box = boxes[idx] # single box: [pred_xc, pred_yc, pred_h, pred_w, pred_angle(radian)]
33 | single_box = np.array([[single_box[0], single_box[1], single_box[2], single_box[3], single_box[4], 0]],
34 | dtype=np.float32) # add extra score 0
35 | predicted_polygon = obb2poly_np_oc(single_box)[0, :-1].astype(np.float32)
36 | predicted_polygon[0::2] /= im_scales[0]
37 | predicted_polygon[1::2] /= im_scales[1]
38 | predicted_bbox = poly2obb_np(predicted_polygon, 'oc') # polygon 2 rbboxes (oc format: [xc, yc, h, w, angle(radian)]
39 | predicted_bboxes.append(predicted_bbox)
40 |
41 | if boxes.shape[1] > 5:
42 | # [pred_xc, pred_yc, pred_h, pred_w, pred_angle(radian),
43 | # anchor_xc, anchor_yc, anchor_w, anchor_h, anchor_angle(radian)]
44 | boxes[:, 5:9] = boxes[:, 5:9] / im_scales
45 | scores = np.reshape(scores, (-1, 1))
46 | classes = np.reshape(classes, (-1, 1))
47 | for id in range(len(predicted_bboxes)):
48 | boxes[id, :5] = predicted_bboxes[id]
49 | cls_dets = np.concatenate([classes, scores, boxes], axis=1)
50 | keep = np.where(classes < model.num_class)[0]
51 | return cls_dets[keep, :]
52 | # cls, score, x,y,w,h,a, a_x,a_y,a_w,a_h,a_a
53 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | from detect import im_detect
3 | import shutil
4 | from tqdm import tqdm
5 | from utils.map import eval_mAP
6 | from utils.bbox_transforms import *
7 |
8 |
9 | # evaluate by rotation detection result
10 | def evaluate(model=None,
11 | target_size=None,
12 | test_path=None,
13 | conf=None,
14 | device=None,
15 | mode=None,
16 | params=None):
17 | evaluate_dir = 'voc_evaluate'
18 | _dir = mode + '_evaluate'
19 | out_dir = os.path.join(params.output_path, evaluate_dir, _dir, 'detection-results')
20 | if os.path.exists(out_dir):
21 | shutil.rmtree(out_dir)
22 | os.makedirs(out_dir)
23 |
24 | # Step1. Collect detect result for per image or get predict result
25 | for image_name in tqdm(os.listdir(os.path.join(params.data_path, mode, 'images'))):
26 | image_path = os.path.join(params.data_path, mode, 'images', image_name)
27 | image = cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
28 | dets = im_detect(model=model,
29 | src=image,
30 | params=params,
31 | target_sizes=target_size,
32 | use_gpu=True,
33 | conf=conf, # score threshold
34 | device=device)
35 |
36 | # Step2. Write per image detect result into per txt file
37 | # line = cls_name score x1 y1 x2 y2 x3 y3 x4 y4
38 | img_ext = image_name.split('.')[-1]
39 | with open(os.path.join(out_dir, image_name.replace(img_ext, 'txt')), 'w') as f:
40 | for det in dets:
41 | cls_ind = int(det[0])
42 | cls_socre = det[1]
43 | rbox = det[2:7] # [xc, yc, h, w, angle(radian)]
44 |
45 | if np.isnan(rbox[0]) or np.isnan(rbox[1]) or np.isnan(rbox[2]) or np.isnan(rbox[3]) or np.isnan(rbox[4]):
46 | line = ''
47 | else:
48 | # add extra score
49 | rbbox = np.array([[rbox[0], rbox[1], rbox[2], rbox[3], rbox[4], 0]], dtype=np.float32)
50 | polygon = obb2poly_np(rbbox, 'oc')[0, :-1].astype(np.float32)
51 | line = str(params.classes[cls_ind]) + ' ' + str(cls_socre) + ' ' + str(polygon[0]) + ' ' + str(polygon[1]) +\
52 | ' ' + str(polygon[2]) + ' ' + str(polygon[3]) + ' ' + str(polygon[4]) + ' ' + str(polygon[5]) +\
53 | ' ' + str(polygon[6]) + ' ' + str(polygon[7]) + '\n'
54 | f.write(line)
55 |
56 | # Step3. Calculate Precision, Recall, mAP, plot PR Curve
57 | mAP, Precision, Recall = eval_mAP(gt_root_dir=params.data_path,
58 | test_path=test_path, # test_path = ground-truth
59 | eval_root_dir=os.path.join(params.output_path, evaluate_dir, _dir),
60 | use_07_metric=False,
61 | thres=0.5) # rotation nms threshold
62 | print(f'mAP: {mAP}\tPrecision: {Precision}\tRecall: {Recall}')
63 | return mAP, Precision, Recall
64 |
65 |
66 | if __name__ == '__main__':
67 | import argparse
68 | import torch
69 | from train import Params
70 | import time
71 | from models.model import RetinaNet
72 |
73 | parser = argparse.ArgumentParser()
74 | parser.add_argument('--device', type=int, default=0)
75 | parser.add_argument('--Dataset', type=str, default='SSDD')
76 | parser.add_argument('--config_file', type=str, default='./configs/retinanet_r50_fpn_ssdd.yml')
77 | parser.add_argument('--target_size', type=int, default=512)
78 | parser.add_argument('--chkpt', type=str, default='best/best.pth', help='the checkpoint file of the trained model.')
79 | parser.add_argument('--score_thr', type=float, default=0.05)
80 |
81 | parser.add_argument('--evaluate', type=bool, default=True)
82 | parser.add_argument('--FPS', type=bool, default=False, help='Check the FPS of the Model.') # todo: Ready to Support
83 | args = parser.parse_args()
84 | params = Params(args.config_file)
85 | params.backbone['pretrained'] = False
86 | model = RetinaNet(params)
87 |
88 | checkpoint = os.path.join(params.output_path, 'checkpoints', args.chkpt)
89 |
90 | # from checkpoint load model weight file
91 | # model weight
92 | chkpt = torch.load(checkpoint, map_location='cpu')
93 | pth = chkpt['model']
94 | model.load_state_dict(pth)
95 | model.cuda(device=args.device)
96 |
97 | """The following codes is used to Debug eval() function."""
98 | if args.evaluate:
99 | model.eval()
100 | mAP, Precision, Recall = evaluate(
101 | model=model,
102 | target_size=[args.target_size],
103 | test_path='ground-truth',
104 | conf=args.score_thr, # score threshold
105 | device=args.device,
106 | mode='test',
107 | params=params)
108 | print(f'mAP: {mAP}\nPrecision: {Precision}\nRecall: {Recall}\n')
109 |
110 | """The following codes are used to calculate FPS of model."""
111 | if args.FPS:
112 | times = 50 # 50 is enough to balance some additional times for IO
113 | image_path = os.path.join(params.data_path, args.single_image)
114 | image = cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
115 | model.eval()
116 | t1 = time.time()
117 | for _ in range(times):
118 | dets = im_detect(model=model,
119 | image=image,
120 | target_sizes=[args.target_size],
121 | use_gpu=True,
122 | conf=0.25,
123 | device=args.device,
124 | params=params)
125 | t2 = time.time()
126 | tact_time = (t2 - t1) / times
127 | print(f'{tact_time} seconds, {1 / tact_time} FPS, Batch_size = 1')
128 |
--------------------------------------------------------------------------------
/models/__pycache__/anchors.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/anchors.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/fpn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/fpn.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/heads.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/heads.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/models/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/anchors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class Anchors(nn.Module):
7 | def __init__(self,
8 | params=None,
9 | pyramid_levels=None,
10 | strides=None,
11 | rotations=None):
12 | super(Anchors, self).__init__()
13 | self.pyramid_levels = pyramid_levels
14 | self.strides = strides
15 | self.base_size = params.base_size
16 | self.ratios = params.ratios
17 | self.scales = params.scales
18 | self.rotations = rotations
19 |
20 | if pyramid_levels is None:
21 | self.pyramid_levels = [3, 4, 5, 6, 7]
22 |
23 | if strides is None:
24 | self.strides = [2 ** x for x in self.pyramid_levels]
25 |
26 | self.base_size = params.base_size
27 | self.ratios = params.ratios
28 | self.scales = np.array([2**(i / 3) for i in range(params.scales_per_octave)])
29 | self.rotations = np.array([params.angle / 180 * np.pi])
30 |
31 | self.num_anchors = len(self.scales) * len(self.ratios) * len(self.rotations)
32 |
33 | print(f'[Info]: anchor ratios: {self.ratios}\tanchor scales: {self.scales}\tbase_size: {self.base_size}\t'
34 | f'angle: {self.rotations}')
35 | print(f'[Info]: number of anchors: {self.num_anchors}')
36 |
37 | @staticmethod
38 | def generate_anchors(base_size, ratios, scales, rotations):
39 | """
40 | Generate anchor (reference) windows by enumerating aspect ratios X
41 | scales w.r.t. a reference window.
42 |
43 | anchors: [xc, yc, w, h, angle(radian)]
44 | """
45 | num_anchors = len(ratios) * len(scales) * len(rotations)
46 | # initialize output anchors
47 | anchors = np.zeros((num_anchors, 5))
48 | # scale base_size
49 | anchors[:, 2:4] = base_size * np.tile(scales, (2, len(ratios) * len(rotations))).T
50 | # compute areas of anchors
51 | areas = anchors[:, 2] * anchors[:, 3]
52 | # correct for ratios
53 | anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales) * len(rotations)))
54 | anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales) * len(rotations))
55 | # add rotations
56 | anchors[:, 4] = np.tile(np.repeat(rotations, len(scales)), (1, len(ratios))).T[:, 0]
57 | # # transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
58 | # anchors[:, 0:3:2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T
59 | # anchors[:, 1:4:2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T
60 | return anchors # [x_ctr, y_ctr, w, h, angle(radian)]
61 |
62 | @staticmethod
63 | def shift(shape, stride, anchors):
64 | shift_x = np.arange(0, shape[1]) * stride
65 | shift_y = np.arange(0, shape[0]) * stride
66 | shift_x, shift_y = np.meshgrid(shift_x, shift_y)
67 | shifts = np.vstack((
68 | shift_x.ravel(), shift_y.ravel(),
69 | np.zeros(shift_x.ravel().shape), np.zeros(shift_y.ravel().shape),
70 | np.zeros(shift_x.ravel().shape)
71 | )).transpose()
72 | # add A anchors (1, A, 5) to
73 | # cell K shifts (K, 1, 5) to get
74 | # shift anchors (K, A, 5)
75 | # reshape to (K*A, 5) shifted anchors
76 | A = anchors.shape[0]
77 | K = shifts.shape[0]
78 | all_anchors = (anchors.reshape((1, A, 5)) + shifts.reshape((1, K, 5)).transpose((1, 0, 2)))
79 | all_anchors = all_anchors.reshape((K * A, 5))
80 | return all_anchors
81 |
82 | def forward(self, images):
83 | image_shape = np.array(images.shape[2:])
84 | image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]
85 |
86 | # compute anchors over all pyramid levels
87 | all_anchors = np.zeros((0, 5)).astype(np.float32)
88 | num_level_anchors = []
89 | for idx, p in enumerate(self.pyramid_levels):
90 | base_anchors = self.generate_anchors(
91 | base_size=self.base_size * self.strides[idx],
92 | ratios=self.ratios,
93 | scales=self.scales,
94 | rotations=self.rotations)
95 | shifted_anchors = self.shift(image_shapes[idx], self.strides[idx], base_anchors)
96 | num_level_anchors.append(shifted_anchors.shape[0])
97 | all_anchors = np.append(all_anchors, shifted_anchors, axis=0)
98 | all_anchors = np.expand_dims(all_anchors, axis=0)
99 | all_anchors = np.tile(all_anchors, (images.size(0), 1, 1))
100 | all_anchors = torch.from_numpy(all_anchors.astype(np.float32))
101 | if torch.is_tensor(images) and images.is_cuda:
102 | device = images.device
103 | all_anchors = all_anchors.cuda(device=device)
104 | return all_anchors, torch.from_numpy(np.array(num_level_anchors)).cuda(device=device)
105 |
106 |
107 | if __name__ == '__main__':
108 | from train import Params
109 | params = Params('/home/fzh/Pictures/Rotation-RetinaNet-PyTorch/configs/retinanet_r50_fpn_hrsc.yml')
110 | anchors = Anchors(params)
111 | feature_map_sizes = [(128, 128), (64, 64), (32, 32), (16, 16), (8, 8)]
112 | for level_idx in range(5):
113 | # print(f'# ============================base_anchor{level_idx}========================================= #')
114 | base_anchor = anchors.generate_anchors(
115 | base_size=anchors.base_size * anchors.strides[level_idx],
116 | ratios=anchors.ratios,
117 | scales=anchors.scales,
118 | rotations=anchors.rotations
119 | )
120 | # print(base_anchor)
121 | print(f'# ============================shift_anchor{level_idx}========================================= #')
122 | shift_anchor = anchors.shift(feature_map_sizes[level_idx], anchors.strides[level_idx], base_anchor)
123 | print(shift_anchor)
124 |
--------------------------------------------------------------------------------
/models/fpn.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch import nn
3 | from utils.utils import kaiming_init, xavier_init
4 |
5 | init_method_list = ['random_init', 'kaiming_init', 'xavier_init', 'normal_init']
6 |
7 |
8 | class FPN(nn.Module):
9 | def __init__(self,
10 | in_channel_list,
11 | out_channels,
12 | top_blocks,
13 | init_method=None):
14 | """
15 | Args:
16 | out_channels(int): number of channels of the FPN feature.
17 | top_blocks(nn.Module or None): if provided, an extra op will be
18 | performed on the FPN output, and the result will extend the result list.
19 | init_method: which method to init lateral_conv and fpn_conv.
20 | kaiming_init: kaiming_init()
21 | xavier_init: xavier_init()
22 | random_init: PyTorch_init()
23 | """
24 | super(FPN, self).__init__()
25 | self.inner_blocks = []
26 | self.layer_blocks = []
27 | self.init_method = init_method
28 | print('[Info]: ===== Neck Using FPN =====')
29 |
30 | assert init_method is not None, f'init_method in class FPN needs to be set.'
31 | assert init_method in init_method_list, f'init_method in class FPN is wrong.'
32 | if init_method is 'kaiming_init':
33 | print('[Info]: Using kaiming_init() to init lateral_conv and fpn_conv.')
34 | if init_method is 'xavier_init':
35 | print('[Info]: Using xavier_init() to init lateral_conv and fpn_conv.')
36 | if init_method is 'random_init':
37 | print('[Info]: Using PyTorch_init() to init lateral_conv and fpn_conv.')
38 |
39 | for idx, in_channels in enumerate(in_channel_list, 1):
40 | inner_block = "fpn_inner{}".format(idx)
41 | layer_block = "fpn_layer{}".format(idx)
42 |
43 | if in_channels == 0:
44 | continue
45 |
46 | # lateral conv 1x1
47 | inner_block_module = nn.Conv2d(in_channels, out_channels, 1) # with bias, without BN Layer
48 | layer_block_module = nn.Conv2d(out_channels, out_channels, 3, 1, 1) # with bias, without BN Layer
49 |
50 | if self.init_method is 'kaiming_init':
51 | kaiming_init(inner_block_module, a=0, nonlinearity='relu')
52 | kaiming_init(layer_block_module, a=0, nonlinearity='relu')
53 |
54 | if self.init_method is 'xavier_init':
55 | xavier_init(inner_block_module, gain=1, bias=0, distribution='uniform')
56 | xavier_init(layer_block_module, gain=1, bias=0, distribution='uniform')
57 |
58 | # if self.init_method is 'random_init':
59 | # Don't do anything
60 |
61 | self.add_module(inner_block, inner_block_module)
62 | self.add_module(layer_block, layer_block_module)
63 |
64 | self.inner_blocks.append(inner_block)
65 | self.layer_blocks.append(layer_block)
66 | self.top_blocks = top_blocks
67 |
68 | def forward(self, x):
69 | """
70 | Arguments:
71 | x : feature maps for each feature level.
72 | Returns:
73 | results (tuple[Tensor]): feature maps after FPN layers.
74 | They are ordered from highest resolution first.
75 | """
76 | last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
77 | results = []
78 | results.append(getattr(self, self.layer_blocks[-1])(last_inner))
79 | for feature, inner_block, layer_block in zip(
80 | x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
81 | ):
82 | if not inner_block:
83 | continue
84 | inner_lateral = getattr(self, inner_block)(feature)
85 | inner_top_down = F.interpolate(
86 | last_inner, size=
87 | (int(inner_lateral.shape[-2]), int(inner_lateral.shape[-1])),
88 | mode='nearest')
89 | last_inner = inner_lateral + inner_top_down
90 | results.insert(0, getattr(self, layer_block)(last_inner))
91 |
92 | if isinstance(self.top_blocks, LastLevelP6_P7):
93 | last_results = self.top_blocks(x[-1], results[-1])
94 | results.extend(last_results)
95 | else:
96 | raise NotImplementedError
97 |
98 | return tuple(results)
99 |
100 |
101 | class LastLevelP6_P7(nn.Module):
102 | """This module is used in RetinaNet to generate extra layers, P6 and P7.
103 | Args:
104 | init_method: which method to init P6_conv and P7_conv,
105 | support methods: kaiming_init:kaiming_init,
106 | xavier_init: xavier_init,
107 | random_init: PyTorch_init
108 | """
109 | def __init__(self, in_channels,
110 | out_channels,
111 | init_method=None):
112 | super(LastLevelP6_P7, self).__init__()
113 | self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) # with bias without BN Layer
114 | self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) # with bias without BN Layer
115 |
116 | assert init_method is not None, f'init_method in class LastLevelP6_P7 needs to be set.'
117 | assert init_method in init_method_list, f'init_method in class LastLevelP6_P7 is wrong.'
118 |
119 | if init_method is 'kaiming_init':
120 | print('[Info]: Using kaiming_init() to init P6_conv and P7_conv')
121 | for layer in [self.p6, self.p7]:
122 | kaiming_init(layer, a=0, nonlinearity='relu')
123 |
124 | if init_method is 'xavier_init':
125 | print('[Info]: Using xavier_init() to init P6_conv and P7_conv')
126 | for layer in [self.p6, self.p7]:
127 | xavier_init(layer, gain=1, bias=0, distribution='uniform')
128 |
129 | if init_method is 'random_init':
130 | print('[Info]: Using PyTorch_init() to init P6_conv and P7_conv')
131 | # Don't do anything
132 |
133 | self.use_p5 = in_channels == out_channels
134 |
135 | def forward(self, c5, p5):
136 | x = p5 if self.use_p5 else c5
137 | p6 = self.p6(x)
138 | p7 = self.p7(p6)
139 | return [p6, p7]
140 |
--------------------------------------------------------------------------------
/models/heads.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils.utils import kaiming_init, constant_init, normal_init
4 | import math
5 |
6 | init_method_list = ['random_init', 'kaiming_init', 'xavier_init', 'normal_init']
7 |
8 |
9 | class CLSBranch(nn.Module):
10 | def __init__(self,
11 | in_channels,
12 | feat_channels,
13 | num_stacked,
14 | init_method=None):
15 | super(CLSBranch, self).__init__()
16 |
17 | assert init_method is not None, f'init_method in class CLSBranch needs to be set.'
18 | assert init_method in init_method_list, f'init_method in class CLSBranch is wrong.'
19 |
20 | self.convs = nn.ModuleList()
21 | for i in range(num_stacked):
22 | chns = in_channels if i == 0 else feat_channels
23 | # : Conv(wo bias) + BN + Relu()
24 | # self.convs.append(nn.Conv2d(chns, feat_channels, 3, 1, 1, bias=False)) # conv_weight -> bn -> relu()
25 | # self.convs.append(nn.BatchNorm2d(feat_channels, affine=True)) # add BN layer
26 | # self.convs.append(nn.ReLU(inplace=True))
27 | # self.init_weights()
28 |
29 | # : Conv(bias) + Relu() and using kaiming_init_weight / mmdet_init_weight
30 | self.convs.append(nn.Conv2d(chns, feat_channels, 3, 1, 1, bias=True)) # conv with bias -> relu()
31 | self.convs.append(nn.ReLU(inplace=True))
32 |
33 | if init_method is 'kaiming_init':
34 | self.kaiming_init_weights()
35 | if init_method is 'normal_init':
36 | self.mmdet_init_weights()
37 |
38 | def mmdet_init_weights(self):
39 | print('[Info]: Using mmdet_init_weights() {normal_init} to init Cls Branch.')
40 | for m in self.modules():
41 | if isinstance(m, nn.Conv2d):
42 | normal_init(m, mean=0, std=0.01, bias=0)
43 | elif isinstance(m, nn.BatchNorm2d):
44 | constant_init(m, 1, bias=0)
45 |
46 | def kaiming_init_weights(self):
47 | print('[Info]: Using kaiming_init_weights() to init Cls Branch.')
48 | for m in self.modules():
49 | if isinstance(m, nn.Conv2d):
50 | kaiming_init(m, a=0, nonlinearity='relu')
51 | elif isinstance(m, nn.BatchNorm2d):
52 | constant_init(m, 1, bias=0)
53 |
54 | def forward(self, x):
55 | for conv in self.convs:
56 | x = conv(x)
57 | return x
58 |
59 |
60 | class CLSHead(nn.Module):
61 | def __init__(self,
62 | feat_channels,
63 | num_anchors,
64 | num_classes):
65 | super(CLSHead, self).__init__()
66 | self.num_anchors = num_anchors
67 | self.num_classes = num_classes
68 | self.feat_channels = feat_channels
69 | self.head = nn.Conv2d(self.feat_channels, self.num_anchors * self.num_classes, 3, 1, 1) # with bias
70 | self.head_init_weights()
71 |
72 | def head_init_weights(self):
73 | print('[Info]: Using RetinaNet Paper Init Method to init Cls Head.')
74 | prior = 0.01
75 | self.head.weight.data.fill_(0)
76 | self.head.bias.data.fill_(-math.log((1.0 - prior) / prior))
77 |
78 | def forward(self, x):
79 | x = torch.sigmoid(self.head(x))
80 | x = x.permute(0, 2, 3, 1)
81 | n, h, w, c = x.shape
82 | x = x.reshape(n, h, w, self.num_anchors, self.num_classes)
83 | return x.reshape(x.shape[0], -1, self.num_classes)
84 |
85 |
86 | class REGBranch(nn.Module):
87 | def __init__(self,
88 | in_channels,
89 | feat_channels,
90 | num_stacked,
91 | init_method=None):
92 | super(REGBranch, self).__init__()
93 |
94 | assert init_method is not None, f'init_method in class RegBranch needs to be set.'
95 | assert init_method in init_method_list, f'init_method in class RegBranch is wrong.'
96 |
97 | self.convs = nn.ModuleList()
98 |
99 | for i in range(num_stacked):
100 | chns = in_channels if i == 0 else feat_channels
101 |
102 | # : Conv(wo bias) + BN + Relu()
103 | # self.convs.append(nn.Conv2d(chns, feat_channels, 3, 1, 1, bias=False)) # conv_weight -> bn -> relu()
104 | # self.convs.append(nn.BatchNorm2d(feat_channels, affine=True))
105 | # self.convs.append(nn.ReLU(inplace=True))
106 | # self.init_weights()
107 |
108 | # : Conv(bias) + Relu() and using kaiming_init_weight / mmdet_init_weight
109 | self.convs.append(nn.Conv2d(chns, feat_channels, 3, 1, 1, bias=True)) # conv with bias -> relu()
110 | self.convs.append(nn.ReLU(inplace=True))
111 | if init_method is 'kaiming_init':
112 | self.kaiming_init_weights()
113 | if init_method is 'normal_init':
114 | self.mmdet_init_weights()
115 |
116 | def mmdet_init_weights(self):
117 | print('[Info]: Using mmdet_init_weights() {normal_init} to init Reg Branch.')
118 | for m in self.modules():
119 | if isinstance(m, nn.Conv2d):
120 | normal_init(m, mean=0, std=0.01, bias=0)
121 | elif isinstance(m, nn.BatchNorm2d):
122 | constant_init(m, 1, bias=0)
123 |
124 | def kaiming_init_weights(self):
125 | print('[Info]: Using kaiming_init_weights() to init Reg Branch.')
126 | for m in self.modules():
127 | if isinstance(m, nn.Conv2d):
128 | kaiming_init(m, a=0, nonlinearity='relu')
129 | elif isinstance(m, nn.BatchNorm2d):
130 | constant_init(m, 1, bias=0)
131 |
132 | def forward(self, x):
133 | for conv in self.convs:
134 | x = conv(x)
135 | return x
136 |
137 |
138 | class REGHead(nn.Module):
139 | def __init__(self,
140 | feat_channels,
141 | num_anchors,
142 | num_regress):
143 | super(REGHead, self).__init__()
144 | self.num_anchors = num_anchors
145 | self.num_regress = num_regress
146 | self.feat_channels = feat_channels
147 | self.head = nn.Conv2d(self.feat_channels, self.num_anchors * self.num_regress, 3, 1, 1) # with bias
148 | self.mmdet_init_weights()
149 |
150 | def mmdet_init_weights(self):
151 | print('[Info]: Using mmdet_init_weights() {normal_init} to init Reg Head.')
152 | normal_init(self.head, mean=0, std=0.01, bias=0)
153 |
154 | def forward(self, x, with_deform=False):
155 | x = self.head(x)
156 | if with_deform is False:
157 | x = x.permute(0, 2, 3, 1)
158 | return x.reshape(x.shape[0], -1, self.num_regress)
159 | else:
160 | return x
161 |
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from utils.utils import bbox_overlaps
3 | from utils.bbox_transforms import *
4 | from utils.box_coder import BoxCoder
5 | from utils.rotation_overlaps.rbox_overlaps import rbox_overlaps
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | class IntegratedLoss(nn.Module):
10 | def __init__(self, params):
11 | super(IntegratedLoss, self).__init__()
12 | loss_dict = params.loss
13 | self.alpha = loss_dict['cls']['alpha']
14 | self.gamma = loss_dict['cls']['gamma']
15 | func = loss_dict['reg']['type']
16 |
17 | assign_dict = params.assigner
18 | self.pos_iou_thr = assign_dict['pos_iou_thr']
19 | self.neg_iou_thr = assign_dict['neg_iou_thr']
20 | self.min_pos_iou = assign_dict['min_pos_iou']
21 | self.low_quality_match = assign_dict['low_quality_match']
22 |
23 | self.box_coder = BoxCoder()
24 |
25 | if func == 'smooth':
26 | self.criteron = smooth_l1_loss
27 | print(f'[Info]: Using {func} Loss.')
28 |
29 | def forward(self, classifications, regressions, anchors, annotations, image_names):
30 | cls_losses = []
31 | reg_losses = []
32 | batch_size = classifications.shape[0]
33 | device = classifications[0].device
34 | for j in range(batch_size):
35 | image_name = image_names[j]
36 | anchor = anchors[j] # [xc, yc, w, h, angle(radian)]
37 | classification = classifications[j, :, :]
38 | regression = regressions[j, :, :] # [xc_offset, yc_offset, h_offset, w_offset, angle_offset]
39 | bbox_annotation = annotations[j, :, :] # [xc, yc, h, w, angle(radian)]
40 | bbox_annotation = bbox_annotation[bbox_annotation[:, -1] != -1]
41 | num_gt = len(bbox_annotation)
42 | if bbox_annotation.shape[0] == 0:
43 | cls_losses.append(torch.tensor(0).float().cuda(device=device))
44 | reg_losses.append(torch.tensor(0).float().cuda(device=device))
45 | continue
46 | classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
47 |
48 | # get minimum circumscribed rectangle of the rotated ground-truth box and
49 | # calculate the horizontal overlaps between minimum circumscribed rectangles and anchor boxes
50 |
51 | horizontal_overlaps = bbox_overlaps(
52 | anchor.clone(), # generate anchor data copy
53 | obb2hbb_oc(bbox_annotation[:, :-1]))
54 |
55 | # obb_rect = [xc, yc, h, w, angle(radian)]
56 | ious = rbox_overlaps(
57 | swap_axis(anchor[:, :]).cpu().numpy(),
58 | bbox_annotation[:, :-1].cpu().numpy(),
59 | horizontal_overlaps.cpu().numpy(),
60 | thresh=1e-1
61 | )
62 |
63 | if not torch.is_tensor(ious):
64 | ious = torch.from_numpy(ious).cuda(device=device)
65 |
66 | iou_max, iou_argmax = torch.max(ious, dim=1)
67 |
68 | positive_indices = torch.ge(iou_max, self.pos_iou_thr)
69 |
70 | if self.low_quality_match is True:
71 | max_gt, argmax_gt = ious.max(dim=0)
72 | for idx in range(num_gt):
73 | if max_gt[idx] >= self.min_pos_iou:
74 | positive_indices[argmax_gt[idx]] = 1
75 |
76 | # calculate classification loss
77 | cls_targets = (torch.ones(classification.shape) * -1).cuda(device=device)
78 | cls_targets[torch.lt(iou_max, self.neg_iou_thr), :] = 0
79 | num_positive_anchors = positive_indices.sum()
80 | assigned_annotations = bbox_annotation[iou_argmax, :]
81 | cls_targets[positive_indices, :] = 0
82 | cls_targets[positive_indices, assigned_annotations[positive_indices, 5].long()] = 1
83 | alpha_factor = torch.ones(cls_targets.shape).cuda(device=device) * self.alpha
84 | alpha_factor = torch.where(torch.eq(cls_targets, 1.), alpha_factor, 1. - alpha_factor)
85 | focal_weight = torch.where(torch.eq(cls_targets, 1.), 1. - classification, classification)
86 | focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)
87 | # bin_cross_entropy = -(cls_targets * torch.log(classification + 1e-6) + (1.0 - cls_targets) * torch.log(
88 | # 1.0 - classification + 1e-6))
89 | bin_cross_entropy = -(cls_targets * torch.log(classification) + (1.0 - cls_targets) * torch.log(
90 | 1.0 - classification))
91 | cls_loss = focal_weight * bin_cross_entropy
92 | cls_loss = torch.where(torch.ne(cls_targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda(device=device))
93 | cls_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.float(), min=1.0))
94 |
95 | # calculate regression loss
96 | if positive_indices.sum() > 0:
97 | all_rois = anchor[positive_indices, :]
98 | gt_boxes = assigned_annotations[positive_indices, :]
99 | reg_targets = self.box_coder.encode(all_rois, gt_boxes)
100 | reg_loss = self.criteron(regression[positive_indices, :], reg_targets)
101 | reg_losses.append(reg_loss)
102 | else:
103 | reg_losses.append(torch.tensor(0).float().cuda(device=device))
104 | loss_cls = torch.stack(cls_losses).mean(dim=0, keepdim=True)
105 | loss_reg = torch.stack(reg_losses).mean(dim=0, keepdim=True)
106 | return loss_cls, loss_reg
107 |
108 |
109 | def smooth_l1_loss(inputs,
110 | targets,
111 | beta=1. / 9,
112 | size_average=True,
113 | weight=None):
114 | """https://github.com/facebookresearch/maskrcnn-benchmark"""
115 | diff = torch.abs(inputs - targets)
116 | if weight is None:
117 | loss = torch.where(
118 | diff < beta,
119 | 0.5 * diff ** 2 / beta,
120 | diff - 0.5 * beta
121 | )
122 | else:
123 | loss = torch.where(
124 | diff < beta,
125 | 0.5 * diff ** 2 / beta,
126 | diff - 0.5 * beta
127 | ) * weight.max(1)[0].unsqueeze(1).repeat(1,5)
128 | if size_average:
129 | return loss.mean()
130 | return loss.sum()
131 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.anchors import Anchors
4 | from models.fpn import FPN, LastLevelP6_P7
5 | from models import resnet
6 | from models.heads import CLSBranch, REGBranch, CLSHead, REGHead
7 | from models.losses import IntegratedLoss
8 | from utils.utils import clip_boxes
9 | from utils.box_coder import BoxCoder
10 | from utils.rotation_nms.cpu_nms import cpu_nms
11 | import math
12 | import cv2
13 | import numpy as np
14 |
15 |
16 | class RetinaNet(nn.Module):
17 | def __init__(self, params):
18 | super(RetinaNet, self).__init__()
19 | self.num_class = len(params.classes)
20 | self.num_regress = 5
21 | self.anchor_generator = Anchors(params)
22 | self.num_anchors = self.anchor_generator.num_anchors
23 | self.pretrained = params.backbone['pretrained']
24 | self.init_backbone(params.backbone['type'])
25 | self.cls_branch_num_stacked = params.head['num_stacked']
26 | self.rotation_nms_thr = params.rotation_nms_thr
27 | self.score_thr = params.score_thr
28 |
29 | self.fpn = FPN(
30 | in_channel_list=self.fpn_in_channels,
31 | out_channels=256,
32 | top_blocks=LastLevelP6_P7(in_channels=256,
33 | out_channels=256,
34 | init_method=params.neck['extra_conv_init_method']), # in_channels: 1) 2048 on C5, 2) 256 on P5
35 | init_method=params.neck['init_method'])
36 |
37 | self.cls_branch = CLSBranch(
38 | in_channels=256,
39 | feat_channels=256,
40 | num_stacked=self.cls_branch_num_stacked,
41 | init_method=params.head['cls_branch_init_method']
42 | )
43 |
44 | self.cls_head = CLSHead(
45 | feat_channels=256,
46 | num_anchors=self.num_anchors,
47 | num_classes=self.num_class
48 | )
49 |
50 | self.reg_branch = REGBranch(
51 | in_channels=256,
52 | feat_channels=256,
53 | num_stacked=self.cls_branch_num_stacked,
54 | init_method=params.head['reg_branch_init_method']
55 | )
56 |
57 | self.reg_head = REGHead(
58 | feat_channels=256,
59 | num_anchors=self.num_anchors,
60 | num_regress=self.num_regress # x, y, w, h, angle
61 | )
62 |
63 | self.loss = IntegratedLoss(params)
64 |
65 | self.box_coder = BoxCoder()
66 |
67 | def init_backbone(self, backbone):
68 | if backbone == 'resnet34':
69 | print(f'[Info]: Use Backbone is {backbone}.')
70 | self.backbone = resnet.resnet34(pretrained=self.pretrained)
71 | self.fpn_in_channels = [128, 256, 512]
72 |
73 | elif backbone == 'resnet50':
74 | print(f'[Info]: Use Backbone is {backbone}.')
75 | self.backbone = resnet.resnet50(pretrained=self.pretrained)
76 | self.fpn_in_channels = [512, 1024, 2048]
77 |
78 | elif backbone == 'resnet101':
79 | print(f'[Info]: Use Backbone is {backbone}.')
80 | self.backbone = resnet.resnet101(pretrained=self.pretrained)
81 | self.fpn_in_channels = [512, 1024, 2048]
82 |
83 | elif backbone == 'resnet152':
84 | print(f'[Info]: Use Backbone is {backbone}.')
85 | self.backbone = resnet.resnet101(pretrained=self.pretrained)
86 | self.fpn_in_channels = [512, 1024, 2048]
87 | else:
88 | raise NotImplementedError
89 |
90 | del self.backbone.avgpool
91 | del self.backbone.fc
92 |
93 | def backbone_output(self, imgs):
94 | feature = self.backbone.relu(self.backbone.bn1(self.backbone.conv1(imgs)))
95 | c2 = self.backbone.layer1(self.backbone.maxpool(feature))
96 | c3 = self.backbone.layer2(c2)
97 | c4 = self.backbone.layer3(c3)
98 | c5 = self.backbone.layer4(c4)
99 | return [c3, c4, c5]
100 |
101 | def forward(self, images, annots=None, image_names=None, test_conf=None):
102 | anchors_list, offsets_list = [], []
103 | original_anchors, num_level_anchors = self.anchor_generator(images)
104 | anchors_list.append(original_anchors)
105 |
106 | features = self.fpn(self.backbone_output(images))
107 |
108 | cls_score = torch.cat([self.cls_head(self.cls_branch(feature)) for feature in features], dim=1)
109 | bbox_pred = torch.cat([self.reg_head(self.reg_branch(feature), with_deform=False)
110 | for feature in features], dim=1)
111 |
112 | # get the predicted bboxes
113 | # predicted_boxes = torch.cat(
114 | # [self.box_coder.decode(anchors_list[-1][index], bbox_pred[index]).unsqueeze(0)
115 | # for index in range(len(bbox_pred))], dim=0).detach()
116 |
117 | if self.training:
118 | # Max IoU Assigner with Focal Loss and Smooth L1 loss
119 | loss_cls, loss_reg = self.loss(cls_score, # cls_score with all levels
120 | bbox_pred, # bbox_pred with all levels
121 | anchors_list[-1],
122 | annots,
123 | image_names)
124 |
125 | return loss_cls, loss_reg
126 |
127 | else: # for model eval()
128 | return self.decoder(images, anchors_list[-1], cls_score, bbox_pred,
129 | thresh=self.score_thr, nms_thresh=self.rotation_nms_thr, test_conf=test_conf)
130 |
131 | def decoder(self, ims, anchors, cls_score, bbox_pred,
132 | thresh=0.6, nms_thresh=0.1, test_conf=None):
133 | """
134 | Args:
135 | thresh: equal to score_thr.
136 | nms_thresh: nms_thr.
137 | test_conf: equal to thresh.
138 | """
139 | if test_conf is not None:
140 | thresh = test_conf
141 | bboxes = self.box_coder.decode(anchors, bbox_pred) # bboxes: [pred_xc, pred_yc, pred_h, pred_w, pred_angle(radian)]
142 | # bboxes = clip_boxes(bboxes, ims)
143 | scores = torch.max(cls_score, dim=2, keepdim=True)[0]
144 | keep = (scores >= thresh)[0, :, 0]
145 | if keep.sum() == 0:
146 | return [torch.zeros(1), torch.zeros(1), torch.zeros(1, 5)]
147 | scores = scores[:, keep, :]
148 | anchors = anchors[:, keep, :]
149 | cls_score = cls_score[:, keep, :]
150 | bboxes = bboxes[:, keep, :]
151 |
152 | # NMS
153 | anchors_nms_idx = cpu_nms(torch.cat([bboxes, scores], dim=2)[0, :, :].cpu().detach().numpy(), nms_thresh)
154 | nms_scores, nms_class = cls_score[0, anchors_nms_idx, :].max(dim=1)
155 | output_boxes = torch.cat([
156 | bboxes[0, anchors_nms_idx, :],
157 | anchors[0, anchors_nms_idx, :]],
158 | dim=1
159 | )
160 | return [nms_scores, nms_class, output_boxes]
161 |
162 | def freeze_bn(self):
163 | """Set BN.eval(), BN is in the model's Backbone. """
164 | for layer in self.backbone.modules():
165 | if isinstance(layer, nn.BatchNorm2d):
166 | # is only used to make the bn.running_mean and running_var not change in training phase.
167 | layer.eval()
168 |
169 | # freeze the bn.weight and bn.bias which are two learnable params in BN Layer.
170 | # layer.weight.requires_grad = False
171 | # layer.bias.requires_grad = False
172 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 | import torch
5 | import os
6 |
7 | model_urls = {
8 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth',
9 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth',
10 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
11 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
12 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
13 | }
14 |
15 |
16 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18 | padding=dilation, groups=groups, bias=False, dilation=dilation)
19 |
20 |
21 | def conv1x1(in_planes, out_planes, stride=1):
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
29 | base_width=64, dilation=1, norm_layer=None):
30 | super(BasicBlock, self).__init__()
31 | if norm_layer is None:
32 | norm_layer = nn.BatchNorm2d
33 | if groups != 1 or base_width != 64:
34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35 | if dilation > 1:
36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37 | self.conv1 = conv3x3(inplanes, planes, stride)
38 | self.bn1 = norm_layer(planes)
39 | self.relu = nn.ReLU(inplace=True)
40 | self.conv2 = conv3x3(planes, planes)
41 | self.bn2 = norm_layer(planes)
42 | self.downsample = downsample
43 | self.stride = stride
44 |
45 | def forward(self, x):
46 | identity = x
47 |
48 | out = self.conv1(x)
49 | out = self.bn1(out)
50 | out = self.relu(out)
51 |
52 | out = self.conv2(out)
53 | out = self.bn2(out)
54 |
55 | if self.downsample is not None:
56 | identity = self.downsample(x)
57 |
58 | out += identity
59 | out = self.relu(out)
60 |
61 | return out
62 |
63 |
64 | class Bottleneck(nn.Module):
65 | expansion = 4
66 |
67 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
68 | base_width=64, dilation=1, norm_layer=None):
69 | super(Bottleneck, self).__init__()
70 | if norm_layer is None:
71 | norm_layer = nn.BatchNorm2d
72 | width = int(planes * (base_width / 64.)) * groups
73 |
74 | self.conv1 = conv1x1(inplanes, width)
75 | self.bn1 = norm_layer(width)
76 |
77 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
78 | self.bn2 = norm_layer(width)
79 |
80 | self.conv3 = conv1x1(width, planes * self.expansion)
81 | self.bn3 = norm_layer(planes * self.expansion)
82 |
83 | self.relu = nn.ReLU(inplace=True)
84 | self.downsample = downsample
85 | self.stride = stride
86 |
87 | def forward(self, x):
88 | identity = x
89 |
90 | out = self.conv1(x)
91 | out = self.bn1(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv2(out)
95 | out = self.bn2(out)
96 | out = self.relu(out)
97 |
98 | out = self.conv3(out)
99 | out = self.bn3(out)
100 |
101 | if self.downsample is not None:
102 | identity = self.downsample(x)
103 |
104 | out += identity
105 | out = self.relu(out)
106 |
107 | return out
108 |
109 |
110 | class ResNet(nn.Module):
111 | def __init__(self, block, layers, num_classes=1000):
112 |
113 | self.inplanes = 64
114 | super(ResNet, self).__init__()
115 |
116 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
117 | self.bn1 = nn.BatchNorm2d(64)
118 | self.relu = nn.ReLU(inplace=True)
119 |
120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
121 |
122 | self.layer1 = self._make_layer(block, 64, layers[0])
123 |
124 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
125 |
126 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
127 |
128 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
129 |
130 | self.avgpool = nn.AvgPool2d(7)
131 | self.fc = nn.Linear(512 * block.expansion, num_classes)
132 |
133 | for m in self.modules():
134 | if isinstance(m, nn.Conv2d):
135 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
136 | m.weight.data.normal_(0, math.sqrt(2. / n))
137 | elif isinstance(m, nn.BatchNorm2d):
138 | m.weight.data.fill_(1)
139 | m.bias.data.zero_()
140 |
141 | def _make_layer(self, block, planes, blocks, stride=1):
142 | downsample = None
143 | if stride != 1 or self.inplanes != planes * block.expansion:
144 | downsample = nn.Sequential(
145 | nn.Conv2d(self.inplanes, planes * block.expansion,
146 | kernel_size=1, stride=stride, bias=False),
147 | nn.BatchNorm2d(planes * block.expansion),
148 | )
149 |
150 | layers = []
151 | layers.append(block(self.inplanes, planes, stride, downsample))
152 | self.inplanes = planes * block.expansion
153 | for i in range(1, blocks):
154 | layers.append(block(self.inplanes, planes))
155 |
156 | return nn.Sequential(*layers)
157 |
158 | def forward(self, x):
159 | x = self.conv1(x)
160 | x = self.bn1(x)
161 | x = self.relu(x)
162 | x = self.maxpool(x)
163 |
164 | x = self.layer1(x)
165 | x = self.layer2(x)
166 | x = self.layer3(x)
167 | x = self.layer4(x)
168 |
169 | x = self.avgpool(x)
170 | x = x.view(x.size(0), -1)
171 | x = self.fc(x)
172 |
173 | return x
174 |
175 |
176 | def resnet18(pretrained=False, **kwargs):
177 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
178 | if pretrained:
179 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'], model_dir='model_data'), strict=False)
180 | return model
181 |
182 |
183 | def resnet34(pretrained=False, **kwargs):
184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
185 | if pretrained:
186 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'], model_dir='model_data'), strict=False)
187 | return model
188 |
189 |
190 | def resnet50(pretrained=False, **kwargs):
191 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
192 | if pretrained:
193 | dir_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
194 | weight_path = dir_path + '/resnet_pretrained_pth/resnet50-0676ba61.pth'
195 | if os.path.exists(weight_path):
196 | model.load_state_dict(torch.load(weight_path), strict=False)
197 | else:
198 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], model_dir='model_data'), strict=False)
199 | return model
200 |
201 |
202 | def resnet101(pretrained=False, **kwargs):
203 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
204 | if pretrained:
205 | dir_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
206 | weight_path = dir_path + '/resnet_pretrained_pth/resnet101-5d3b4d8f.pth'
207 | if os.path.exists(weight_path):
208 | model.load_state_dict(torch.load(weight_path), strict=False)
209 | else:
210 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'], model_dir='model_data'), strict=False)
211 | return model
212 |
213 |
214 | def resnet152(pretrained=False, **kwargs):
215 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
216 | if pretrained:
217 | dir_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
218 | weight_path = dir_path + '/resnet_pretrained_pth/resnet152-b121ed2d.pth'
219 | if os.path.exists(weight_path):
220 | model.load_state_dict(torch.load(weight_path), strict=False)
221 | else:
222 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'], model_dir='model_data'), strict=False)
223 | return model
224 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | colorlog==6.6.0
2 | Cython==0.29.28
3 | matplotlib==3.5.1
4 | numpy==1.21.2
5 | opencv_python==4.5.5.64
6 | Pillow==9.0.1
7 | PyYAML==6.0
8 | setuptools==58.0.4
9 | Shapely==1.8.1.post1
10 | tensorboardX==2.5
11 | torch==1.7.0
12 | torchvision==0.8.0
13 | tqdm==4.63.0
14 |
--------------------------------------------------------------------------------
/resnet_pretrained_pth/.gitignore:
--------------------------------------------------------------------------------
1 | *.pth
2 |
--------------------------------------------------------------------------------
/resnet_pretrained_pth/README.md:
--------------------------------------------------------------------------------
1 | ### Put the pretrained resnet-50/101/152 weight file here.
2 |
--------------------------------------------------------------------------------
/resource/HRSC_Result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/resource/HRSC_Result.png
--------------------------------------------------------------------------------
/resource/RSSDD_Result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/resource/RSSDD_Result.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import platform
2 | from setuptools import Extension, setup
3 | import os
4 | import numpy as np
5 | from Cython.Build import cythonize
6 | from torch.utils.cpp_extension import BuildExtension
7 |
8 |
9 | def make_cython_ext(name, module, sources):
10 | extra_compile_args = None
11 | if platform.system() != 'Windows':
12 | extra_compile_args = {
13 | 'cxx': ['-Wno-unused-function', '-Wno-write-strings']
14 | }
15 |
16 | extension = Extension(
17 | '{}.{}'.format(module, name),
18 | [os.path.join(*module.split('.'), p) for p in sources],
19 | include_dirs=[np.get_include()],
20 | language='c++',
21 | extra_compile_args=extra_compile_args)
22 | extension, = cythonize(extension)
23 | return extension
24 |
25 |
26 | if __name__ == '__main__':
27 | setup(
28 | name='extension',
29 | ext_modules=[
30 | make_cython_ext(
31 | name='rbox_overlaps',
32 | module='utils.rotation_overlaps',
33 | sources=['rbox_overlaps.pyx']),
34 |
35 | make_cython_ext(
36 | name='cpu_nms',
37 | module='utils.rotation_nms',
38 | sources=['cpu_nms.pyx']),
39 | ],
40 | cmdclass={'build_ext': BuildExtension},
41 | )
--------------------------------------------------------------------------------
/show.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from models.model import RetinaNet
3 | import os
4 | import cv2
5 | import torch
6 | from detect import im_detect
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | import math
10 |
11 |
12 | def get_args():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--backbone', type=str, default='resnet50')
15 | parser.add_argument('--config_file', type=str, default='./configs/retinanet_r50_fpn_ssdd.yml')
16 | parser.add_argument('--target_sizes', type=list, default=[512], help='the size of the input image.')
17 | parser.add_argument('--chkpt', type=str, default='best/best.pth', help='the chkpt file name')
18 | parser.add_argument('--result_path', type=str, default='show_result', help='the relative path for saving'
19 | 'ori pic and predicted pic')
20 | parser.add_argument('--score_thresh', type=float, default=0.05, help='score threshold')
21 | parser.add_argument('--pic_name', type=str, default='demo6.jpg', help='relative path')
22 | parser.add_argument('--device', type=int, default=1)
23 | args = parser.parse_args()
24 | return args
25 |
26 |
27 | def plot_box(image, coord, label_index=None, score=None, color=None, line_thickness=None):
28 | bbox_color = [226, 43, 138] if color is None else color
29 | text_color = [255, 255, 255]
30 | line_thickness = 1 if line_thickness is None else line_thickness
31 | xc, yc, h, w, ag = coord[:5]
32 | wx, wy = -w / 2 * math.sin(ag), w / 2 * math.cos(ag)
33 | hx, hy = h / 2 * math.cos(ag), h / 2 * math.sin(ag)
34 | p1 = (xc - wx - hx, yc - wy - hy)
35 | p2 = (xc - wx + hx, yc - wy + hy)
36 | p3 = (xc + wx + hx, yc + wy + hy)
37 | p4 = (xc + wx - hx, yc + wy - hy)
38 | ps = np.int0(np.array([p1, p2, p3, p4]))
39 | cv2.drawContours(image, [ps], -1, bbox_color, thickness=3)
40 | if label_index is not None:
41 | label_text = params.classes[label_index]
42 | label_text += '|{:.02f}'.format(score)
43 | font = cv2.FONT_HERSHEY_COMPLEX
44 | text_size = cv2.getTextSize(label_text, font, fontScale=0.25, thickness=line_thickness)
45 | text_width = text_size[0][0]
46 | text_height = text_size[0][1]
47 | try:
48 | cv2.rectangle(image, (int(xc), int(yc) - text_height -2),
49 | (int(xc) + text_width, int(yc) + 3), (0, 128, 0), -1)
50 | cv2.putText(image, label_text, (int(xc), int(yc)), font, 0.25, text_color, thickness=1)
51 | except:
52 | print(f'{coord} is wrong!')
53 |
54 |
55 | def show_pred_box(args, params):
56 | # create folder
57 | if not os.path.exists(args.result_path):
58 | os.makedirs(args.result_path)
59 |
60 | model = RetinaNet(params)
61 | chkpt_path = os.path.join(params.output_path, 'checkpoints', args.chkpt)
62 | chkpt = torch.load(chkpt_path, map_location='cpu')
63 | print(f"The current model training {chkpt['epoch']} epoch(s)")
64 | print(f"The current model mAP: {chkpt['best_fitness']} based on test_conf={params.score_thr} & nms_thr={params.nms_thr}")
65 |
66 | model.load_state_dict(chkpt['model'])
67 | model.cuda(device=args.device)
68 | model.eval()
69 |
70 | image = cv2.cvtColor(cv2.imread(os.path.join(args.result_path, args.pic_name), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
71 |
72 | dets = im_detect(model,
73 | image,
74 | target_sizes=args.target_sizes,
75 | params=params,
76 | use_gpu=True,
77 | conf=args.score_thresh,
78 | device=args.device)
79 |
80 | # dets: list[class_index, 0
81 | # score, 1
82 | # pred_xc, pred_yc, pred_w, pred_h, pred_angle(radian), 2 - 6
83 | # anchor_xc, anchor_yc, anchor_w, anchor_h, anchor_angle(radian)] 7 - 11
84 | for det in dets:
85 | cls_index = int(det[0])
86 | score = float(det[1])
87 | pred_box = det[2:7]
88 | anchor = det[7:12]
89 |
90 | # plot predict box
91 | plot_box(image, coord=pred_box, label_index=cls_index, score=score, color=None,
92 | line_thickness=4)
93 |
94 | # plot which anchor to create predict box
95 | # plot_box(image, coord=anchor, color=[0, 0, 255])
96 |
97 | plt.imsave(os.path.join(args.result_path, f"{args.pic_name.split('.')[0]}_predict.png"), image)
98 | plt.imshow(image)
99 | plt.show()
100 |
101 |
102 | if __name__ == '__main__':
103 | from train import Params
104 |
105 | args = get_args()
106 | params = Params(args.config_file)
107 | if args.score_thresh != params.score_thr:
108 | print('[Info]: score_thresh is not equal to cfg.score_thr')
109 | params.backbone['pretrained'] = False
110 | show_pred_box(args, params)
111 |
--------------------------------------------------------------------------------
/show_result/HRSC/demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/HRSC/demo1.jpg
--------------------------------------------------------------------------------
/show_result/HRSC/demo2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/HRSC/demo2.jpg
--------------------------------------------------------------------------------
/show_result/HRSC/demo3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/HRSC/demo3.jpg
--------------------------------------------------------------------------------
/show_result/RSSDD/demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/RSSDD/demo1.jpg
--------------------------------------------------------------------------------
/show_result/RSSDD/demo2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/RSSDD/demo2.jpg
--------------------------------------------------------------------------------
/show_result/RSSDD/demo3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/show_result/RSSDD/demo3.jpg
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | # from datasets.HRSC_dataset import HRSCDataset
4 | from datasets.SSDD_dataset import SSDDataset
5 | from datasets.collater import Collater
6 | import torch.utils.data as data
7 | from utils.utils import set_random_seed, count_param
8 | from models.model import RetinaNet
9 | import torch.optim as optim
10 | from tqdm import tqdm
11 | import os
12 | from tensorboardX import SummaryWriter
13 | import datetime
14 | import torch.nn as nn
15 | from warmup import WarmupLR
16 | import yaml
17 | from pprint import pprint
18 | from eval import evaluate
19 | from Logger import Logger
20 |
21 |
22 | class Params:
23 | def __init__(self, project_file):
24 | self.filename = os.path.basename(project_file)
25 | self.params = yaml.safe_load(open(project_file).read())
26 |
27 | def __getattr__(self, item):
28 | return self.params.get(item, None)
29 |
30 | def info(self):
31 | return '\n'.join([(f'{key}: {value}') for key, value in self.params.items()])
32 |
33 | def save(self):
34 | with open(os.path.join(self.params.get('output_path'), f'{self.filename}'), 'w') as f:
35 | yaml.dump(self.params, f, sort_keys=False)
36 |
37 | def show(self):
38 | print('=================== Show Params =====================')
39 | pprint(self.params)
40 |
41 |
42 | def get_args():
43 | parser = argparse.ArgumentParser('A Rotation Detector based on RetinaNet by PyTorch.')
44 | parser.add_argument('--config_file', type=str, default='./configs/retinanet_r50_fpn_{Dataset Name}.yml')
45 | parser.add_argument('--resume', type=str,
46 | # default='{epoch}_{step}.pth',
47 | default=None, # train from scratch
48 | help='the last checkpoint file.')
49 | args = parser.parse_args()
50 | return args
51 |
52 |
53 | def train(args, params):
54 | epochs = params.epoch
55 | if torch.cuda.is_available():
56 | if len(params.device) == 1:
57 | device = params.device[0]
58 | else:
59 | print(f'[Info]: Traing with {params.device} GPUs')
60 |
61 | weight = ''
62 | if args.resume:
63 | weight = params.output_path + os.sep + params.checkpoint + os.sep + args.resume
64 |
65 | start_epoch = 0
66 | best_fitness = 0
67 | fitness = 0
68 | last_step = 0
69 |
70 | # create folder
71 | tensorboard_path = os.path.join(params.output_path, params.tensorboard)
72 | if not os.path.exists(tensorboard_path):
73 | os.makedirs(tensorboard_path)
74 |
75 | checkpoint_path = os.path.join(params.output_path, params.checkpoint)
76 | if not os.path.exists(checkpoint_path):
77 | os.makedirs(checkpoint_path)
78 |
79 | best_checkpoint_path = os.path.join(checkpoint_path, 'best')
80 | if not os.path.exists(best_checkpoint_path):
81 | os.makedirs(best_checkpoint_path)
82 |
83 | log_file_path = os.path.join(params.output_path, params.log)
84 | if os.path.isfile(log_file_path):
85 | os.remove(log_file_path)
86 |
87 | log = Logger(log_path=os.path.join(params.output_path, params.log), logging_name='R-RetinaNet')
88 | logger = log.logger_config()
89 | env_info = params.info()
90 | logger.info('Config info:\n' + log.dash_line + env_info + '\n' + log.dash_line)
91 |
92 | # save config yaml file
93 | params.save()
94 |
95 | train_dataset = SSDDataset(root_path=params.data_path, set_name='train', augment=params.augment,
96 | classes=params.classes)
97 | collater = Collater(scales=params.image_size, keep_ratio=params.keep_ratio, multiple=32)
98 | train_generator = data.DataLoader(
99 | dataset=train_dataset,
100 | batch_size=params.batch_size,
101 | num_workers=8, # 4 * number of the GPU
102 | collate_fn=collater,
103 | shuffle=True,
104 | pin_memory=True,
105 | drop_last=True)
106 |
107 | # Initialize model & set random seed
108 | set_random_seed(seed=42, deterministic=False)
109 | model = RetinaNet(params)
110 | count_param(model)
111 |
112 | # init tensorboardX
113 | writer = SummaryWriter(tensorboard_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
114 |
115 | # Optimizer Option
116 | optimizer = optim.Adam(model.parameters(), lr=params.lr)
117 |
118 | # Scheduler Option
119 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
120 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[round(epochs * x) for x in [0.6, 0.8]], gamma=0.1)
121 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.94)
122 |
123 | # Warm-up
124 | is_warmup = False
125 | if params.warm_up and args.resume is None:
126 | print('[Info]: Launching Warmup.')
127 | scheduler = WarmupLR(scheduler, init_lr=params.warmup_lr, num_warmup=params.warmup_epoch, warmup_strategy='cos')
128 | is_warmup = True
129 | if is_warmup is False:
130 | print('[Info]: Not Launching Warmup.')
131 |
132 | if torch.cuda.is_available() and len(params.device) == 1:
133 | model = model.cuda(device=device)
134 | else:
135 | model = nn.DataParallel(model, device_ids=[0, 1], output_device=0)
136 | model.cuda() # put the model on the main card in the condition of the multi-gpus
137 |
138 | if args.resume:
139 | if weight.endswith('.pth'):
140 | chkpt = torch.load(weight)
141 | last_step = chkpt['step']
142 |
143 | # Load model
144 | if 'model' in chkpt.keys():
145 | model.load_state_dict(chkpt['model'])
146 | else:
147 | model.load_state_dict(chkpt)
148 |
149 | # Load optimizer
150 | if 'optimizer' in chkpt.keys() and chkpt['optimizer'] is not None:
151 | optimizer.load_state_dict(chkpt['optimizer'])
152 | best_fitness = chkpt['best_fitness']
153 | for state in optimizer.state.values():
154 | for k, v in state.items():
155 | if isinstance(v, torch.Tensor):
156 | state[k] = v.cuda(device=device)
157 |
158 | # Load scheduler
159 | if 'scheduler' in chkpt.keys() and chkpt['scheduler'] is not None:
160 | scheduler_state = chkpt['scheduler']
161 | scheduler._step_count = scheduler_state['step_count']
162 | scheduler.last_epoch = scheduler_state['last_epoch']
163 |
164 | start_epoch = chkpt['epoch'] + 1
165 |
166 | del chkpt
167 |
168 | # start training
169 | step = max(0, last_step)
170 | num_iter_per_epoch = len(train_generator)
171 |
172 | head_line = ('%10s' * 8) % ('Epoch', 'Steps', 'gpu_mem', 'cls', 'reg', 'total', 'targets', 'img_size')
173 | print(('\n' + '%10s' * 8) % ('Epoch', 'Steps', 'gpu_mem', 'cls', 'reg', 'total', 'targets', 'img_size'))
174 | logger.debug(head_line)
175 |
176 | if is_warmup:
177 | scheduler.step()
178 | for epoch in range(start_epoch, epochs):
179 | last_epoch = step // num_iter_per_epoch
180 | if epoch < last_epoch:
181 | continue
182 | pbar = tqdm(enumerate(train_generator), total=len(train_generator)) # progress bar
183 |
184 | # for each epoch, we set model.eval() to model.train()
185 | # and freeze backbone BN Layers parameters
186 | model.train()
187 |
188 | if params.freeze_bn and len(params.device) == 1:
189 | model.freeze_bn()
190 | else:
191 | model.module.freeze_bn()
192 |
193 | for iter, (ni, batch) in enumerate(pbar):
194 |
195 | if iter < step - last_epoch * num_iter_per_epoch:
196 | pbar.update()
197 | continue
198 |
199 | optimizer.zero_grad()
200 | images, annots, image_names = batch['image'], batch['bboxes'], batch['image_name']
201 | if torch.cuda.is_available():
202 | if len(params.device) == 1:
203 | images, annots = images.cuda(device=device), annots.cuda(device=device)
204 | else:
205 | images, annots = images.cuda(), annots.cuda()
206 | loss_cls, loss_reg = model(images, annots, image_names)
207 |
208 | # Using .mean() is following Ming71 and Zylo117 repo
209 | loss_cls = loss_cls.mean()
210 | loss_reg = loss_reg.mean()
211 |
212 | total_loss = loss_cls + loss_reg
213 |
214 | if not torch.isfinite(total_loss):
215 | print('[Warning]: loss is nan')
216 | break
217 |
218 | if bool(total_loss == 0):
219 | continue
220 |
221 | total_loss.backward()
222 |
223 | # Update parameters
224 |
225 | # if loss is not nan not using grad clip
226 | # nn.utils.clip_grad_norm_(model.parameters(), 0.1)
227 |
228 | optimizer.step()
229 |
230 | # print batch result
231 | if len(params.device) == 1:
232 | mem = torch.cuda.memory_reserved(device=device) / 1E9 if torch.cuda.is_available() else 0
233 | else:
234 | mem = sum(torch.cuda.memory_reserved(device=idx) for idx in range(len(params.device))) / 1E9
235 |
236 | s = ('%10s' * 3 + '%10.3g' * 4 + '%10s' * 1) % (
237 | '%g/%g' % (epoch, epochs - 1),
238 | '%g' % iter,
239 | '%.3gG' % mem, loss_cls.item(), loss_reg.item(), total_loss.item(), annots.shape[1],
240 | '%gx%g' % (int(images.shape[2]), int(images.shape[3])))
241 |
242 | pbar.set_description(s)
243 |
244 | # write loss info into tensorboard
245 | writer.add_scalars('Loss', {'train': total_loss}, step)
246 | writer.add_scalars('Regression_loss', {'train': loss_reg}, step)
247 | writer.add_scalars('Classfication_loss', {'train': loss_cls}, step)
248 |
249 | # write lr info into tensorboard
250 | current_lr = optimizer.param_groups[0]['lr']
251 | writer.add_scalar('lr_per_step', current_lr, step)
252 | step = step + 1
253 |
254 | # Update scheduler / learning rate
255 | scheduler.step()
256 | logger.debug(s)
257 |
258 | final_epoch = epoch + 1 == epochs
259 |
260 | # # check the mAP on training set begin ------------------------------------------------
261 | # if epoch >= params.evaluate_train_start and epoch % params.val_interval == 0:
262 | # test_path = 'train-ground-truth'
263 | # train_results = evaluate(
264 | # target_size=[params.image_size],
265 | # test_path=test_path,
266 | # eval_method=args.eval_method,
267 | # model=model,
268 | # conf=params.score_thr,
269 | # device=args.device,
270 | # mode='train')
271 | #
272 | # train_fitness = train_results[0] # Update best mAP
273 | # writer.add_scalar('train_mAP', train_fitness, epoch)
274 | # --------------------------end
275 |
276 | # save model
277 | # create checkpoint
278 | chkpt = {'epoch': epoch,
279 | 'step': step,
280 | 'best_fitness': best_fitness,
281 | 'model': model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel
282 | else model.state_dict(),
283 | 'optimizer': None if final_epoch else optimizer.state_dict(),
284 | 'scheduler': {'step_count': scheduler._step_count,
285 | 'last_epoch': scheduler.last_epoch}
286 | }
287 |
288 | # save interval checkpoint
289 | if epoch % params.save_interval == 0 and epoch >= 30:
290 | torch.save(chkpt, os.path.join(checkpoint_path, f'{epoch}_{step}.pth'))
291 |
292 | if epoch >= params.evaluation_val_start and epoch % params.val_interval == 0:
293 | test_path = 'ground-truth'
294 | model.eval()
295 | val_mAP, val_Precision, val_Recall = evaluate(model=model,
296 | target_size=params.image_size,
297 | test_path=test_path,
298 | conf=params.score_thr,
299 | device=device,
300 | mode='test',
301 | params=params)
302 |
303 | eval_line = ('%10s' * 7) % ('[%g/%g]' % (epoch, epochs - 1), 'Val mAP:', '%10.3f' % val_mAP,
304 | 'Precision:', '%10.3f' % val_Precision,
305 | 'Recall:', '%10.3f' % val_Recall)
306 | logger.debug(eval_line)
307 |
308 | fitness = val_mAP # Update best mAP
309 |
310 | if fitness > best_fitness:
311 | best_fitness = fitness
312 |
313 | # write mAP info into tensorboard
314 | writer.add_scalar('val_mAP', fitness, epoch)
315 |
316 | # save best checkpoint
317 | if best_fitness == fitness:
318 | torch.save(chkpt, os.path.join(best_checkpoint_path, 'best.pth'))
319 |
320 | # TensorboardX writer close
321 | writer.close()
322 |
323 |
324 | if __name__ == '__main__':
325 | # os.environ["CUDA_VISIBLE_DEVICES"] = '3, 2' # for multi-GPU
326 | from utils.utils import show_args
327 | args = get_args()
328 | params = Params(args.config_file)
329 | show_args(args)
330 | train(args, params)
331 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/augment.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/augment.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/bbox_transforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/bbox_transforms.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/box_coder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/box_coder.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/map.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/map.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/augment.py:
--------------------------------------------------------------------------------
1 | import random
2 | import cv2
3 | import numpy as np
4 |
5 |
6 | class HorizontalFlip(object):
7 | """
8 | Args:
9 | p: the probability of the horizontal flip
10 | """
11 | def __init__(self, p=0.5):
12 | self.p = p
13 |
14 | def __call__(self, image, bboxes):
15 | """
16 | Args:
17 | image: array([C, H, W])
18 | bboxes: array (N, 8) :[[x1, y1, x2, y2, x3, y3, x4, y4] ... ]
19 | """
20 | if random.random() < self.p:
21 | h, w, _ = image.shape
22 | image = np.array(np.fliplr(image))
23 | for idx, single_box in enumerate(bboxes):
24 | bboxes[idx, 0::2] = w - single_box[0::2]
25 | return image, bboxes
26 |
27 |
28 | class VerticalFlip(object):
29 | """
30 | Args:
31 | p: the probability of the vertical flip
32 | """
33 | def __init__(self, p=0.5):
34 | self.p = p
35 |
36 | def __call__(self, image, bboxes):
37 | """
38 | Args:
39 | image: array([C, H, W])
40 | bboxes: list (N, 9) :[[x1, y1, x2, y2, x3, y3, x4, y4, class_index] ... ]
41 | """
42 | if random.random() < self.p:
43 | h, w, _ = image.shape
44 | image = np.array(np.flipud(image))
45 | for idx, single_box in enumerate(bboxes):
46 | bboxes[idx, 1::2] = h - single_box[1::2]
47 | return image, bboxes
48 |
49 |
50 | class HSV(object):
51 | def __init__(self, saturation=0, brightness=0, p=0.):
52 | self.saturation = saturation
53 | self.brightness = brightness
54 | self.p = p
55 |
56 | def __call__(self, image, bboxes, mode=None):
57 | if random.random() < self.p:
58 | img_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) # hue, sat, val
59 | S = img_hsv[:, :, 1].astype(np.float32) # saturation
60 | V = img_hsv[:, :, 2].astype(np.float32) # value
61 | a = random.uniform(-1, 1) * self.saturation + 1
62 | b = random.uniform(-1, 1) * self.brightness + 1
63 | S *= a
64 | V *= b
65 | img_hsv[:, :, 1] = S if a < 1 else S.clip(None, 255)
66 | img_hsv[:, :, 2] = V if b < 1 else V.clip(None, 255)
67 | cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=image)
68 | return image, bboxes
69 |
70 |
71 | class Augment(object):
72 | def __init__(self, transforms):
73 | self.transforms = transforms
74 |
75 | def __call__(self, image, bboxes):
76 | for transform in self.transforms:
77 | image, bboxes = transform(image, bboxes)
78 | return image, bboxes
79 |
--------------------------------------------------------------------------------
/utils/bbox_transforms.py:
--------------------------------------------------------------------------------
1 | """ Original code is from:
2 | `https://github.com/open-mmlab/mmrotate/blob/main/mmrotate/core/bbox/transforms.py`"""
3 |
4 | import numpy as np
5 | import math
6 | import torch
7 | import cv2
8 |
9 |
10 | def swap_axis(tensor):
11 | if torch.is_tensor(tensor):
12 | swap_bbox = torch.zeros_like(tensor)
13 | swap_bbox[:, 0] = tensor[:, 0]
14 | swap_bbox[:, 1] = tensor[:, 1]
15 | swap_bbox[:, 2] = tensor[:, 3]
16 | swap_bbox[:, 3] = tensor[:, 2]
17 | swap_bbox[:, 4] = tensor[:, 4]
18 | else:
19 | swap_bbox = np.zeros_like(tensor)
20 | swap_bbox[:, 0] = tensor[:, 0]
21 | swap_bbox[:, 1] = tensor[:, 1]
22 | swap_bbox[:, 2] = tensor[:, 3]
23 | swap_bbox[:, 3] = tensor[:, 2]
24 | swap_bbox[:, 4] = tensor[:, 4]
25 |
26 | return swap_bbox
27 |
28 |
29 | def norm_angle(angle, angle_range):
30 | """Limit the range of angles.
31 |
32 | Args:
33 | angle (ndarray): shape(n, ).
34 | angle_range (Str): angle representations.
35 | Returns:
36 | angle (ndarray): shape(n, ).
37 | """
38 | if angle_range == 'oc':
39 | return angle
40 | elif angle_range == 'le135':
41 | return (angle + np.pi / 4) % np.pi - np.pi / 4
42 | elif angle_range == 'le90':
43 | return (angle + np.pi / 2) % np.pi - np.pi / 2
44 | else:
45 | print('Not yet implemented.')
46 |
47 |
48 | def obb2poly_oc(rboxes):
49 | """Convert oriented bounding boxes to polygons.
50 | Args:
51 | obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]
52 | Returns:
53 | polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3]
54 | """
55 | x = rboxes[:, 0]
56 | y = rboxes[:, 1]
57 | w = rboxes[:, 2]
58 | h = rboxes[:, 3]
59 | a = rboxes[:, 4]
60 | cosa = torch.cos(a)
61 | sina = torch.sin(a)
62 | wx, wy = w / 2 * cosa, w / 2 * sina
63 | hx, hy = -h / 2 * sina, h / 2 * cosa
64 | p1x, p1y = x - wx - hx, y - wy - hy
65 | p2x, p2y = x + wx - hx, y + wy - hy
66 | p3x, p3y = x + wx + hx, y + wy + hy
67 | p4x, p4y = x - wx + hx, y - wy + hy
68 | return torch.stack([p1x, p1y, p2x, p2y, p3x, p3y, p4x, p4y], dim=-1)
69 |
70 |
71 | def obb2poly_le135(rboxes):
72 | """Convert oriented bounding boxes to polygons.
73 | Args:
74 | obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]
75 | Returns:
76 | polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3]
77 | """
78 | N = rboxes.shape[0]
79 | if N == 0:
80 | return rboxes.new_zeros((rboxes.size(0), 8))
81 | x_ctr, y_ctr, width, height, angle = rboxes.select(1, 0), rboxes.select(
82 | 1, 1), rboxes.select(1, 2), rboxes.select(1, 3), rboxes.select(1, 4)
83 | tl_x, tl_y, br_x, br_y = \
84 | -width * 0.5, -height * 0.5, \
85 | width * 0.5, height * 0.5
86 | rects = torch.stack([tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y],
87 | dim=0).reshape(2, 4, N).permute(2, 0, 1)
88 | sin, cos = torch.sin(angle), torch.cos(angle)
89 | M = torch.stack([cos, -sin, sin, cos], dim=0).reshape(2, 2,
90 | N).permute(2, 0, 1)
91 | polys = M.matmul(rects).permute(2, 1, 0).reshape(-1, N).transpose(1, 0)
92 | polys[:, ::2] += x_ctr.unsqueeze(1)
93 | polys[:, 1::2] += y_ctr.unsqueeze(1)
94 | return polys.contiguous()
95 |
96 |
97 | def obb2poly_le90(rboxes):
98 | """Convert oriented bounding boxes to polygons.
99 | Args:
100 | obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]
101 | Returns:
102 | polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3]
103 | """
104 | N = rboxes.shape[0]
105 | if N == 0:
106 | return rboxes.new_zeros((rboxes.size(0), 8))
107 | x_ctr, y_ctr, width, height, angle = rboxes.select(1, 0), rboxes.select(
108 | 1, 1), rboxes.select(1, 2), rboxes.select(1, 3), rboxes.select(1, 4)
109 | tl_x, tl_y, br_x, br_y = \
110 | -width * 0.5, -height * 0.5, \
111 | width * 0.5, height * 0.5
112 | rects = torch.stack([tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y],
113 | dim=0).reshape(2, 4, N).permute(2, 0, 1)
114 | sin, cos = torch.sin(angle), torch.cos(angle)
115 | M = torch.stack([cos, -sin, sin, cos], dim=0).reshape(2, 2,
116 | N).permute(2, 0, 1)
117 | polys = M.matmul(rects).permute(2, 1, 0).reshape(-1, N).transpose(1, 0)
118 | polys[:, ::2] += x_ctr.unsqueeze(1)
119 | polys[:, 1::2] += y_ctr.unsqueeze(1)
120 | return polys.contiguous()
121 |
122 |
123 | def cal_line_length(point1, point2):
124 | """Calculate the length of line.
125 | Args:
126 | point1 (List): [x,y]
127 | point2 (List): [x,y]
128 | Returns:
129 | length (float)
130 | """
131 | return math.sqrt(
132 | math.pow(point1[0] - point2[0], 2) +
133 | math.pow(point1[1] - point2[1], 2))
134 |
135 |
136 | def get_best_begin_point_single(coordinate):
137 | """Get the best begin point of the single polygon.
138 | Args:
139 | coordinate (List): [x1, y1, x2, y2, x3, y3, x4, y4, score]
140 | Returns:
141 | reorder coordinate (List): [x1, y1, x2, y2, x3, y3, x4, y4, score]
142 | """
143 | x1, y1, x2, y2, x3, y3, x4, y4, score = coordinate
144 | xmin = min(x1, x2, x3, x4)
145 | ymin = min(y1, y2, y3, y4)
146 | xmax = max(x1, x2, x3, x4)
147 | ymax = max(y1, y2, y3, y4)
148 | combine = [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]],
149 | [[x2, y2], [x3, y3], [x4, y4], [x1, y1]],
150 | [[x3, y3], [x4, y4], [x1, y1], [x2, y2]],
151 | [[x4, y4], [x1, y1], [x2, y2], [x3, y3]]]
152 | dst_coordinate = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
153 | force = 100000000.0
154 | force_flag = 0
155 | for i in range(4):
156 | temp_force = cal_line_length(combine[i][0], dst_coordinate[0]) \
157 | + cal_line_length(combine[i][1], dst_coordinate[1]) \
158 | + cal_line_length(combine[i][2], dst_coordinate[2]) \
159 | + cal_line_length(combine[i][3], dst_coordinate[3])
160 | if temp_force < force:
161 | force = temp_force
162 | force_flag = i
163 | if force_flag != 0:
164 | pass
165 | return np.hstack(
166 | (np.array(combine[force_flag]).reshape(8), np.array(score)))
167 |
168 |
169 | def get_best_begin_point(coordinates):
170 | """Get the best begin points of polygons.
171 | Args:
172 | coordinate (ndarray): shape(n, 9).
173 | Returns:
174 | reorder coordinate (ndarray): shape(n, 9).
175 | """
176 | coordinates = list(map(get_best_begin_point_single, coordinates.tolist()))
177 | coordinates = np.array(coordinates)
178 | return coordinates
179 |
180 |
181 | def obb2poly_np_oc(rbboxes):
182 | """ Modified !
183 | Convert oriented bounding boxes to polygons.
184 | Args:
185 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle,score] modify-> [x_ctr, y_ctr, h, w, angle(radian), score]
186 | Returns:
187 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3,score]
188 | """
189 | x = rbboxes[:, 0]
190 | y = rbboxes[:, 1]
191 | h = rbboxes[:, 2]
192 | w = rbboxes[:, 3]
193 | a = rbboxes[:, 4]
194 | score = rbboxes[:, 5]
195 |
196 | cosa = np.cos(a)
197 | sina = np.sin(a)
198 | wx, wy = -w / 2 * sina, w / 2 * cosa
199 | hx, hy = h / 2 * cosa, h / 2 * sina
200 | p1x, p1y = x - wx - hx, y - wy - hy
201 | p2x, p2y = x - wx + hx, y - wy + hy
202 | p3x, p3y = x + wx + hx, y + wy + hy
203 | p4x, p4y = x + wx - hx, y + wy - hy
204 | polys = np.stack([p1x, p1y, p2x, p2y, p3x, p3y, p4x, p4y, score], axis=-1)
205 | polys = get_best_begin_point(polys)
206 | return polys
207 |
208 |
209 | def obb2poly_np_le135(rrects):
210 | """Convert oriented bounding boxes to polygons.
211 | Args:
212 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle,score]
213 | Returns:
214 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3,score]
215 | """
216 | polys = []
217 | for rrect in rrects:
218 | x_ctr, y_ctr, width, height, angle, score = rrect[:6]
219 | tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
220 | rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
221 | R = np.array([[np.cos(angle), -np.sin(angle)],
222 | [np.sin(angle), np.cos(angle)]])
223 | poly = R.dot(rect)
224 | x0, x1, x2, x3 = poly[0, :4] + x_ctr
225 | y0, y1, y2, y3 = poly[1, :4] + y_ctr
226 | poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3, score],
227 | dtype=np.float32)
228 | polys.append(poly)
229 | polys = np.array(polys)
230 | polys = get_best_begin_point(polys)
231 | return polys
232 |
233 |
234 | def obb2poly_np_le90(obboxes):
235 | """Convert oriented bounding boxes to polygons.
236 | Args:
237 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle,score]
238 | Returns:
239 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3,score]
240 | """
241 | try:
242 | center, w, h, theta, score = np.split(obboxes, (2, 3, 4, 5), axis=-1)
243 | except: # noqa: E722
244 | results = np.stack([0., 0., 0., 0., 0., 0., 0., 0., 0.], axis=-1)
245 | return results.reshape(1, -1)
246 | Cos, Sin = np.cos(theta), np.sin(theta)
247 | vector1 = np.concatenate([w / 2 * Cos, w / 2 * Sin], axis=-1)
248 | vector2 = np.concatenate([-h / 2 * Sin, h / 2 * Cos], axis=-1)
249 | point1 = center - vector1 - vector2
250 | point2 = center + vector1 - vector2
251 | point3 = center + vector1 + vector2
252 | point4 = center - vector1 + vector2
253 | polys = np.concatenate([point1, point2, point3, point4, score], axis=-1)
254 | polys = get_best_begin_point(polys)
255 | return polys
256 |
257 |
258 | def obb2poly(rbboxes, version='oc'):
259 | """Convert oriented bounding boxes to polygons.
260 | Args:
261 | obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]
262 | version (Str): angle representations.
263 | Returns:
264 | polys (torch.Tensor): [x0,y0,x1,y1,x2,y2,x3,y3]
265 | """
266 | if version == 'oc':
267 | results = obb2poly_oc(rbboxes)
268 | elif version == 'le135':
269 | results = obb2poly_le135(rbboxes)
270 | elif version == 'le90':
271 | results = obb2poly_le90(rbboxes)
272 | else:
273 | raise NotImplementedError
274 | return results
275 |
276 |
277 | def obb2poly_np(rbboxes, version='oc'):
278 | """Convert oriented bounding boxes to polygons.
279 | Args:
280 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle]
281 | version (Str): angle representations.
282 | Returns:
283 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3]
284 | """
285 | if version == 'oc':
286 | results = obb2poly_np_oc(rbboxes)
287 | elif version == 'le135':
288 | results = obb2poly_np_le135(rbboxes)
289 | elif version == 'le90':
290 | results = obb2poly_np_le90(rbboxes)
291 | else:
292 | raise NotImplementedError
293 | return results
294 |
295 |
296 | def poly2obb_np(polys, version='oc'):
297 | """Convert polygons to oriented bounding boxes.
298 | Args:
299 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3]
300 | version (Str): angle representations.
301 | Returns:
302 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle]
303 | """
304 | if version == 'oc':
305 | results = poly2obb_np_oc(polys)
306 | elif version == 'le135':
307 | results = poly2obb_np_le135(polys)
308 | elif version == 'le90':
309 | results = poly2obb_np_le90(polys)
310 | else:
311 | raise NotImplementedError
312 | return results
313 |
314 |
315 | def poly2obb_np_oc(poly):
316 | """ Modified !!
317 | Convert polygons to oriented bounding boxes.
318 | Args:
319 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3]
320 | Returns:
321 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle] modified -> [x_ctr, y_ctr, h, w, angle(radian)]
322 | """
323 | bboxps = np.array(poly).reshape((4, 2))
324 | rbbox = cv2.minAreaRect(bboxps)
325 | x, y, h, w, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[2]
326 | # assert 0 < a <= 90, f'error from poly2obb_np_oc function.'
327 | if w < 2 or h < 2:
328 | return
329 | while not 0 < a <= 90:
330 | if a == -90:
331 | a += 180
332 | else:
333 | a += 90
334 | w, h = h, w
335 | a = a / 180 * np.pi
336 | assert 0 < a <= np.pi / 2
337 | return x, y, h, w, a
338 |
339 |
340 | def poly2obb_np_le135(poly):
341 | """Convert polygons to oriented bounding boxes.
342 | Args:
343 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3]
344 | Returns:
345 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle]
346 | """
347 | poly = np.array(poly[:8], dtype=np.float32)
348 | pt1 = (poly[0], poly[1])
349 | pt2 = (poly[2], poly[3])
350 | pt3 = (poly[4], poly[5])
351 | pt4 = (poly[6], poly[7])
352 | edge1 = np.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[1] - pt2[1]) *
353 | (pt1[1] - pt2[1]))
354 | edge2 = np.sqrt((pt2[0] - pt3[0]) * (pt2[0] - pt3[0]) + (pt2[1] - pt3[1]) *
355 | (pt2[1] - pt3[1]))
356 | if edge1 < 2 or edge2 < 2:
357 | return
358 | width = max(edge1, edge2)
359 | height = min(edge1, edge2)
360 | angle = 0
361 | if edge1 > edge2:
362 | angle = np.arctan2(float(pt2[1] - pt1[1]), float(pt2[0] - pt1[0]))
363 | elif edge2 >= edge1:
364 | angle = np.arctan2(float(pt4[1] - pt1[1]), float(pt4[0] - pt1[0]))
365 | angle = norm_angle(angle, 'le135')
366 | x_ctr = float(pt1[0] + pt3[0]) / 2
367 | y_ctr = float(pt1[1] + pt3[1]) / 2
368 | return x_ctr, y_ctr, width, height, angle
369 |
370 |
371 | def poly2obb_np_le90(poly):
372 | """Convert polygons to oriented bounding boxes.
373 | Args:
374 | polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3]
375 | Returns:
376 | obbs (ndarray): [x_ctr,y_ctr,w,h,angle]
377 | """
378 | bboxps = np.array(poly).reshape((4, 2))
379 | rbbox = cv2.minAreaRect(bboxps)
380 | x, y, w, h, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[
381 | 2]
382 | if w < 2 or h < 2:
383 | return
384 | a = a / 180 * np.pi
385 | if w < h:
386 | w, h = h, w
387 | a += np.pi / 2
388 | while not np.pi / 2 > a >= -np.pi / 2:
389 | if a >= np.pi / 2:
390 | a -= np.pi
391 | else:
392 | a += np.pi
393 | assert np.pi / 2 > a >= -np.pi / 2
394 | return x, y, w, h, a
395 |
396 |
397 | def obb2hbb_oc(rbboxes):
398 | """ Modified !
399 | Convert oriented bounding boxes to horizontal bounding boxes.
400 |
401 | Args:
402 | obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle] modify -> [x_ctr, y_ctr, h, w, angle(radian)]
403 | Returns:
404 | hbbs (torch.Tensor): [x_ctr,y_ctr,w,h,pi/2]
405 | """
406 | h = rbboxes[:, 2::5]
407 | w = rbboxes[:, 3::5]
408 | a = rbboxes[:, 4::5]
409 | cosa = torch.cos(a)
410 | sina = torch.sin(a)
411 | hbbox_h = cosa * w + sina * h
412 | hbbox_w = sina * w + cosa * h
413 | hbboxes = rbboxes.clone().detach()
414 | hbboxes[:, 2::5] = hbbox_w
415 | hbboxes[:, 3::5] = hbbox_h
416 | hbboxes[:, 4::5] = np.pi / 2
417 | return hbboxes
418 |
--------------------------------------------------------------------------------
/utils/box_coder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class BoxCoder(object):
6 | """
7 | This class encodes and decodes a set of bounding boxes into
8 | the representation used for training the regressors.
9 |
10 | Args:
11 | encode() function:
12 | ex_rois: positive anchors: [xc, yc, w, h, angle(radian)]
13 | gt_rois: positive anchor ground-truth box: [xc, yc, h, w, angle(radian)]
14 |
15 | decode() function:
16 | boxes: anchors: [xc, yc, w, h, angle(radian)]
17 | deltas: offset: [xc_offset, yc_offset, h_offset, w_offset, angle_offset(radian)]
18 | """
19 | def __init__(self, means=(0., 0., 0., 0., 0.), stds=(0.1, 0.1, 0.1, 0.1, 0.05)):
20 | self.means = means
21 | self.stds = stds
22 |
23 | def encode(self, ex_rois, gt_rois):
24 | ex_widths = ex_rois[:, 2]
25 | ex_heights = ex_rois[:, 3]
26 | ex_widths = torch.clamp(ex_widths, min=1)
27 | ex_heights = torch.clamp(ex_heights, min=1)
28 | ex_ctr_x = ex_rois[:, 0]
29 | ex_ctr_y = ex_rois[:, 1]
30 | ex_thetas = ex_rois[:, 4]
31 |
32 | gt_widths = gt_rois[:, 3]
33 | gt_heights = gt_rois[:, 2]
34 | gt_widths = torch.clamp(gt_widths, min=1)
35 | gt_heights = torch.clamp(gt_heights, min=1)
36 | gt_ctr_x = gt_rois[:, 0]
37 | gt_ctr_y = gt_rois[:, 1]
38 | gt_thetas = gt_rois[:, 4]
39 |
40 | targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths # t_x = (x - x_a) / w_a
41 | targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights # t_y = (y - y_a) / h_a
42 | targets_dw = torch.log(gt_widths / ex_widths) # t_w = log(w / w_a)
43 | targets_dh = torch.log(gt_heights / ex_heights) # t_h = log(h / h_a)
44 | targets_dt = gt_thetas - ex_thetas
45 |
46 | targets = torch.stack(
47 | (targets_dx, targets_dy, targets_dh, targets_dw, targets_dt), dim=1)
48 |
49 | means = targets.new_tensor(self.means).unsqueeze(0)
50 | stds = targets.new_tensor(self.stds).unsqueeze(0)
51 | targets = targets.sub_(means).div_(stds)
52 | return targets
53 |
54 | def decode(self, boxes, deltas):
55 | means = deltas.new_tensor(self.means).view(1, 1, -1).repeat(1, deltas.size(1), 1)
56 | stds = deltas.new_tensor(self.stds).view(1, 1, -1).repeat(1, deltas.size(1), 1)
57 | denorm_deltas = deltas * stds + means
58 |
59 | dx = denorm_deltas[:, :, 0]
60 | dy = denorm_deltas[:, :, 1]
61 | dh = denorm_deltas[:, :, 2]
62 | dw = denorm_deltas[:, :, 3]
63 | dt = denorm_deltas[:, :, 4]
64 |
65 | widths = boxes[:, :, 2]
66 | heights = boxes[:, :, 3]
67 | widths = torch.clamp(widths, min=1)
68 | heights = torch.clamp(heights, min=1)
69 | ctr_x = boxes[:, :, 0]
70 | ctr_y = boxes[:, :, 1]
71 | thetas = boxes[:, :, 4]
72 |
73 | pred_ctr_x = ctr_x + dx * widths
74 | pred_ctr_y = ctr_y + dy * heights
75 | pred_w = torch.exp(dw) * widths
76 | pred_h = torch.exp(dh) * heights
77 | pred_t = thetas + dt
78 |
79 | pred_boxes = torch.stack([
80 | pred_ctr_x,
81 | pred_ctr_y,
82 | pred_h,
83 | pred_w,
84 | pred_t], dim=2)
85 | return pred_boxes
86 |
--------------------------------------------------------------------------------
/utils/map.py:
--------------------------------------------------------------------------------
1 | # from shapely.geometry import Polygon
2 | import glob
3 | import json
4 | import os
5 | import shutil
6 | import operator
7 | import sys
8 | import argparse
9 | import math
10 | # import shapely
11 | import cv2
12 | from shapely.geometry import Polygon, MultiPoint
13 |
14 | import numpy as np
15 | # from tqdm import tqdm
16 |
17 |
18 | def skewiou(box1, box2):
19 | box1=np.asarray(box1).reshape(4,2)
20 | box2=np.asarray(box2).reshape(4,2)
21 | # ---------------- original code ----------------------------
22 | poly1 = Polygon(box1).convex_hull
23 | poly2 = Polygon(box2).convex_hull
24 | if not poly1.is_valid or not poly2.is_valid :
25 | print('formatting errors for boxes!!!! ')
26 | return 0
27 | if poly1.area == 0 or poly2.area == 0 :
28 | return 0, 0
29 | inter = Polygon(poly1).intersection(Polygon(poly2)).area
30 | union = poly1.area + poly2.area - inter
31 | if union == 0:
32 | return 0, 0
33 | else:
34 | return inter/union, inter
35 |
36 | # ------------------ cv2 implementation -----------------------
37 |
38 |
39 |
40 | def log_average_miss_rate(precision, fp_cumsum, num_images):
41 | """
42 | log-average miss rate:
43 | Calculated by averaging miss rates at 9 evenly spaced FPPI points
44 | between 10e-2 and 10e0, in log-space.
45 |
46 | output:
47 | lamr | log-average miss rate
48 | mr | miss rate
49 | fppi | false positives per image
50 |
51 | references:
52 | [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the
53 | State of the Art." Pattern Analysis and Machine Intelligence, IEEE
54 | Transactions on 34.4 (2012): 743 - 761.
55 | """
56 |
57 | # if there were no detections of that class
58 | if precision.size == 0:
59 | lamr = 0
60 | mr = 1
61 | fppi = 0
62 | return lamr, mr, fppi
63 |
64 | fppi = fp_cumsum / float(num_images)
65 | mr = (1 - precision)
66 |
67 | fppi_tmp = np.insert(fppi, 0, -1.0)
68 | mr_tmp = np.insert(mr, 0, 1.0)
69 |
70 | # Use 9 evenly spaced reference points in log-space
71 | ref = np.logspace(-2.0, 0.0, num = 9)
72 | for i, ref_i in enumerate(ref):
73 | # np.where() will always find at least 1 index, since min(ref) = 0.01 and min(fppi_tmp) = -1.0
74 | j = np.where(fppi_tmp <= ref_i)[-1][-1]
75 | ref[i] = mr_tmp[j]
76 |
77 | # log(0) is undefined, so we use the np.maximum(1e-10, ref)
78 | lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
79 |
80 | return lamr, mr, fppi
81 |
82 | """
83 | throw error and exit
84 | """
85 | def error(msg):
86 | print(msg)
87 | sys.exit(0)
88 |
89 | """
90 | check if the number is a float between 0.0 and 1.0
91 | """
92 | def is_float_between_0_and_1(value):
93 | try:
94 | val = float(value)
95 | if val > 0.0 and val < 1.0:
96 | return True
97 | else:
98 | return False
99 | except ValueError:
100 | return False
101 |
102 | """
103 | Calculate the AP given the recall and precision array
104 | 1st) We compute a version of the measured precision/recall curve with
105 | precision monotonically decreasing
106 | 2nd) We compute the AP as the area under this curve by numerical integration.
107 | """
108 | def voc_ap(rec, prec, use_07_metric=False):
109 | """ ap = voc_ap(rec, prec, [use_07_metric])
110 | Compute VOC AP given precision and recall.
111 | If use_07_metric is true, uses the
112 | VOC 07 11 point method (default:False).
113 | """
114 | if use_07_metric:
115 | mrec = np.concatenate(([0.], rec, [1.]))
116 | mpre = np.concatenate(([0.], prec, [0.]))
117 | # 11 point metric
118 | ap = 0.
119 | for t in np.arange(0., 1.1, 0.1):
120 | if np.sum(rec >= t) == 0:
121 | p = 0
122 | else:
123 | p = np.max(np.array(prec)[rec >= t])
124 | ap = ap + p / 11.
125 | else:
126 | # correct AP calculation
127 | # first append sentinel values at the end
128 | mrec = np.concatenate(([0.], rec, [1.]))
129 | mpre = np.concatenate(([0.], prec, [0.]))
130 |
131 | # compute the precision envelope
132 | for i in range(mpre.size - 1, 0, -1):
133 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
134 |
135 | # to calculate area under PR curve, look for points
136 | # where X axis (recall) changes value
137 | i = np.where(mrec[1:] != mrec[:-1])[0]
138 |
139 | # and sum (\Delta recall) * prec
140 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
141 | return ap, mrec, mpre
142 |
143 |
144 | """
145 | Convert the lines of a file to a list
146 | """
147 | def file_lines_to_list(path):
148 | # open txt file lines to a list
149 | with open(path) as f:
150 | content = f.readlines()
151 | # remove whitespace characters like `\n` at the end of each line
152 | content = [x.strip() for x in content]
153 | return content
154 |
155 | """
156 | Draws text in image
157 | """
158 | def draw_text_in_image(img, text, pos, color, line_width):
159 | font = cv2.FONT_HERSHEY_PLAIN
160 | fontScale = 1
161 | lineType = 1
162 | bottomLeftCornerOfText = pos
163 | cv2.putText(img, text,
164 | bottomLeftCornerOfText,
165 | font,
166 | fontScale,
167 | color,
168 | lineType)
169 | text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
170 | return img, (line_width + text_width)
171 |
172 | """
173 | Plot - adjust axes
174 | """
175 | def adjust_axes(r, t, fig, axes):
176 | # get text width for re-scaling
177 | bb = t.get_window_extent(renderer=r)
178 | text_width_inches = bb.width / fig.dpi
179 | # get axis width in inches
180 | current_fig_width = fig.get_figwidth()
181 | new_fig_width = current_fig_width + text_width_inches
182 | propotion = new_fig_width / current_fig_width
183 | # get axis limit
184 | x_lim = axes.get_xlim()
185 | axes.set_xlim([x_lim[0], x_lim[1]*propotion])
186 |
187 | """
188 | Draw plot using Matplotlib
189 | """
190 | def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
191 | # sort the dictionary by decreasing value, into a list of tuples
192 | sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
193 | # unpacking the list of tuples into two lists
194 | sorted_keys, sorted_values = zip(*sorted_dic_by_value)
195 | #
196 | import matplotlib.pyplot as plt
197 | if true_p_bar != "":
198 | """
199 | Special case to draw in:
200 | - green -> TP: True Positives (object detected and matches ground-truth)
201 | - red -> FP: False Positives (object detected but does not match ground-truth)
202 | - pink -> FN: False Negatives (object not detected but present in the ground-truth)
203 | """
204 | fp_sorted = []
205 | tp_sorted = []
206 | for key in sorted_keys:
207 | fp_sorted.append(dictionary[key] - true_p_bar[key])
208 | tp_sorted.append(true_p_bar[key])
209 | plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
210 | plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
211 | # add legend
212 | plt.legend(loc='lower right')
213 | """
214 | Write number on side of bar
215 | """
216 | fig = plt.gcf() # gcf - get current figure
217 | axes = plt.gca()
218 | r = fig.canvas.get_renderer()
219 | for i, val in enumerate(sorted_values):
220 | fp_val = fp_sorted[i]
221 | tp_val = tp_sorted[i]
222 | fp_str_val = " " + str(fp_val)
223 | tp_str_val = fp_str_val + " " + str(tp_val)
224 | # trick to paint multicolor with offset:
225 | # first paint everything and then repaint the first number
226 | t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
227 | plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
228 | if i == (len(sorted_values)-1): # largest bar
229 | adjust_axes(r, t, fig, axes)
230 | else:
231 | plt.barh(range(n_classes), sorted_values, color=plot_color)
232 | """
233 | Write number on side of bar
234 | """
235 | fig = plt.gcf() # gcf - get current figure
236 | axes = plt.gca()
237 | r = fig.canvas.get_renderer()
238 | for i, val in enumerate(sorted_values):
239 | str_val = " " + str(val) # add a space before
240 | if val < 1.0:
241 | str_val = " {0:.2f}".format(val)
242 | t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
243 | # re-set axes to show number inside the figure
244 | if i == (len(sorted_values)-1): # largest bar
245 | adjust_axes(r, t, fig, axes)
246 | # set window title
247 | fig.canvas.set_window_title(window_title)
248 | # write classes in y axis
249 | tick_font_size = 12
250 | plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
251 | """
252 | Re-scale height accordingly
253 | """
254 | init_height = fig.get_figheight()
255 | # comput the matrix height in points and inches
256 | dpi = fig.dpi
257 | height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
258 | height_in = height_pt / dpi
259 | # compute the required figure height
260 | top_margin = 0.15 # in percentage of the figure height
261 | bottom_margin = 0.05 # in percentage of the figure height
262 | figure_height = height_in / (1 - top_margin - bottom_margin)
263 | # set new height
264 | if figure_height > init_height:
265 | fig.set_figheight(figure_height)
266 |
267 | # set plot title
268 | plt.title(plot_title, fontsize=14)
269 | # set axis titles
270 | # plt.xlabel('classes')
271 | plt.xlabel(x_label, fontsize='large')
272 | # adjust size of window
273 | fig.tight_layout()
274 | # save the plot
275 | fig.savefig(output_path)
276 | # show image
277 | if to_show:
278 | plt.show()
279 | # close the plot
280 | plt.close()
281 |
282 |
283 | def eval_mAP(gt_root_dir=None, test_path=None, eval_root_dir=None, use_07_metric=False, thres=0.5):
284 | """
285 | Args:
286 | thres: rotation nms threshold
287 | """
288 | MINOVERLAP = thres # default value (defined in the PASCAL VOC2012 challenge)
289 |
290 | # parser = argparse.ArgumentParser()
291 | # parser.add_argument('-na', '--no-animation', help="no animation is shown.", action="store_true")
292 | # parser.add_argument('-np', '--no-plot', help="no plot is shown.", action="store_true")
293 | # parser.add_argument('-q', '--quiet', help="minimalistic console output.", action="store_true")
294 | # # argparse receiving list of classes to be ignored (e.g., python map.py --ignore person book)
295 | # parser.add_argument('-i', '--ignore', nargs='+', type=str, help="ignore a list of classes.")
296 | # # argparse receiving list of classes with specific IoU (e.g., python map.py --set-class-iou person 0.7)
297 | # parser.add_argument('--set-class-iou', nargs='+', type=str, help="set IoU for a specific class.")
298 | # args = parser.parse_args()
299 |
300 | no_animation = False
301 | no_plot = False
302 | quiet = False
303 | ignore = None
304 | set_class_iou = None
305 |
306 | # if there are no classes to ignore then replace None by empty list
307 | if ignore is None:
308 | ignore = []
309 |
310 | specific_iou_flagged = False
311 | if set_class_iou is not None:
312 | specific_iou_flagged = True
313 |
314 | # make sure that the cwd() is the location of the python script (so that every path makes sense)
315 | # os.chdir(os.path.dirname(os.path.abspath(__file__)))
316 |
317 | GT_PATH = os.path.join(gt_root_dir, test_path)
318 | DR_PATH = os.path.join(eval_root_dir, 'detection-results')
319 | # if there are no images then no animation can be shown
320 | IMG_PATH = os.path.join(gt_root_dir, 'images-optional')
321 | if os.path.exists(IMG_PATH):
322 | for dirpath, dirnames, files in os.walk(IMG_PATH):
323 | if not files:
324 | # no image files found
325 | no_animation = True
326 | else:
327 | no_animation = True
328 |
329 | # try to import OpenCV if the user didn't choose the option --no-animation
330 | show_animation = False
331 | if not no_animation:
332 | try:
333 | import cv2
334 | show_animation = True
335 | except ImportError:
336 | print("\"opencv-python\" not found, please install to visualize the results.")
337 | no_animation = True
338 |
339 | # try to import Matplotlib if the user didn't choose the option --no-plot
340 | draw_plot = False
341 | if not no_plot:
342 | try:
343 | import matplotlib.pyplot as plt
344 | draw_plot = True
345 | except ImportError:
346 | print("\"matplotlib\" not found, please install it to get the resulting plots.")
347 | no_plot = True
348 |
349 |
350 | """
351 | Create a ".temp_files/" and "output/" directory
352 | """
353 | TEMP_FILES_PATH = os.path.join(eval_root_dir, ".temp_files")
354 | if not os.path.exists(TEMP_FILES_PATH): # if it doesn't exist already
355 | os.makedirs(TEMP_FILES_PATH)
356 | output_files_path = os.path.join(eval_root_dir, "output")
357 | if os.path.exists(output_files_path): # if it exist already
358 | # reset the output directory
359 | shutil.rmtree(output_files_path)
360 |
361 | os.makedirs(output_files_path)
362 | if draw_plot: # plot some curves
363 | os.makedirs(os.path.join(output_files_path, "classes"))
364 | if show_animation:
365 | os.makedirs(os.path.join(output_files_path, "images", "detections_one_by_one"))
366 |
367 | """
368 | ground-truth
369 | Load each of the ground-truth files into a temporary ".json" file.
370 | Create a list of all the class names present in the ground-truth (gt_classes).
371 | """
372 | # get a list with the ground-truth files
373 | ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
374 | if len(ground_truth_files_list) == 0:
375 | error("Error: No ground-truth files found!")
376 | ground_truth_files_list.sort()
377 | # dictionary with counter per class
378 | gt_counter_per_class = {} # save the number of per ground-truth
379 | counter_images_per_class = {}
380 |
381 | gt_files = []
382 | for txt_file in ground_truth_files_list:
383 | #print(txt_file)
384 | file_id = txt_file.split(".txt", 1)[0]
385 | file_id = os.path.basename(os.path.normpath(file_id))
386 | # check if there is a correspondent detection-results file
387 | temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
388 | if not os.path.exists(temp_path):
389 | error_msg = "Error. File not found: {}\n".format(temp_path)
390 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)"
391 | error(error_msg)
392 | lines_list = file_lines_to_list(txt_file)
393 | # create ground-truth dictionary
394 | bounding_boxes = []
395 | is_difficult = False
396 | already_seen_classes = []
397 | for line in lines_list:
398 | try:
399 | if "difficult" in line:
400 | class_name, x1, y1, x2, y2, x3, y3, x4, y4, _difficult = line.split()
401 | is_difficult = True
402 | else:
403 | class_name, x1, y1, x2, y2, x3, y3, x4, y4 = line.split()
404 | except ValueError:
405 | error_msg = "Error: File " + txt_file + " in the wrong format.\n"
406 | error_msg += " Expected: ['difficult']\n"
407 | error_msg += " Received: " + line
408 | error_msg += "\n\nIf you have a with spaces between words you should remove them\n"
409 | error_msg += "by running the script \"remove_space.py\" or \"rename_class.py\" in the \"extra/\" folder."
410 | error(error_msg)
411 | # check if class is in the ignore list, if yes skip
412 | if class_name in ignore:
413 | continue
414 | bbox = x1 + " " + y1 + " " + x2 + " " + y2 + " " + x3 + " " + y3 + " " + x4 + " " + y4
415 | if is_difficult:
416 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
417 | is_difficult = False
418 | else:
419 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
420 | # count that object
421 | if class_name in gt_counter_per_class:
422 | gt_counter_per_class[class_name] += 1
423 | else:
424 | # if class didn't exist yet
425 | gt_counter_per_class[class_name] = 1
426 |
427 | if class_name not in already_seen_classes:
428 | if class_name in counter_images_per_class:
429 | counter_images_per_class[class_name] += 1
430 | else:
431 | # if class didn't exist yet
432 | counter_images_per_class[class_name] = 1
433 | already_seen_classes.append(class_name)
434 |
435 |
436 | # dump bounding_boxes into a ".json" file
437 | new_temp_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
438 | gt_files.append(new_temp_file)
439 | with open(new_temp_file, 'w') as outfile:
440 | json.dump(bounding_boxes, outfile)
441 |
442 | gt_classes = list(gt_counter_per_class.keys())
443 | # let's sort the classes alphabetically
444 | gt_classes = sorted(gt_classes)
445 | n_classes = len(gt_classes)
446 | #print(gt_classes)
447 | #print(gt_counter_per_class)
448 |
449 | """
450 | Check format of the flag --set-class-iou (if used)
451 | e.g. check if class exists
452 | """
453 | if specific_iou_flagged:
454 | n_args = len( set_class_iou)
455 | error_msg = \
456 | '\n --set-class-iou [class_1] [IoU_1] [class_2] [IoU_2] [...]'
457 | if n_args % 2 != 0:
458 | error('Error, missing arguments. Flag usage:' + error_msg)
459 | # [class_1] [IoU_1] [class_2] [IoU_2]
460 | # specific_iou_classes = ['class_1', 'class_2']
461 | specific_iou_classes = set_class_iou[::2] # even
462 | # iou_list = ['IoU_1', 'IoU_2']
463 | iou_list = set_class_iou[1::2] # odd
464 | if len(specific_iou_classes) != len(iou_list):
465 | error('Error, missing arguments. Flag usage:' + error_msg)
466 | for tmp_class in specific_iou_classes:
467 | if tmp_class not in gt_classes:
468 | error('Error, unknown class \"' + tmp_class + '\". Flag usage:' + error_msg)
469 | for num in iou_list:
470 | if not is_float_between_0_and_1(num):
471 | error('Error, IoU must be between 0.0 and 1.0. Flag usage:' + error_msg)
472 |
473 | """
474 | detection-results
475 | Load each of the detection-results files into a temporary ".json" file.
476 | """
477 | # get a list with the detection-results files
478 | dr_files_list = glob.glob(DR_PATH + '/*.txt')
479 | dr_files_list.sort()
480 |
481 | for class_index, class_name in enumerate(gt_classes):
482 | bounding_boxes = []
483 | for txt_file in dr_files_list:
484 | #print(txt_file)
485 | # the first time it checks if all the corresponding ground-truth files exist
486 | file_id = txt_file.split(".txt",1)[0]
487 | file_id = os.path.basename(os.path.normpath(file_id))
488 | temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
489 | if class_index == 0:
490 | if not os.path.exists(temp_path):
491 | error_msg = "Error. File not found: {}\n".format(temp_path)
492 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)"
493 | error(error_msg)
494 | lines = file_lines_to_list(txt_file)
495 | for line in lines:
496 | try:
497 | tmp_class_name, confidence, x1, y1, x2, y2, x3, y3, x4, y4 = line.split()
498 | except ValueError:
499 | error_msg = "Error: File " + txt_file + " in the wrong format.\n"
500 | error_msg += " Expected: \n"
501 | error_msg += " Received: " + line
502 | error(error_msg)
503 | if tmp_class_name == class_name:
504 | #print("match")
505 | bbox = x1 + " " + y1 + " " + x2 + " " + y2 + " " + x3 + " " + y3 + " " + x4 + " " + y4
506 | bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
507 | #print(bounding_boxes)
508 | # sort detection-results by decreasing confidence
509 | bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
510 | with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
511 | json.dump(bounding_boxes, outfile)
512 |
513 | """
514 | Calculate the AP for each class
515 | """
516 | sum_AP = 0.0
517 | ap_dictionary = {}
518 | lamr_dictionary = {}
519 | # open file to store the output
520 | with open(output_files_path + "/output.txt", 'w') as output_file:
521 | output_file.write("# AP and precision/recall per class\n")
522 | count_true_positives = {}
523 | for class_index, class_name in enumerate(gt_classes):
524 | count_true_positives[class_name] = 0
525 | """
526 | Load detection-results of that class
527 | """
528 | dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
529 | dr_data = json.load(open(dr_file))
530 |
531 | """
532 | Assign detection-results to ground-truth objects
533 | """
534 | nd = len(dr_data) # the number of the detections
535 | tp = [0] * nd # creates an array of zeros of size nd
536 | fp = [0] * nd
537 | # print('evaluate on class: {} '.format(class_name))
538 | # for idx, detection in enumerate(tqdm(dr_data)):
539 | for _index, detection in enumerate(dr_data): # todo: idx -> _index
540 | file_id = detection["file_id"]
541 | if show_animation:
542 | # find ground truth image
543 | ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
544 | #tifCounter = len(glob.glob1(myPath,"*.tif"))
545 | if len(ground_truth_img) == 0:
546 | error("Error. Image not found with id: " + file_id)
547 | elif len(ground_truth_img) > 1:
548 | error("Error. Multiple image with id: " + file_id)
549 | else: # found image
550 | #print(IMG_PATH + "/" + ground_truth_img[0])
551 | # Load image
552 | img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
553 | # load image with draws of multiple detections
554 | img_cumulative_path = output_files_path + "/images/" + ground_truth_img[0]
555 | if os.path.isfile(img_cumulative_path):
556 | img_cumulative = cv2.imread(img_cumulative_path)
557 | else:
558 | img_cumulative = img.copy()
559 | # Add bottom border to image
560 | bottom_border = 60
561 | BLACK = [0, 0, 0]
562 | img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
563 | # as3;fmtsign detection-results to ground truth object if any
564 | # open ground-truth with that file_id
565 | gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
566 | ground_truth_data = json.load(open(gt_file))
567 | ovmax = -1
568 | gt_match = -1
569 | # load detected object bounding-box
570 | bb = [ float(x) for x in detection["bbox"].split() ]
571 | for idx, obj in enumerate(ground_truth_data):
572 | # look for a class_name match
573 | if obj["class_name"] == class_name:
574 | bbgt = [ float(x) for x in obj["bbox"].split() ]
575 | ### IoU calculation
576 | iou, inter = skewiou(bbgt, bb)
577 | if inter != 0:
578 | ov = iou
579 | if ov > ovmax:
580 | ovmax = ov
581 | gt_match = obj
582 | # bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
583 | # iw = bi[2] - bi[0] + 1
584 | # ih = bi[3] - bi[1] + 1
585 | # if iw > 0 and ih > 0:
586 | # # compute overlap (IoU) = area of intersection / area of union
587 | # ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
588 | # + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
589 | # ov = iw * ih / ua
590 | # if ov > ovmax:
591 | # ovmax = ov
592 | # gt_match = obj
593 |
594 | # assign detection as true positive/don't care/false positive
595 | if show_animation:
596 | status = "NO MATCH FOUND!" # status is only used in the animation
597 | # set minimum overlap
598 | min_overlap = MINOVERLAP
599 | if specific_iou_flagged:
600 | if class_name in specific_iou_classes:
601 | index = specific_iou_classes.index(class_name)
602 | min_overlap = float(iou_list[index])
603 | if ovmax >= min_overlap:
604 | if "difficult" not in gt_match:
605 | if not bool(gt_match["used"]):
606 | # true positive
607 | tp[_index] = 1
608 | gt_match["used"] = True
609 | count_true_positives[class_name] += 1
610 | # update the ".json" file
611 | with open(gt_file, 'w') as f:
612 | f.write(json.dumps(ground_truth_data))
613 | if show_animation:
614 | status = "MATCH!"
615 | else:
616 | # false positive (multiple detection)
617 | fp[_index] = 1
618 | if show_animation:
619 | status = "REPEATED MATCH!"
620 | else:
621 | # false positive
622 | fp[_index] = 1
623 | if ovmax > 0:
624 | status = "INSUFFICIENT OVERLAP"
625 |
626 | """
627 | Draw image to show animation
628 | """
629 | if show_animation:
630 | height, widht = img.shape[:2]
631 | # colors (OpenCV works with BGR)
632 | white = (255,255,255)
633 | light_blue = (255,200,100)
634 | green = (0,255,0)
635 | light_red = (30,30,255)
636 | # 1st line
637 | margin = 10
638 | v_pos = int(height - margin - (bottom_border / 2.0))
639 | text = "Image: " + ground_truth_img[0] + " "
640 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
641 | text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
642 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
643 | if ovmax != -1:
644 | color = light_red
645 | if status == "INSUFFICIENT OVERLAP":
646 | text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
647 | else:
648 | text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
649 | color = green
650 | img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
651 | # 2nd line
652 | v_pos += int(bottom_border / 2.0)
653 | rank_pos = str(idx+1) # rank position (idx starts at 0)
654 | text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)
655 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
656 | color = light_red
657 | if status == "MATCH!":
658 | color = green
659 | text = "Result: " + status + " "
660 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
661 |
662 | font = cv2.FONT_HERSHEY_SIMPLEX
663 | if ovmax > 0: # if there is intersections between the bounding-boxes
664 | bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
665 | cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
666 | cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
667 | cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
668 | bb = [int(i) for i in bb]
669 | cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
670 | cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
671 | cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
672 | # show image
673 | cv2.imshow("Animation", img)
674 | cv2.waitKey(20) # show for 20 ms
675 | # save image to output
676 | output_img_path = output_files_path + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
677 | cv2.imwrite(output_img_path, img)
678 | # save the image with all the objects drawn to it
679 | cv2.imwrite(img_cumulative_path, img_cumulative)
680 |
681 | #print(tp)
682 | # compute precision/recall
683 | cumsum = 0
684 | for idx, val in enumerate(fp):
685 | fp[idx] += cumsum
686 | cumsum += val
687 | cumsum = 0
688 | for idx, val in enumerate(tp):
689 | tp[idx] += cumsum
690 | cumsum += val
691 | #print(tp)
692 | rec = tp[:]
693 | for idx, val in enumerate(tp):
694 | rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name]
695 | recall = rec[-1]
696 | prec = tp[:]
697 | for idx, val in enumerate(tp):
698 | prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx] + 1e-6)
699 | precision = prec[-1]
700 |
701 | ap, mrec, mprec = voc_ap(rec[:], prec[:],use_07_metric=use_07_metric)
702 | sum_AP += ap
703 | text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
704 | """
705 | Write to output.txt
706 | """
707 | rounded_prec = [ '%.2f' % elem for elem in prec ]
708 | rounded_rec = [ '%.2f' % elem for elem in rec ]
709 | output_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
710 | # if not quiet:
711 | # print(text)
712 | ap_dictionary[class_name] = ap
713 |
714 | n_images = counter_images_per_class[class_name]
715 | lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
716 | lamr_dictionary[class_name] = lamr
717 |
718 | """
719 | Draw plot
720 | """
721 | if draw_plot:
722 | plt.plot(rec, prec, '-o')
723 | # add a new penultimate point to the list (mrec[-2], 0.0)
724 | # since the last line segment (and respective area) do not affect the AP value
725 | area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
726 | area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
727 | plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
728 | # set window title
729 | fig = plt.gcf() # gcf - get current figure
730 | fig.canvas.set_window_title('AP ' + class_name)
731 | # set plot title
732 | plt.title('class: ' + text)
733 | #plt.suptitle('This is a somewhat long figure title', fontsize=16)
734 | # set axis titles
735 | plt.xlabel('Recall')
736 | plt.ylabel('Precision')
737 | # optional - set axes
738 | axes = plt.gca() # gca - get current axes
739 | axes.set_xlim([0.0,1.0])
740 | axes.set_ylim([0.0,1.05]) # .05 to give some extra space
741 | # Alternative option -> wait for button to be pressed
742 | #while not plt.waitforbuttonpress(): pass # wait for key display
743 | # Alternative option -> normal display
744 | #plt.show()
745 | # save the plot
746 | fig.savefig(output_files_path + "/classes/" + class_name + ".png")
747 | plt.cla() # clear axes for next plot
748 |
749 | if show_animation:
750 | cv2.destroyAllWindows()
751 |
752 | output_file.write("\n# mAP of all classes\n")
753 | mAP = sum_AP / n_classes
754 | text = "mAP = {0:.2f}%".format(mAP*100)
755 | output_file.write(text + "\n")
756 | # print(text)
757 |
758 | """
759 | Draw false negatives
760 | """
761 | # pink = (203,192,255)
762 | # for tmp_file in gt_files:
763 | # ground_truth_data = json.load(open(tmp_file))
764 | # #print(ground_truth_data)
765 | # # get name of corresponding image
766 | # start = TEMP_FILES_PATH + '/'
767 | # img_id = tmp_file[tmp_file.find(start)+len(start):tmp_file.rfind('_ground_truth.json')]
768 | # img_cumulative_path = output_files_path + "/images/" + img_id + ".jpg"
769 | # import cv2
770 | # img = cv2.imread(img_cumulative_path)
771 | # if img is None:
772 | # img_path = IMG_PATH + '/' + img_id + ".jpg"
773 | # img = cv2.imread(img_path)
774 | # draw false negatives
775 | # for obj in ground_truth_data:
776 | # if not obj['used']:
777 | # bbgt = [ int(round(float(x))) for x in obj["bbox"].split() ]
778 | # cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),pink,2)
779 | # cv2.imwrite(img_cumulative_path, img)
780 |
781 | # remove the temp_files directory
782 | shutil.rmtree(TEMP_FILES_PATH)
783 |
784 | """
785 | Count total of detection-results
786 | """
787 | # iterate through all the files
788 | det_counter_per_class = {}
789 | for txt_file in dr_files_list:
790 | # get lines to list
791 | lines_list = file_lines_to_list(txt_file)
792 | for line in lines_list:
793 | class_name = line.split()[0]
794 | # check if class is in the ignore list, if yes skip
795 | if class_name in ignore:
796 | continue
797 | # count that object
798 | if class_name in det_counter_per_class:
799 | det_counter_per_class[class_name] += 1
800 | else:
801 | # if class didn't exist yet
802 | det_counter_per_class[class_name] = 1
803 | #print(det_counter_per_class)
804 | dr_classes = list(det_counter_per_class.keys())
805 |
806 |
807 | """
808 | Plot the total number of occurences of each class in the ground-truth
809 | """
810 | if draw_plot:
811 | window_title = "ground-truth-info"
812 | plot_title = "ground-truth\n"
813 | plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
814 | x_label = "Number of objects per class"
815 | output_path = output_files_path + "/ground-truth-info.png"
816 | to_show = False
817 | plot_color = 'forestgreen'
818 | draw_plot_func(
819 | gt_counter_per_class,
820 | n_classes,
821 | window_title,
822 | plot_title,
823 | x_label,
824 | output_path,
825 | to_show,
826 | plot_color,
827 | '',
828 | )
829 |
830 | """
831 | Write number of ground-truth objects per class to results.txt
832 | """
833 | with open(output_files_path + "/output.txt", 'a') as output_file:
834 | output_file.write("\n# Number of ground-truth objects per class\n")
835 | for class_name in sorted(gt_counter_per_class):
836 | output_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
837 |
838 | """
839 | Finish counting true positives
840 | """
841 | for class_name in dr_classes:
842 | # if class exists in detection-result but not in ground-truth then there are no true positives in that class
843 | if class_name not in gt_classes:
844 | count_true_positives[class_name] = 0
845 | #print(count_true_positives)
846 |
847 | """
848 | Plot the total number of occurences of each class in the "detection-results" folder
849 | """
850 | if draw_plot:
851 | window_title = "detection-results-info"
852 | # Plot title
853 | plot_title = "detection-results\n"
854 | plot_title += "(" + str(len(dr_files_list)) + " files and "
855 | count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
856 | plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
857 | # end Plot title
858 | x_label = "Number of objects per class"
859 | output_path = output_files_path + "/detection-results-info.png"
860 | to_show = False
861 | plot_color = 'forestgreen'
862 | true_p_bar = count_true_positives
863 | try:
864 | draw_plot_func(
865 | det_counter_per_class,
866 | len(det_counter_per_class),
867 | window_title,
868 | plot_title,
869 | x_label,
870 | output_path,
871 | to_show,
872 | plot_color,
873 | true_p_bar
874 | )
875 | except:
876 | pass
877 |
878 | """
879 | Write number of detected objects per class to output.txt
880 | """
881 | with open(output_files_path + "/output.txt", 'a') as output_file:
882 | output_file.write("\n# Number of detected objects per class\n")
883 | for class_name in sorted(dr_classes):
884 | n_det = det_counter_per_class[class_name]
885 | text = class_name + ": " + str(n_det)
886 | text += " (tp:" + str(count_true_positives[class_name]) + ""
887 | text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
888 | output_file.write(text)
889 |
890 | """
891 | Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
892 | """
893 | if draw_plot:
894 | window_title = "lamr"
895 | plot_title = "log-average miss rate"
896 | x_label = "log-average miss rate"
897 | output_path = output_files_path + "/lamr.png"
898 | to_show = False
899 | plot_color = 'royalblue'
900 | draw_plot_func(
901 | lamr_dictionary,
902 | n_classes,
903 | window_title,
904 | plot_title,
905 | x_label,
906 | output_path,
907 | to_show,
908 | plot_color,
909 | ""
910 | )
911 |
912 | """
913 | Draw mAP plot (Show AP's of all classes in decreasing order)
914 | """
915 | if draw_plot:
916 | window_title = "mAP"
917 | plot_title = "mAP = {0:.2f}%".format(mAP*100)
918 | x_label = "Average Precision"
919 | output_path = output_files_path + "/mAP.png"
920 | to_show = False
921 | plot_color = 'royalblue'
922 | draw_plot_func(
923 | ap_dictionary,
924 | n_classes,
925 | window_title,
926 | plot_title,
927 | x_label,
928 | output_path,
929 | to_show,
930 | plot_color,
931 | ""
932 | )
933 | return mAP, precision, recall
934 |
--------------------------------------------------------------------------------
/utils/rotation_nms/.gitignore:
--------------------------------------------------------------------------------
1 | *.cpp
2 | *.so
3 |
--------------------------------------------------------------------------------
/utils/rotation_nms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/rotation_nms/__init__.py
--------------------------------------------------------------------------------
/utils/rotation_nms/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/rotation_nms/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/rotation_nms/cpu_nms.pyx:
--------------------------------------------------------------------------------
1 | # ----------------------------------------------------------
2 | # Soft-NMS: Improving Object Detection With One Line of Code
3 | # Copyright (c) University of Maryland, College Park
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Navaneeth Bodla and Bharat Singh
6 | # ----------------------------------------------------------
7 |
8 | import cv2
9 | import numpy as np
10 | cimport numpy as np
11 |
12 | cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
13 | return a if a >= b else b
14 |
15 | cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
16 | return a if a <= b else b
17 |
18 | def cpu_soft_nms(
19 | np.ndarray[float, ndim=2] boxes,
20 | float thresh=0.3,
21 | unsigned int method=1,
22 | float sigma=0.5,
23 | float min_score=0.001
24 | ):
25 | cdef unsigned int N = boxes.shape[0]
26 | cdef float iw, ih, box_area
27 | cdef float ua
28 | cdef int pos = 0
29 | cdef float maxscore = 0
30 | cdef int maxpos = 0
31 | cdef float x1, x2, y1, y2, t, s, w, h, xx, yy, tx1, tx2, ty1, ty2, tt, ts, tw, th, txx, tyy, area, weight, ov, inter
32 |
33 | inds = np.arange(N)
34 | for i in range(N):
35 | maxscore = boxes[i, 5]
36 | maxpos = i
37 |
38 | tx1 = boxes[i, 0]
39 | ty1 = boxes[i, 1]
40 | tx2 = boxes[i, 2]
41 | ty2 = boxes[i, 3]
42 | tt = boxes[i, 4]
43 | ts = boxes[i, 5]
44 | ti = inds[i]
45 |
46 | pos = i + 1
47 | # get max box
48 | while pos < N:
49 | if maxscore < boxes[pos, 5]:
50 | maxscore = boxes[pos, 5]
51 | maxpos = pos
52 | pos = pos + 1
53 |
54 | # add max box as a detection
55 | boxes[i, 0] = boxes[maxpos, 0]
56 | boxes[i, 1] = boxes[maxpos, 1]
57 | boxes[i, 2] = boxes[maxpos, 2]
58 | boxes[i, 3] = boxes[maxpos, 3]
59 | boxes[i, 4] = boxes[maxpos, 4]
60 | boxes[i, 5] = boxes[maxpos, 5]
61 | inds[i] = inds[maxpos]
62 |
63 | # swap ith box with position of max box
64 | boxes[maxpos, 0] = tx1
65 | boxes[maxpos, 1] = ty1
66 | boxes[maxpos, 2] = tx2
67 | boxes[maxpos, 3] = ty2
68 | boxes[maxpos, 4] = tt
69 | boxes[maxpos, 5] = ts
70 | inds[maxpos] = ti
71 |
72 | tx1 = boxes[i, 0]
73 | ty1 = boxes[i, 1]
74 | tx2 = boxes[i, 2]
75 | ty2 = boxes[i, 3]
76 | tt = boxes[i, 4]
77 | ts = boxes[i, 5]
78 |
79 | tw = tx2 - tx1
80 | th = ty2 - ty1
81 | txx = tx1 + tw * 0.5
82 | tyy = ty1 + th * 0.5
83 |
84 | pos = i + 1
85 | # NMS iterations, note that N changes if detection boxes fall below threshold
86 | while pos < N:
87 | x1 = boxes[pos, 0]
88 | y1 = boxes[pos, 1]
89 | x2 = boxes[pos, 2]
90 | y2 = boxes[pos, 3]
91 | t = boxes[pos, 4]
92 | s = boxes[pos, 5]
93 |
94 | w = x2 - x1
95 | h = y2 - y1
96 | xx = x1 + w * 0.5
97 | yy = y1 + h * 0.5
98 |
99 | rtn, contours = cv2.rotatedRectangleIntersection(
100 | ((txx, tyy), (tw, th), tt),
101 | ((xx, yy), (w, h), t)
102 | )
103 | if rtn == 1:
104 | inter = np.round(np.abs(cv2.contourArea(contours)))
105 | elif rtn == 2:
106 | inter = min(tw * th, w * h)
107 | else:
108 | inter = 0.0
109 |
110 | if inter > 0.0:
111 | # iou between max box and detection box
112 | ov = inter / (tw * th + w * h - inter)
113 | if method == 1: # linear
114 | if ov > thresh:
115 | weight = 1 - ov
116 | else:
117 | weight = 1
118 | elif method == 2: # gaussian
119 | weight = np.exp(-(ov * ov) / sigma)
120 | else: # original NMS
121 | if ov > thresh:
122 | weight = 0
123 | else:
124 | weight = 1
125 | boxes[pos, 5] = weight * boxes[pos, 5]
126 | # if box score falls below threshold, discard the box by swapping with last box, update N
127 | if boxes[pos, 5] < min_score:
128 | boxes[pos, 0] = boxes[N-1, 0]
129 | boxes[pos, 1] = boxes[N-1, 1]
130 | boxes[pos, 2] = boxes[N-1, 2]
131 | boxes[pos, 3] = boxes[N-1, 3]
132 | boxes[pos, 4] = boxes[N-1, 4]
133 | boxes[pos, 5] = boxes[N-1, 5]
134 | inds[pos] = inds[N - 1]
135 | N = N - 1
136 | pos = pos - 1
137 | pos = pos + 1
138 |
139 | return inds[:N]
140 |
141 |
142 | def cpu_nms(
143 | np.ndarray[np.float32_t, ndim=2] dets,
144 | np.float thresh
145 | ):
146 | cdef np.ndarray[np.float32_t, ndim=1] ws = dets[:, 2]
147 | cdef np.ndarray[np.float32_t, ndim=1] hs = dets[:, 3]
148 | cdef np.ndarray[np.float32_t, ndim=1] xx = dets[:, 0]
149 | cdef np.ndarray[np.float32_t, ndim=1] yy = dets[:, 1]
150 | cdef np.ndarray[np.float32_t, ndim=1] tt = dets[:, 4]
151 | cdef np.ndarray[np.float32_t, ndim=1] areas = ws * hs
152 |
153 | cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 5]
154 | cdef np.ndarray[np.intp_t, ndim=1] order = scores.argsort()[::-1]
155 |
156 | cdef int ndets = dets.shape[0]
157 | cdef np.ndarray[np.int_t, ndim=1] suppressed = np.zeros((ndets), dtype=np.int)
158 |
159 | cdef int _i, _j, i, j, rtn
160 | cdef np.float32_t inter, ovr
161 |
162 | keep = []
163 | for _i in range(ndets):
164 | i = order[_i]
165 | if suppressed[i] == 1:
166 | continue
167 | keep.append(i)
168 | for _j in range(_i + 1, ndets):
169 | j = order[_j]
170 | if suppressed[j] == 1:
171 | continue
172 | rtn, contours = cv2.rotatedRectangleIntersection(
173 | ((xx[i], yy[i]), (ws[i], hs[i]), tt[i]),
174 | ((xx[j], yy[j]), (ws[j], hs[j]), tt[j])
175 | )
176 | if rtn == 1:
177 | inter = np.round(np.abs(cv2.contourArea(contours)))
178 | elif rtn == 2:
179 | inter = min(areas[i], areas[j])
180 | else:
181 | inter = 0.0
182 | ovr = inter / (areas[i] + areas[j] - inter + 1e-6)
183 | if ovr >= thresh:
184 | suppressed[j] = 1
185 |
186 | return keep
187 |
--------------------------------------------------------------------------------
/utils/rotation_overlaps/.gitignore:
--------------------------------------------------------------------------------
1 | *.cpp
2 | *.so
3 |
--------------------------------------------------------------------------------
/utils/rotation_overlaps/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/rotation_overlaps/__init__.py
--------------------------------------------------------------------------------
/utils/rotation_overlaps/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HsLOL/Rotation-RetinaNet-PyTorch/386114f4892356e44857e924adf6faf963625c9d/utils/rotation_overlaps/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/rotation_overlaps/rbox_overlaps.pyx:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | cimport cython
4 | cimport numpy as np
5 |
6 | ctypedef np.float32_t DTYPE_t
7 |
8 |
9 | def rbox_overlaps(
10 | np.ndarray[DTYPE_t, ndim=2] boxes,
11 | np.ndarray[DTYPE_t, ndim=2] query_boxes,
12 | np.ndarray[DTYPE_t, ndim=2] indicator=None,
13 | np.float thresh=1e-4):
14 | """
15 | Parameters:
16 | boxes: (N, 5) ndarray of float: [xc, yc, w, h, angle(radian)]
17 | query_boxes: (K, 5) ndarray of float: [xc, yc, w, h, angle(radian)]
18 |
19 | Returns:
20 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes
21 | """
22 | cdef unsigned int N = boxes.shape[0]
23 | cdef unsigned int K = query_boxes.shape[0]
24 | cdef DTYPE_t box_area
25 | cdef DTYPE_t ua, ia
26 | cdef unsigned int k, n, rtn
27 | cdef np.ndarray[DTYPE_t, ndim=3] contours
28 | cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=np.float32)
29 |
30 | cdef np.ndarray[DTYPE_t, ndim=1] a_tt = boxes[:, 4] * 180 / np.pi
31 | cdef np.ndarray[DTYPE_t, ndim=1] a_ws = boxes[:, 2]
32 | cdef np.ndarray[DTYPE_t, ndim=1] a_hs = boxes[:, 3]
33 | cdef np.ndarray[DTYPE_t, ndim=1] a_xx = boxes[:, 0]
34 | cdef np.ndarray[DTYPE_t, ndim=1] a_yy = boxes[:, 1]
35 |
36 | cdef np.ndarray[DTYPE_t, ndim=1] b_tt = query_boxes[:, 4] * 180 / np.pi
37 | cdef np.ndarray[DTYPE_t, ndim=1] b_ws = query_boxes[:, 2]
38 | cdef np.ndarray[DTYPE_t, ndim=1] b_hs = query_boxes[:, 3]
39 | cdef np.ndarray[DTYPE_t, ndim=1] b_xx = query_boxes[:, 0]
40 | cdef np.ndarray[DTYPE_t, ndim=1] b_yy = query_boxes[:, 1]
41 |
42 | for k in range(K):
43 | box_area = b_ws[k] * b_hs[k]
44 | for n in range(N):
45 | if indicator is not None and indicator[n, k] < thresh:
46 | continue
47 | ua = a_ws[n] * a_hs[n] + box_area
48 | rtn, contours = cv2.rotatedRectangleIntersection(
49 | ((a_xx[n], a_yy[n]), (a_ws[n], a_hs[n]), a_tt[n]),
50 | ((b_xx[k], b_yy[k]), (b_ws[k], b_hs[k]), b_tt[k])
51 | )
52 | if rtn == 1:
53 | ia = np.round(np.abs(cv2.contourArea(contours)))
54 | overlaps[n, k] = ia / (ua - ia)
55 | elif rtn == 2:
56 | ia = np.minimum(ua - box_area, box_area)
57 | overlaps[n, k] = ia / (ua - ia)
58 | return overlaps
59 |
60 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 | import torchvision.transforms as transforms
5 | import random
6 | import torch.nn as nn
7 |
8 |
9 | def set_random_seed(seed, deterministic=False):
10 | """Set random seed.
11 | Args:
12 | deterministic is set True if use torch.backends.cudnn.deterministic
13 | Default is False.
14 | """
15 | print(f'[Info]: Set random seed to {seed}, deterministic: {deterministic}.')
16 | random.seed(seed)
17 | np.random.seed(seed)
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed_all(seed)
20 |
21 | if deterministic:
22 | torch.backends.cudnn.deterministic = True
23 | torch.backends.cudnn.benchmark = False
24 |
25 |
26 | def xavier_init(module, gain=1, bias=0, distribution='normal'):
27 | assert distribution in ['uniform', 'normal']
28 | if hasattr(module, 'weight') and module.weight is not None:
29 | if distribution == 'uniform':
30 | nn.init.xavier_uniform_(module.weight, gain=gain)
31 | else:
32 | nn.init.xavier_normal_(module.weight, gain=gain)
33 |
34 | if hasattr(module, 'bias') and module.bias is not None:
35 | nn.init.constant_(module.bias, bias)
36 |
37 |
38 | def kaiming_init(module, a=0, mode='fan_out', nonlinearity='relu', bias=0, distribution='normal'):
39 | assert distribution in ['uniform', 'normal']
40 | if hasattr(module, 'weight') and module.weight is not None:
41 | if distribution == 'uniform':
42 | nn.init.kaiming_uniform_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
43 | else:
44 | nn.init.kaiming_normal_(module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
45 | if hasattr(module, 'bias') and module.bias is not None:
46 | nn.init.constant_(module.bias, bias)
47 |
48 |
49 | def constant_init(module, val, bias=0):
50 | if hasattr(module, 'weight') and module.weight is not None:
51 | nn.init.constant_(module.weight, val)
52 | if hasattr(module, 'bias') and module.bias is not None:
53 | nn.init.constant_(module.bias, bias)
54 |
55 |
56 | def normal_init(module, mean=0, std=1, bias=0):
57 | if hasattr(module, 'weight') and module.weight is not None:
58 | nn.init.normal_(module.weight, mean, std)
59 | if hasattr(module, 'bias') and module.bias is not None:
60 | nn.init.constant_(module.bias, bias)
61 |
62 |
63 | def pretty_print(num_params, units=None, precision=2):
64 | if units is None:
65 | if num_params // 10**6 > 0:
66 | print(f'[Info]: Model Params = {str(round(num_params / 10**6, precision))}' + ' M')
67 | elif num_params // 10**3:
68 | print(f'[Info]: Model Params = {str(round(num_params / 10**3, precision))}' + ' k')
69 | else:
70 | print(f'[Info]: Model Params = {str(num_params)}')
71 |
72 |
73 | def count_param(model, units=None, precision=2):
74 | """Count Params"""
75 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
76 | pretty_print(num_params)
77 |
78 |
79 | def show_args(args):
80 | print('=============== Show Args ===============')
81 | for k in list(vars(args).keys()):
82 | print('%s: %s' % (k, vars(args)[k]))
83 |
84 |
85 | def clip_boxes(boxes, ims):
86 | _, _, h, w = ims.shape
87 | boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0)
88 | boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0)
89 | boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=w)
90 | boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=h)
91 | return boxes
92 |
93 |
94 | # (num_boxes, 5) xyxya
95 | def min_area_square(rboxes):
96 | w = rboxes[:, 2] - rboxes[:, 0]
97 | h = rboxes[:, 3] - rboxes[:, 1]
98 | ctr_x = rboxes[:, 0] + w * 0.5
99 | ctr_y = rboxes[:, 1] + h * 0.5
100 | s = torch.max(w, h)
101 | return torch.stack((
102 | ctr_x - s * 0.5, ctr_y - s * 0.5,
103 | ctr_x + s * 0.5, ctr_y + s * 0.5),
104 | dim=1
105 | )
106 |
107 |
108 | def rbox_overlaps(boxes, query_boxes, indicator=None, thresh=1e-1):
109 | # rewrited by cython
110 | N = boxes.shape[0]
111 | K = query_boxes.shape[0]
112 |
113 | a_tt = boxes[:, 4]
114 | a_ws = boxes[:, 2] - boxes[:, 0]
115 | a_hs = boxes[:, 3] - boxes[:, 1]
116 | a_xx = boxes[:, 0] + a_ws * 0.5
117 | a_yy = boxes[:, 1] + a_hs * 0.5
118 |
119 | b_tt = query_boxes[:, 4]
120 | b_ws = query_boxes[:, 2] - query_boxes[:, 0]
121 | b_hs = query_boxes[:, 3] - query_boxes[:, 1]
122 | b_xx = query_boxes[:, 0] + b_ws * 0.5
123 | b_yy = query_boxes[:, 1] + b_hs * 0.5
124 |
125 | overlaps = np.zeros((N, K), dtype=np.float32)
126 | for k in range(K):
127 | box_area = b_ws[k] * b_hs[k]
128 | for n in range(N):
129 | if indicator is not None and indicator[n, k] < thresh:
130 | continue
131 | ua = a_ws[n] * a_hs[n] + box_area
132 | rtn, contours = cv2.rotatedRectangleIntersection(
133 | ((a_xx[n], a_yy[n]), (a_ws[n], a_hs[n]), a_tt[n]),
134 | ((b_xx[k], b_yy[k]), (b_ws[k], b_hs[k]), b_tt[k])
135 | )
136 | if rtn == 1:
137 | ia = cv2.contourArea(contours)
138 | overlaps[n, k] = ia / (ua - ia)
139 | elif rtn == 2:
140 | ia = np.minimum(ua - box_area, box_area)
141 | overlaps[n, k] = ia / (ua - ia)
142 | return overlaps
143 |
144 |
145 | def bbox_overlaps(boxes, query_boxes):
146 | """Calculate the horizontal overlaps
147 |
148 | Args:
149 | boxes: [xc, yc, w, h, angle]
150 | query_boxes: [xc, yc, w, h, pi/2]
151 | """
152 | if not isinstance(boxes, float): # apex
153 | boxes = boxes.float()
154 |
155 | # convert the [xc, yc, w, h, angle] to [x1, y1, x2, y2, angle]
156 | query_boxes[:, 0] = query_boxes[:, 0] - query_boxes[:, 2] / 2
157 | query_boxes[:, 1] = query_boxes[:, 1] - query_boxes[:, 3] / 2
158 | query_boxes[:, 2] = query_boxes[:, 0] + query_boxes[:, 2]
159 | query_boxes[:, 3] = query_boxes[:, 1] + query_boxes[:, 3]
160 |
161 | boxes[:, 0] = boxes[:, 0] - boxes[:, 2] / 2
162 | boxes[:, 1] = boxes[:, 1] - boxes[:, 3] / 2
163 | boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
164 | boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
165 |
166 | area = (query_boxes[:, 2] - query_boxes[:, 0]) * \
167 | (query_boxes[:, 3] - query_boxes[:, 1])
168 | iw = torch.min(torch.unsqueeze(boxes[:, 2], dim=1), query_boxes[:, 2]) - \
169 | torch.max(torch.unsqueeze(boxes[:, 0], 1), query_boxes[:, 0])
170 | ih = torch.min(torch.unsqueeze(boxes[:, 3], dim=1), query_boxes[:, 3]) - \
171 | torch.max(torch.unsqueeze(boxes[:, 1], 1), query_boxes[:, 1])
172 | iw = torch.clamp(iw, min=0)
173 | ih = torch.clamp(ih, min=0)
174 | ua = torch.unsqueeze((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]), dim=1) + area - iw * ih
175 | ua = torch.clamp(ua, min=1e-8)
176 | intersection = iw * ih
177 | return intersection / ua
178 |
179 |
180 | def rescale(im, target_size, max_size, keep_ratio, multiple=32):
181 | im_shape = im.shape
182 | im_size_min = np.min(im_shape[0:2])
183 | im_size_max = np.max(im_shape[0:2])
184 | if keep_ratio:
185 | # scale method 1:
186 | # scale the shorter side to target size by the constraint of the max size
187 | im_scale = float(target_size) / float(im_size_min)
188 | if np.round(im_scale * im_size_max) > max_size:
189 | im_scale = float(max_size) / float(im_size_max)
190 | im_scale_x = np.floor(im.shape[1] * im_scale / multiple) * multiple / im.shape[1]
191 | im_scale_y = np.floor(im.shape[0] * im_scale / multiple) * multiple / im.shape[0]
192 | im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_LINEAR)
193 | im_scale = np.array([im_scale_x, im_scale_y, im_scale_x, im_scale_y])
194 |
195 | # scale method 2:
196 | # scale the longer side to target size
197 | # im_scale = float(target_size) / float(im_size_max)
198 | # im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
199 | # im_scale = np.array([im_scale, im_scale, im_scale, im_scale])
200 |
201 | else:
202 | target_size = int(np.floor(float(target_size) / multiple) * multiple)
203 | im_scale_x = float(target_size) / float(im_shape[1])
204 | im_scale_y = float(target_size) / float(im_shape[0])
205 | im = cv2.resize(im, (target_size, target_size), interpolation=cv2.INTER_LINEAR)
206 | im_scale = np.array([im_scale_x, im_scale_y, im_scale_x, im_scale_y])
207 | return im, im_scale
208 |
209 |
210 | class Rescale(object):
211 | def __init__(self, target_size, keep_ratio):
212 | self.target_size = target_size
213 | self.keep_ratio = keep_ratio
214 | self.max_size = 2000 # for scale method 1
215 |
216 | def __call__(self, image):
217 | im, im_scale = rescale(image, target_size=self.target_size, max_size=self.max_size,
218 | keep_ratio=self.keep_ratio)
219 | return im, im_scale
220 |
221 |
222 | class Normalize(object):
223 | def __init__(self):
224 | self._transform = transforms.Compose([
225 | transforms.ToTensor(),
226 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # mean / std
227 |
228 | def __call__(self, im):
229 | im = self._transform(im)
230 | return im
231 |
232 |
233 | class Reshape(object):
234 | def __init__(self, unsqueeze=True):
235 | self._unsqueeze = unsqueeze
236 | return
237 |
238 | def __call__(self, ims):
239 | if not torch.is_tensor(ims):
240 | ims = torch.from_numpy(ims.transpose((2, 0, 1)))
241 | if self._unsqueeze:
242 | ims = ims.unsqueeze(0)
243 | return ims
244 |
245 |
--------------------------------------------------------------------------------
/warmup.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import _LRScheduler
3 |
4 |
5 | class WarmupLR(_LRScheduler):
6 | def __init__(self, scheduler, init_lr=1e-3, num_warmup=1, warmup_strategy='linear'):
7 | if warmup_strategy not in ['linear', 'cos', 'constant']:
8 | raise ValueError(
9 | "Expect warmup_strategy to be one of ['linear', 'cos', 'constant'] but got {}".format(warmup_strategy))
10 | self._scheduler = scheduler
11 | self._init_lr = init_lr
12 | self._num_warmup = num_warmup
13 | self._step_count = 0
14 |
15 | # Define the strategy to warm up learning rate
16 | self._warmup_strategy = warmup_strategy
17 | if warmup_strategy == 'cos':
18 | self._warmup_func = self._warmup_cos
19 | elif warmup_strategy == 'linear':
20 | self._warmup_func = self._warmup_linear
21 | else:
22 | self._warmup_func = self._warmup_const
23 |
24 | # save initial learning rate of each param group
25 | # only useful when each param groups having different learning rate
26 | self._format_param()
27 |
28 | def __getattr__(self, name):
29 | return getattr(self._scheduler, name)
30 |
31 | def state_dict(self):
32 | """Returns the state of the scheduler as a :class:`dict`.
33 | It contains an entry for every variable in self.__dict__ which
34 | is not the optimizer.
35 | """
36 | wrapper_state_dict = {key: value for key, value in self.__dict__.items() if
37 | (key != 'optimizer' and key != '_scheduler')}
38 | wrapped_state_dict = {key: value for key, value in self._scheduler.__dict__.items() if key != 'optimizer'}
39 | return {'wrapped': wrapped_state_dict, 'wrapper': wrapper_state_dict}
40 |
41 | def load_state_dict(self, state_dict):
42 | """Loads the schedulers state.
43 | Arguments:
44 | state_dict (dict): scheduler state. Should be an object returned
45 | from a call to :meth:`state_dict`.
46 | """
47 | self.__dict__.update(state_dict['wrapper'])
48 | self._scheduler.__dict__.update(state_dict['wrapped'])
49 |
50 | def _format_param(self):
51 | # learning rate of each param group will increase
52 | # from the min_lr to initial_lr
53 | for group in self._scheduler.optimizer.param_groups:
54 | group['warmup_max_lr'] = group['lr']
55 | group['warmup_initial_lr'] = min(self._init_lr, group['lr'])
56 |
57 | def _warmup_cos(self, start, end, pct):
58 | """cosine annealing function:
59 | current = end + 0.5 * (start + end) * (1 + cos(t_current / t_total * pi)). """
60 | cos_out = math.cos(math.pi * pct) + 1
61 | return end + (start - end) / 2.0 * cos_out
62 |
63 | def _warmup_const(self, start, end, pct):
64 | return start if pct < 0.9999 else end
65 |
66 | def _warmup_linear(self, start, end, pct):
67 | return (end - start) * pct + start
68 |
69 | def get_lr(self):
70 | lrs = []
71 | step_num = self._step_count
72 | # warm up learning rate
73 | if step_num <= self._num_warmup:
74 | for group in self._scheduler.optimizer.param_groups:
75 | computed_lr = self._warmup_func(group['warmup_initial_lr'],
76 | group['warmup_max_lr'],
77 | step_num / self._num_warmup)
78 | lrs.append(computed_lr)
79 | else:
80 | lrs = self._scheduler.get_lr()
81 | return lrs
82 |
83 | def step(self, *args):
84 | if self._step_count <= self._num_warmup:
85 | values = self.get_lr()
86 | for param_group, lr in zip(self._scheduler.optimizer.param_groups, values):
87 | param_group['lr'] = lr
88 | self._step_count += 1
89 | else:
90 | # method 1:
91 | # self._scheduler.step(epoch=self._step_count)
92 | # self._step_count += 1
93 |
94 | # method 2:
95 | self._scheduler._step_count = self._step_count + 1
96 | self._scheduler.last_epoch = self._step_count
97 | self._scheduler.step()
98 | self._step_count += 1
--------------------------------------------------------------------------------