├── README.md
├── configs
├── _base_
│ ├── datasets
│ │ ├── dior.py
│ │ ├── dota.py
│ │ ├── dota_coco.py
│ │ ├── dota_ms.py
│ │ ├── dota_qbox.py
│ │ ├── dotav15.py
│ │ ├── dotav2.py
│ │ ├── hrsc.py
│ │ ├── hrsid.py
│ │ ├── rsdd.py
│ │ ├── srsdd.py
│ │ └── ssdd.py
│ ├── default_runtime.py
│ └── schedules
│ │ ├── schedule_1x.py
│ │ ├── schedule_3x.py
│ │ ├── schedule_40e.py
│ │ └── schedule_6x.py
└── rotated_fcos
│ ├── README.md
│ ├── metafile.yml
│ ├── rotated-fcos-hbox-le90_r50_fpn_1x_dota.py
│ ├── rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota.py
│ ├── rotated-fcos-le90_r50_fpn_1x_dota.py
│ ├── rotated-fcos-le90_r50_fpn_kld_1x_dota.py
│ └── rotated-fcos-le90_r50_fpn_rr-6x_hrsc.py
├── data.py
├── engine.py
├── main_rdet-sam_dota.py
├── main_sam_dota.py
├── transforms.py
├── utils.py
└── visualizer.py
/README.md:
--------------------------------------------------------------------------------
1 | # SAM-RBox
2 | This is an implementation of [SAM (Segment Anything Model)](https://github.com/facebookresearch/segment-anything) for generating rotated bounding boxes with [MMRotate](https://github.com/open-mmlab/mmrotate), which is a comparison method of [H2RBox-v2: Boosting HBox-supervised Oriented Object Detection via Symmetric Learning](https://arxiv.org/abs/2304.04403).
3 |
4 | **NOTE:** This project has been involved into OpenMMLab's new repo [**_PlayGround_**](https://github.com/open-mmlab/playground). For more details, please refer to [this](https://github.com/open-mmlab/playground/blob/main/mmrotate_sam/README.md).
5 |
6 |
7 |

8 |
9 |
10 | Recently, [SAM](https://arxiv.org/abs/2304.02643) has demonstrated strong zero-shot capabilities by training on the largest segmentation dataset to date. Thus, we use a trained horizontal FCOS detector to provide HBoxes into SAM as prompts, so that corresponding Masks can be generated by zero-shot, and finally the rotated RBoxes are obtained by performing the minimum circumscribed rectangle operation on the predicted Masks. Thanks to the powerful zero-shot capability, SAM-RBox based on ViT-B has achieved 63.94%. However, it is also limited to the time-consuming post-processing, only 1.7 FPS during inference.
11 |
12 |
13 | 
14 | 
15 |
16 | ## Prepare Env
17 |
18 | The code is based on MMRotate 1.x and official API of SAM.
19 |
20 | Here is the installation commands of recommended environment.
21 | ```bash
22 | pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
23 |
24 | pip install openmim
25 | mim install mmengine 'mmcv>=2.0.0rc0' 'mmrotate>=1.0.0rc0'
26 |
27 | pip install git+https://github.com/facebookresearch/segment-anything.git
28 | pip install opencv-python pycocotools matplotlib onnxruntime onnx
29 | ```
30 |
31 | ## Note
32 | 1. Prepare DOTA data set according to MMRotate doc.
33 | 2. Download the detector weight from MMRotate model zoo.
34 | 3. `python main_sam_dota.py` prompts SAM with HBox obtained from annotation file (such as DOTA trainval).
35 | 4. `python main_rdet-sam_dota.py` prompts SAM with HBox predicted by a well-trained detector for non-annotated data (such as DOTA test).
36 | 5. Many configs, including pipeline (i.e. transforms), dataset, dataloader, evaluator, visualizer, are set in `data.py`.
37 | 6. You can change the detector config and the corresponding weight path in `main_rdet-sam_dota.py` to any detector that can be built with MMRotate.
38 |
39 | ## Citation
40 | ```
41 | @article{yu2023h2rboxv2,
42 | title={H2RBox-v2: Boosting HBox-supervised Oriented Object Detection via Symmetric Learning},
43 | author={Yu, Yi and Yang, Xue and Li, Qingyun and Zhou, Yue and Zhang, Gefan and Yan, Junchi and Da, Feipeng},
44 | journal={arXiv preprint arXiv:2304.04403},
45 | year={2023}
46 | }
47 |
48 | @inproceedings{yang2023h2rbox,
49 | title={H2RBox: Horizontal Box Annotation is All You Need for Oriented Object Detection},
50 | author={Yang, Xue and Zhang, Gefan and Li, Wentong and Wang, Xuehui and Zhou, Yue and Yan, Junchi},
51 | booktitle={International Conference on Learning Representations},
52 | year={2023}
53 | }
54 |
55 | @article{kirillov2023segany,
56 | title={Segment Anything},
57 | author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
58 | journal={arXiv:2304.02643},
59 | year={2023}
60 | }
61 | ```
62 |
63 | ### Other awesome SAM projects:
64 | - [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)
65 | - [Zero-Shot Anomaly Detection](https://github.com/caoyunkang/GroundedSAM-zero-shot-anomaly-detection)
66 | - [EditAnything: ControlNet + StableDiffusion based on the SAM segmentation mask](https://github.com/sail-sg/EditAnything)
67 | - [IEA: Image Editing Anything](https://github.com/feizc/IEA)
68 | - [sam-with-mmdet](https://github.com/liuyanyi/sam-with-mmdet) (mmdet 3.0.0, provide RTMDet)
69 | - [Prompt-Segment-Anything](https://github.com/RockeyCoss/Prompt-Segment-Anything) (mmdet 3.0.0, H-DETR, DINO, Focal backbone)
70 | ......
71 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/dior.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'DIORDataset'
3 | data_root = 'data/DIOR/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
10 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True),
11 | dict(
12 | type='mmdet.RandomFlip',
13 | prob=0.75,
14 | direction=['horizontal', 'vertical', 'diagonal']),
15 | dict(type='mmdet.PackDetInputs')
16 | ]
17 | val_pipeline = [
18 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
19 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True),
20 | # avoid bboxes being resized
21 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
22 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
23 | dict(
24 | type='mmdet.PackDetInputs',
25 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
26 | 'scale_factor'))
27 | ]
28 | test_pipeline = [
29 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
30 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True),
31 | dict(
32 | type='mmdet.PackDetInputs',
33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
34 | 'scale_factor'))
35 | ]
36 | train_dataloader = dict(
37 | batch_size=2,
38 | num_workers=2,
39 | persistent_workers=True,
40 | sampler=dict(type='DefaultSampler', shuffle=True),
41 | batch_sampler=None,
42 | dataset=dict(
43 | type='ConcatDataset',
44 | ignore_keys=['DATASET_TYPE'],
45 | datasets=[
46 | dict(
47 | type=dataset_type,
48 | data_root=data_root,
49 | ann_file='ImageSets/Main/train.txt',
50 | data_prefix=dict(img_path='JPEGImages-trainval'),
51 | filter_cfg=dict(filter_empty_gt=True),
52 | pipeline=train_pipeline),
53 | dict(
54 | type=dataset_type,
55 | data_root=data_root,
56 | ann_file='ImageSets/Main/val.txt',
57 | data_prefix=dict(img_path='JPEGImages-trainval'),
58 | filter_cfg=dict(filter_empty_gt=True),
59 | pipeline=train_pipeline,
60 | backend_args=backend_args)
61 | ]))
62 | val_dataloader = dict(
63 | batch_size=1,
64 | num_workers=2,
65 | persistent_workers=True,
66 | drop_last=False,
67 | sampler=dict(type='DefaultSampler', shuffle=False),
68 | dataset=dict(
69 | type=dataset_type,
70 | data_root=data_root,
71 | ann_file='ImageSets/Main/test.txt',
72 | data_prefix=dict(img_path='JPEGImages-test'),
73 | test_mode=True,
74 | pipeline=val_pipeline,
75 | backend_args=backend_args))
76 | test_dataloader = val_dataloader
77 |
78 | val_evaluator = dict(type='DOTAMetric', metric='mAP')
79 | test_evaluator = val_evaluator
80 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/dota.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'DOTADataset'
3 | data_root = 'data/split_ss_dota/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
10 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
11 | dict(
12 | type='mmdet.RandomFlip',
13 | prob=0.75,
14 | direction=['horizontal', 'vertical', 'diagonal']),
15 | dict(type='mmdet.PackDetInputs')
16 | ]
17 | val_pipeline = [
18 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
19 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
20 | # avoid bboxes being resized
21 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
22 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
23 | dict(
24 | type='mmdet.PackDetInputs',
25 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
26 | 'scale_factor'))
27 | ]
28 | test_pipeline = [
29 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
30 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
31 | dict(
32 | type='mmdet.PackDetInputs',
33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
34 | 'scale_factor'))
35 | ]
36 | train_dataloader = dict(
37 | batch_size=2,
38 | num_workers=2,
39 | persistent_workers=True,
40 | sampler=dict(type='DefaultSampler', shuffle=True),
41 | batch_sampler=None,
42 | dataset=dict(
43 | type=dataset_type,
44 | data_root=data_root,
45 | ann_file='trainval/annfiles/',
46 | data_prefix=dict(img_path='trainval/images/'),
47 | filter_cfg=dict(filter_empty_gt=True),
48 | pipeline=train_pipeline))
49 | val_dataloader = dict(
50 | batch_size=1,
51 | num_workers=2,
52 | persistent_workers=True,
53 | drop_last=False,
54 | sampler=dict(type='DefaultSampler', shuffle=False),
55 | dataset=dict(
56 | type=dataset_type,
57 | data_root=data_root,
58 | ann_file='trainval/annfiles/',
59 | data_prefix=dict(img_path='trainval/images/'),
60 | test_mode=True,
61 | pipeline=val_pipeline))
62 | test_dataloader = val_dataloader
63 |
64 | val_evaluator = dict(type='DOTAMetric', metric='mAP')
65 | test_evaluator = val_evaluator
66 |
67 | # inference on test dataset and format the output results
68 | # for submission. Note: the test set has no annotation.
69 | # test_dataloader = dict(
70 | # batch_size=1,
71 | # num_workers=2,
72 | # persistent_workers=True,
73 | # drop_last=False,
74 | # sampler=dict(type='DefaultSampler', shuffle=False),
75 | # dataset=dict(
76 | # type=dataset_type,
77 | # data_root=data_root,
78 | # data_prefix=dict(img_path='test/images/'),
79 | # test_mode=True,
80 | # pipeline=test_pipeline))
81 | # test_evaluator = dict(
82 | # type='DOTAMetric',
83 | # format_only=True,
84 | # merge_patches=True,
85 | # outfile_prefix='./work_dirs/dota/Task1')
86 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/dota_coco.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'mmdet.CocoDataset'
3 | data_root = 'data/split_ms_dota/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(
9 | type='mmdet.LoadAnnotations',
10 | with_bbox=True,
11 | with_mask=True,
12 | poly2mask=False),
13 | dict(type='ConvertMask2BoxType', box_type='rbox'),
14 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
15 | dict(
16 | type='mmdet.RandomFlip',
17 | prob=0.75,
18 | direction=['horizontal', 'vertical', 'diagonal']),
19 | dict(type='mmdet.PackDetInputs')
20 | ]
21 | val_pipeline = [
22 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
23 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
24 | # avoid bboxes being resized
25 | dict(
26 | type='mmdet.LoadAnnotations',
27 | with_bbox=True,
28 | with_mask=True,
29 | poly2mask=False),
30 | dict(type='ConvertMask2BoxType', box_type='qbox'),
31 | dict(
32 | type='mmdet.PackDetInputs',
33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
34 | 'scale_factor', 'instances'))
35 | ]
36 | test_pipeline = [
37 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
38 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
39 | dict(
40 | type='mmdet.PackDetInputs',
41 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
42 | 'scale_factor'))
43 | ]
44 |
45 | metainfo = dict(
46 | classes=('plane', 'baseball-diamond', 'bridge', 'ground-track-field',
47 | 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
48 | 'basketball-court', 'storage-tank', 'soccer-ball-field',
49 | 'roundabout', 'harbor', 'swimming-pool', 'helicopter'))
50 |
51 | train_dataloader = dict(
52 | batch_size=2,
53 | num_workers=2,
54 | persistent_workers=True,
55 | sampler=dict(type='DefaultSampler', shuffle=True),
56 | batch_sampler=None,
57 | dataset=dict(
58 | type=dataset_type,
59 | metainfo=metainfo,
60 | data_root=data_root,
61 | ann_file='train/train.json',
62 | data_prefix=dict(img='train/images/'),
63 | filter_cfg=dict(filter_empty_gt=True),
64 | pipeline=train_pipeline,
65 | backend_args=backend_args))
66 | val_dataloader = dict(
67 | batch_size=1,
68 | num_workers=2,
69 | persistent_workers=True,
70 | drop_last=False,
71 | sampler=dict(type='DefaultSampler', shuffle=False),
72 | dataset=dict(
73 | type=dataset_type,
74 | metainfo=metainfo,
75 | data_root=data_root,
76 | ann_file='val/val.json',
77 | data_prefix=dict(img='val/images/'),
78 | test_mode=True,
79 | pipeline=val_pipeline,
80 | backend_args=backend_args))
81 | test_dataloader = val_dataloader
82 |
83 | val_evaluator = dict(
84 | type='RotatedCocoMetric',
85 | metric='bbox',
86 | classwise=True,
87 | backend_args=backend_args)
88 |
89 | test_evaluator = val_evaluator
90 |
91 | # inference on test dataset and format the output results
92 | # for submission. Note: the test set has no annotation.
93 | # test_dataloader = dict(
94 | # batch_size=1,
95 | # num_workers=2,
96 | # persistent_workers=True,
97 | # drop_last=False,
98 | # sampler=dict(type='DefaultSampler', shuffle=False),
99 | # dataset=dict(
100 | # type=dataset_type,
101 | # ann_file='test/test.json',
102 | # data_prefix=dict(img='test/images/'),
103 | # test_mode=True,
104 | # pipeline=test_pipeline))
105 | # test_evaluator = dict(
106 | # type='DOTAMetric',
107 | # format_only=True,
108 | # merge_patches=True,
109 | # outfile_prefix='./work_dirs/dota/Task1')
110 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/dota_ms.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'DOTADataset'
3 | data_root = 'data/split_ms_dota/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
10 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
11 | dict(
12 | type='mmdet.RandomFlip',
13 | prob=0.75,
14 | direction=['horizontal', 'vertical', 'diagonal']),
15 | dict(
16 | type='RandomRotate',
17 | prob=0.5,
18 | angle_range=180,
19 | rect_obj_labels=[9, 11]),
20 | dict(type='mmdet.PackDetInputs')
21 | ]
22 | val_pipeline = [
23 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
24 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
25 | # avoid bboxes being resized
26 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
27 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
28 | dict(
29 | type='mmdet.PackDetInputs',
30 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
31 | 'scale_factor'))
32 | ]
33 | test_pipeline = [
34 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
35 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
36 | dict(
37 | type='mmdet.PackDetInputs',
38 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
39 | 'scale_factor'))
40 | ]
41 | train_dataloader = dict(
42 | batch_size=2,
43 | num_workers=2,
44 | persistent_workers=True,
45 | sampler=dict(type='DefaultSampler', shuffle=True),
46 | batch_sampler=None,
47 | dataset=dict(
48 | type=dataset_type,
49 | data_root=data_root,
50 | ann_file='trainval/annfiles/',
51 | data_prefix=dict(img_path='trainval/images/'),
52 | filter_cfg=dict(filter_empty_gt=True),
53 | pipeline=train_pipeline))
54 | val_dataloader = dict(
55 | batch_size=1,
56 | num_workers=2,
57 | persistent_workers=True,
58 | drop_last=False,
59 | sampler=dict(type='DefaultSampler', shuffle=False),
60 | dataset=dict(
61 | type=dataset_type,
62 | data_root=data_root,
63 | ann_file='trainval/annfiles/',
64 | data_prefix=dict(img_path='trainval/images/'),
65 | test_mode=True,
66 | pipeline=val_pipeline))
67 | test_dataloader = val_dataloader
68 |
69 | val_evaluator = dict(type='DOTAMetric', metric='mAP')
70 | test_evaluator = val_evaluator
71 |
72 | # inference on test dataset and format the output results
73 | # for submission. Note: the test set has no annotation.
74 | # test_dataloader = dict(
75 | # batch_size=1,
76 | # num_workers=2,
77 | # persistent_workers=True,
78 | # drop_last=False,
79 | # sampler=dict(type='DefaultSampler', shuffle=False),
80 | # dataset=dict(
81 | # type=dataset_type,
82 | # data_root=data_root,
83 | # data_prefix=dict(img_path='test/images/'),
84 | # test_mode=True,
85 | # pipeline=test_pipeline))
86 | # test_evaluator = dict(
87 | # type='DOTAMetric',
88 | # format_only=True,
89 | # merge_patches=True,
90 | # outfile_prefix='./work_dirs/dota/Task1')
91 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/dota_qbox.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'DOTADataset'
3 | data_root = 'data/split_ss_dota/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
9 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
10 | dict(
11 | type='mmdet.RandomFlip',
12 | prob=0.75,
13 | direction=['horizontal', 'vertical', 'diagonal']),
14 | dict(type='mmdet.PackDetInputs')
15 | ]
16 | val_pipeline = [
17 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
18 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
19 | # avoid bboxes being resized
20 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
21 | dict(
22 | type='mmdet.PackDetInputs',
23 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
24 | 'scale_factor'))
25 | ]
26 | test_pipeline = [
27 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
28 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
29 | dict(
30 | type='mmdet.PackDetInputs',
31 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
32 | 'scale_factor'))
33 | ]
34 | train_dataloader = dict(
35 | batch_size=2,
36 | num_workers=2,
37 | persistent_workers=True,
38 | sampler=dict(type='DefaultSampler', shuffle=True),
39 | batch_sampler=None,
40 | dataset=dict(
41 | type=dataset_type,
42 | data_root=data_root,
43 | ann_file='trainval/annfiles/',
44 | data_prefix=dict(img_path='trainval/images/'),
45 | filter_cfg=dict(filter_empty_gt=True),
46 | pipeline=train_pipeline))
47 | val_dataloader = dict(
48 | batch_size=1,
49 | num_workers=2,
50 | persistent_workers=True,
51 | drop_last=False,
52 | sampler=dict(type='DefaultSampler', shuffle=False),
53 | dataset=dict(
54 | type=dataset_type,
55 | data_root=data_root,
56 | ann_file='trainval/annfiles/',
57 | data_prefix=dict(img_path='trainval/images/'),
58 | test_mode=True,
59 | pipeline=val_pipeline))
60 | test_dataloader = val_dataloader
61 |
62 | val_evaluator = dict(
63 | type='DOTAMetric', metric='mAP', iou_thrs=0.2, predict_box_type='qbox')
64 | test_evaluator = val_evaluator
65 |
66 | # inference on test dataset and format the output results
67 | # for submission. Note: the test set has no annotation.
68 | # test_dataloader = dict(
69 | # batch_size=1,
70 | # num_workers=2,
71 | # persistent_workers=True,
72 | # drop_last=False,
73 | # sampler=dict(type='DefaultSampler', shuffle=False),
74 | # dataset=dict(
75 | # type=dataset_type,
76 | # data_root=data_root,
77 | # data_prefix=dict(img_path='test/images/'),
78 | # test_mode=True,
79 | # pipeline=test_pipeline))
80 | # test_evaluator = dict(
81 | # type='DOTAMetric',
82 | # format_only=True,
83 | # merge_patches=True,
84 | # predict_box_type='qbox',
85 | # outfile_prefix='./work_dirs/dota/Task1')
86 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/dotav15.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'DOTAv15Dataset'
3 | data_root = 'data/split_ss_dota1_5/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
10 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
11 | dict(
12 | type='mmdet.RandomFlip',
13 | prob=0.75,
14 | direction=['horizontal', 'vertical', 'diagonal']),
15 | dict(type='mmdet.PackDetInputs')
16 | ]
17 | val_pipeline = [
18 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
19 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
20 | # avoid bboxes being resized
21 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
22 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
23 | dict(
24 | type='mmdet.PackDetInputs',
25 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
26 | 'scale_factor'))
27 | ]
28 | test_pipeline = [
29 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
30 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
31 | dict(
32 | type='mmdet.PackDetInputs',
33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
34 | 'scale_factor'))
35 | ]
36 | train_dataloader = dict(
37 | batch_size=2,
38 | num_workers=2,
39 | persistent_workers=True,
40 | sampler=dict(type='DefaultSampler', shuffle=True),
41 | batch_sampler=None,
42 | dataset=dict(
43 | type=dataset_type,
44 | data_root=data_root,
45 | ann_file='trainval/annfiles/',
46 | data_prefix=dict(img_path='trainval/images/'),
47 | filter_cfg=dict(filter_empty_gt=True),
48 | pipeline=train_pipeline))
49 | val_dataloader = dict(
50 | batch_size=1,
51 | num_workers=2,
52 | persistent_workers=True,
53 | drop_last=False,
54 | sampler=dict(type='DefaultSampler', shuffle=False),
55 | dataset=dict(
56 | type=dataset_type,
57 | data_root=data_root,
58 | ann_file='trainval/annfiles/',
59 | data_prefix=dict(img_path='trainval/images/'),
60 | test_mode=True,
61 | pipeline=val_pipeline))
62 | test_dataloader = val_dataloader
63 |
64 | val_evaluator = dict(type='DOTAMetric', metric='mAP')
65 | test_evaluator = val_evaluator
66 |
67 | # inference on test dataset and format the output results
68 | # for submission. Note: the test set has no annotation.
69 | # test_dataloader = dict(
70 | # batch_size=1,
71 | # num_workers=2,
72 | # persistent_workers=True,
73 | # drop_last=False,
74 | # sampler=dict(type='DefaultSampler', shuffle=False),
75 | # dataset=dict(
76 | # type=dataset_type,
77 | # data_root=data_root,
78 | # data_prefix=dict(img_path='test/images/'),
79 | # test_mode=True,
80 | # pipeline=test_pipeline))
81 | # test_evaluator = dict(
82 | # type='DOTAMetric',
83 | # format_only=True,
84 | # merge_patches=True,
85 | # outfile_prefix='./work_dirs/dotav15/h2rbox-le90_r50_fpn_adamw-1x_dotav15/Task1')
86 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/dotav2.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'DOTAv2Dataset'
3 | data_root = 'data/split_ss_dota2_0/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
10 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
11 | dict(
12 | type='mmdet.RandomFlip',
13 | prob=0.75,
14 | direction=['horizontal', 'vertical', 'diagonal']),
15 | dict(type='mmdet.PackDetInputs')
16 | ]
17 | val_pipeline = [
18 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
19 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
20 | # avoid bboxes being resized
21 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
22 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
23 | dict(
24 | type='mmdet.PackDetInputs',
25 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
26 | 'scale_factor'))
27 | ]
28 | test_pipeline = [
29 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
30 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
31 | dict(
32 | type='mmdet.PackDetInputs',
33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
34 | 'scale_factor'))
35 | ]
36 | train_dataloader = dict(
37 | batch_size=2,
38 | num_workers=2,
39 | persistent_workers=True,
40 | sampler=dict(type='DefaultSampler', shuffle=True),
41 | batch_sampler=None,
42 | dataset=dict(
43 | type=dataset_type,
44 | data_root=data_root,
45 | ann_file='trainval/annfiles/',
46 | data_prefix=dict(img_path='trainval/images/'),
47 | filter_cfg=dict(filter_empty_gt=True),
48 | pipeline=train_pipeline))
49 | val_dataloader = dict(
50 | batch_size=1,
51 | num_workers=2,
52 | persistent_workers=True,
53 | drop_last=False,
54 | sampler=dict(type='DefaultSampler', shuffle=False),
55 | dataset=dict(
56 | type=dataset_type,
57 | data_root=data_root,
58 | ann_file='trainval/annfiles/',
59 | data_prefix=dict(img_path='trainval/images/'),
60 | test_mode=True,
61 | pipeline=val_pipeline))
62 | test_dataloader = val_dataloader
63 |
64 | val_evaluator = dict(type='DOTAMetric', metric='mAP')
65 | test_evaluator = val_evaluator
66 |
67 | # inference on test dataset and format the output results
68 | # for submission. Note: the test set has no annotation.
69 | # test_dataloader = dict(
70 | # batch_size=1,
71 | # num_workers=2,
72 | # persistent_workers=True,
73 | # drop_last=False,
74 | # sampler=dict(type='DefaultSampler', shuffle=False),
75 | # dataset=dict(
76 | # type=dataset_type,
77 | # data_root=data_root,
78 | # data_prefix=dict(img_path='test/images/'),
79 | # test_mode=True,
80 | # pipeline=test_pipeline))
81 | # test_evaluator = dict(
82 | # type='DOTAMetric',
83 | # format_only=True,
84 | # merge_patches=True,
85 | # outfile_prefix='./work_dirs/dotav2/h2rbox-le90_r50_fpn_adamw-1x_dotav2/Task1')
86 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/hrsc.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'HRSCDataset'
3 | data_root = 'data/hrsc/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
10 | dict(type='mmdet.Resize', scale=(800, 512), keep_ratio=True),
11 | dict(
12 | type='mmdet.RandomFlip',
13 | prob=0.75,
14 | direction=['horizontal', 'vertical', 'diagonal']),
15 | dict(type='mmdet.PackDetInputs')
16 | ]
17 | val_pipeline = [
18 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
19 | dict(type='mmdet.Resize', scale=(800, 512), keep_ratio=True),
20 | # avoid bboxes being resized
21 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
22 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
23 | dict(
24 | type='mmdet.PackDetInputs',
25 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
26 | 'scale_factor'))
27 | ]
28 | test_pipeline = [
29 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
30 | dict(type='mmdet.Resize', scale=(800, 512), keep_ratio=True),
31 | dict(
32 | type='mmdet.PackDetInputs',
33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
34 | 'scale_factor'))
35 | ]
36 | train_dataloader = dict(
37 | batch_size=2,
38 | num_workers=2,
39 | persistent_workers=True,
40 | sampler=dict(type='DefaultSampler', shuffle=True),
41 | batch_sampler=None,
42 | dataset=dict(
43 | type=dataset_type,
44 | data_root=data_root,
45 | ann_file='ImageSets/trainval.txt',
46 | data_prefix=dict(sub_data_root='FullDataSet/'),
47 | filter_cfg=dict(filter_empty_gt=True),
48 | pipeline=train_pipeline,
49 | backend_args=backend_args))
50 | val_dataloader = dict(
51 | batch_size=1,
52 | num_workers=2,
53 | persistent_workers=True,
54 | drop_last=False,
55 | sampler=dict(type='DefaultSampler', shuffle=False),
56 | dataset=dict(
57 | type=dataset_type,
58 | data_root=data_root,
59 | ann_file='ImageSets/test.txt',
60 | data_prefix=dict(sub_data_root='FullDataSet/'),
61 | test_mode=True,
62 | pipeline=val_pipeline,
63 | backend_args=backend_args))
64 | test_dataloader = val_dataloader
65 |
66 | val_evaluator = dict(type='DOTAMetric', metric='mAP')
67 | test_evaluator = val_evaluator
68 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/hrsid.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'mmdet.CocoDataset'
3 | data_root = 'data/HRSID_JPG/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(
9 | type='mmdet.LoadAnnotations',
10 | with_bbox=True,
11 | with_mask=True,
12 | poly2mask=False),
13 | dict(type='ConvertMask2BoxType', box_type='rbox'),
14 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True),
15 | dict(type='mmdet.FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
16 | dict(
17 | type='mmdet.RandomFlip',
18 | prob=0.75,
19 | direction=['horizontal', 'vertical', 'diagonal']),
20 | dict(type='mmdet.PackDetInputs')
21 | ]
22 | val_pipeline = [
23 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
24 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True),
25 | # avoid bboxes being resized
26 | dict(
27 | type='mmdet.LoadAnnotations',
28 | with_bbox=True,
29 | with_mask=True,
30 | poly2mask=False),
31 | dict(type='ConvertMask2BoxType', box_type='qbox'),
32 | dict(
33 | type='mmdet.PackDetInputs',
34 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
35 | 'scale_factor', 'instances'))
36 | ]
37 | test_pipeline = [
38 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
39 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True),
40 | dict(
41 | type='mmdet.PackDetInputs',
42 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
43 | 'scale_factor'))
44 | ]
45 |
46 | metainfo = dict(classes=('ship', ))
47 |
48 | train_dataloader = dict(
49 | batch_size=2,
50 | num_workers=2,
51 | persistent_workers=True,
52 | sampler=dict(type='DefaultSampler', shuffle=True),
53 | batch_sampler=None,
54 | dataset=dict(
55 | type=dataset_type,
56 | metainfo=metainfo,
57 | data_root=data_root,
58 | ann_file='annotations/train2017.json',
59 | data_prefix=dict(img='JPEGImages/'),
60 | filter_cfg=dict(filter_empty_gt=True),
61 | pipeline=train_pipeline,
62 | backend_args=backend_args))
63 | val_dataloader = dict(
64 | batch_size=1,
65 | num_workers=2,
66 | persistent_workers=True,
67 | drop_last=False,
68 | sampler=dict(type='DefaultSampler', shuffle=False),
69 | dataset=dict(
70 | type=dataset_type,
71 | metainfo=metainfo,
72 | data_root=data_root,
73 | ann_file='annotations/test2017.json',
74 | data_prefix=dict(img='JPEGImages/'),
75 | test_mode=True,
76 | pipeline=val_pipeline,
77 | backend_args=backend_args))
78 | test_dataloader = val_dataloader
79 |
80 | val_evaluator = dict(type='RotatedCocoMetric', metric='bbox')
81 |
82 | test_evaluator = val_evaluator
83 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/rsdd.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'mmdet.CocoDataset'
3 | data_root = 'data/rsdd/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(
9 | type='mmdet.LoadAnnotations',
10 | with_bbox=True,
11 | with_mask=True,
12 | poly2mask=False),
13 | dict(type='ConvertMask2BoxType', box_type='rbox'),
14 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True),
15 | dict(
16 | type='mmdet.RandomFlip',
17 | prob=0.75,
18 | direction=['horizontal', 'vertical', 'diagonal']),
19 | dict(type='mmdet.PackDetInputs')
20 | ]
21 | val_pipeline = [
22 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
23 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True),
24 | # avoid bboxes being resized
25 | dict(
26 | type='mmdet.LoadAnnotations',
27 | with_bbox=True,
28 | with_mask=True,
29 | poly2mask=False),
30 | dict(type='ConvertMask2BoxType', box_type='qbox'),
31 | dict(
32 | type='mmdet.PackDetInputs',
33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
34 | 'scale_factor', 'instances'))
35 | ]
36 | test_pipeline = [
37 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
38 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True),
39 | dict(
40 | type='mmdet.PackDetInputs',
41 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
42 | 'scale_factor'))
43 | ]
44 |
45 | metainfo = dict(classes=('ship', ))
46 |
47 | train_dataloader = dict(
48 | batch_size=2,
49 | num_workers=2,
50 | persistent_workers=True,
51 | sampler=dict(type='DefaultSampler', shuffle=True),
52 | batch_sampler=None,
53 | dataset=dict(
54 | type=dataset_type,
55 | metainfo=metainfo,
56 | data_root=data_root,
57 | ann_file='ImageSets/train.json',
58 | data_prefix=dict(img='JPEGImages/'),
59 | filter_cfg=dict(filter_empty_gt=True),
60 | pipeline=train_pipeline,
61 | backend_args=backend_args))
62 | val_dataloader = dict(
63 | batch_size=1,
64 | num_workers=2,
65 | persistent_workers=True,
66 | drop_last=False,
67 | sampler=dict(type='DefaultSampler', shuffle=False),
68 | dataset=dict(
69 | type=dataset_type,
70 | metainfo=metainfo,
71 | data_root=data_root,
72 | ann_file='ImageSets/test.json',
73 | data_prefix=dict(img='JPEGImages/'),
74 | test_mode=True,
75 | pipeline=val_pipeline,
76 | backend_args=backend_args))
77 | test_dataloader = val_dataloader
78 |
79 | val_evaluator = dict(type='RotatedCocoMetric', metric='bbox')
80 |
81 | test_evaluator = val_evaluator
82 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/srsdd.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'mmdet.CocoDataset'
3 | data_root = 'data/srsdd/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(
9 | type='mmdet.LoadAnnotations',
10 | with_bbox=True,
11 | with_mask=True,
12 | poly2mask=False),
13 | dict(type='ConvertMask2BoxType', box_type='rbox'),
14 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
15 | dict(
16 | type='mmdet.RandomFlip',
17 | prob=0.75,
18 | direction=['horizontal', 'vertical', 'diagonal']),
19 | dict(type='mmdet.PackDetInputs')
20 | ]
21 | val_pipeline = [
22 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
23 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
24 | # avoid bboxes being resized
25 | dict(
26 | type='mmdet.LoadAnnotations',
27 | with_bbox=True,
28 | with_mask=True,
29 | poly2mask=False),
30 | dict(type='ConvertMask2BoxType', box_type='qbox'),
31 | dict(
32 | type='mmdet.PackDetInputs',
33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
34 | 'scale_factor', 'instances'))
35 | ]
36 | test_pipeline = [
37 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
38 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
39 | dict(
40 | type='mmdet.PackDetInputs',
41 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
42 | 'scale_factor'))
43 | ]
44 |
45 | metainfo = dict(
46 | classes=('Container', 'Dredger', 'LawEnforce', 'Cell-Container', 'ore-oil',
47 | 'Fishing'))
48 |
49 | train_dataloader = dict(
50 | batch_size=2,
51 | num_workers=2,
52 | persistent_workers=True,
53 | sampler=dict(type='DefaultSampler', shuffle=True),
54 | batch_sampler=None,
55 | dataset=dict(
56 | type=dataset_type,
57 | metainfo=metainfo,
58 | data_root=data_root,
59 | ann_file='train/train.json',
60 | data_prefix=dict(img='train/images/'),
61 | filter_cfg=dict(filter_empty_gt=True),
62 | pipeline=train_pipeline,
63 | backend_args=backend_args))
64 | val_dataloader = dict(
65 | batch_size=1,
66 | num_workers=2,
67 | persistent_workers=True,
68 | drop_last=False,
69 | sampler=dict(type='DefaultSampler', shuffle=False),
70 | dataset=dict(
71 | type=dataset_type,
72 | metainfo=metainfo,
73 | data_root=data_root,
74 | ann_file='test/test.json',
75 | data_prefix=dict(img='test/images/'),
76 | test_mode=True,
77 | pipeline=val_pipeline,
78 | backend_args=backend_args))
79 | test_dataloader = val_dataloader
80 |
81 | val_evaluator = dict(type='RotatedCocoMetric', metric='bbox')
82 |
83 | test_evaluator = val_evaluator
84 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/ssdd.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'mmdet.CocoDataset'
3 | data_root = 'data/ssdd/'
4 | backend_args = None
5 |
6 | train_pipeline = [
7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
8 | dict(
9 | type='mmdet.LoadAnnotations',
10 | with_bbox=True,
11 | with_mask=True,
12 | poly2mask=False),
13 | dict(type='ConvertMask2BoxType', box_type='rbox'),
14 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True),
15 | dict(
16 | type='mmdet.RandomFlip',
17 | prob=0.75,
18 | direction=['horizontal', 'vertical', 'diagonal']),
19 | dict(type='mmdet.PackDetInputs')
20 | ]
21 | val_pipeline = [
22 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
23 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True),
24 | # avoid bboxes being resized
25 | dict(
26 | type='mmdet.LoadAnnotations',
27 | with_bbox=True,
28 | with_mask=True,
29 | poly2mask=False),
30 | dict(type='ConvertMask2BoxType', box_type='qbox'),
31 | dict(
32 | type='mmdet.PackDetInputs',
33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
34 | 'scale_factor', 'instances'))
35 | ]
36 | test_pipeline = [
37 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
38 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True),
39 | dict(
40 | type='mmdet.PackDetInputs',
41 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
42 | 'scale_factor'))
43 | ]
44 |
45 | metainfo = dict(classes=('ship', ))
46 |
47 | train_dataloader = dict(
48 | batch_size=2,
49 | num_workers=2,
50 | persistent_workers=True,
51 | sampler=dict(type='DefaultSampler', shuffle=True),
52 | batch_sampler=None,
53 | dataset=dict(
54 | type=dataset_type,
55 | metainfo=metainfo,
56 | data_root=data_root,
57 | ann_file='train/train.json',
58 | data_prefix=dict(img='train/images/'),
59 | filter_cfg=dict(filter_empty_gt=True),
60 | pipeline=train_pipeline,
61 | backend_args=backend_args))
62 | val_dataloader = dict(
63 | batch_size=1,
64 | num_workers=2,
65 | persistent_workers=True,
66 | drop_last=False,
67 | sampler=dict(type='DefaultSampler', shuffle=False),
68 | dataset=dict(
69 | type=dataset_type,
70 | metainfo=metainfo,
71 | data_root=data_root,
72 | ann_file='test/all/test.json',
73 | data_prefix=dict(img='test/all/images/'),
74 | test_mode=True,
75 | pipeline=val_pipeline,
76 | backend_args=backend_args))
77 | test_dataloader = val_dataloader
78 |
79 | val_evaluator = dict(type='RotatedCocoMetric', metric='bbox')
80 |
81 | test_evaluator = val_evaluator
82 |
--------------------------------------------------------------------------------
/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | default_scope = 'mmrotate'
2 |
3 | default_hooks = dict(
4 | timer=dict(type='IterTimerHook'),
5 | logger=dict(type='LoggerHook', interval=50),
6 | param_scheduler=dict(type='ParamSchedulerHook'),
7 | checkpoint=dict(type='CheckpointHook', interval=1),
8 | sampler_seed=dict(type='DistSamplerSeedHook'),
9 | visualization=dict(type='mmdet.DetVisualizationHook'))
10 |
11 | env_cfg = dict(
12 | cudnn_benchmark=False,
13 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
14 | dist_cfg=dict(backend='nccl'),
15 | )
16 |
17 | vis_backends = [dict(type='LocalVisBackend')]
18 | visualizer = dict(
19 | type='RotLocalVisualizer', vis_backends=vis_backends, name='visualizer')
20 | log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
21 |
22 | log_level = 'INFO'
23 | load_from = None
24 | resume = False
25 |
--------------------------------------------------------------------------------
/configs/_base_/schedules/schedule_1x.py:
--------------------------------------------------------------------------------
1 | # training schedule for 1x
2 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
3 | val_cfg = dict(type='ValLoop')
4 | test_cfg = dict(type='TestLoop')
5 |
6 | # learning rate
7 | param_scheduler = [
8 | dict(
9 | type='LinearLR',
10 | start_factor=1.0 / 3,
11 | by_epoch=False,
12 | begin=0,
13 | end=500),
14 | dict(
15 | type='MultiStepLR',
16 | begin=0,
17 | end=12,
18 | by_epoch=True,
19 | milestones=[8, 11],
20 | gamma=0.1)
21 | ]
22 |
23 | # optimizer
24 | optim_wrapper = dict(
25 | type='OptimWrapper',
26 | optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001),
27 | clip_grad=dict(max_norm=35, norm_type=2))
28 |
--------------------------------------------------------------------------------
/configs/_base_/schedules/schedule_3x.py:
--------------------------------------------------------------------------------
1 | # training schedule for 1x
2 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=36, val_interval=1)
3 | val_cfg = dict(type='ValLoop')
4 | test_cfg = dict(type='TestLoop')
5 |
6 | # learning rate
7 | param_scheduler = [
8 | dict(
9 | type='LinearLR',
10 | start_factor=1.0 / 3,
11 | by_epoch=False,
12 | begin=0,
13 | end=500),
14 | dict(
15 | type='MultiStepLR',
16 | begin=0,
17 | end=36,
18 | by_epoch=True,
19 | milestones=[24, 33],
20 | gamma=0.1)
21 | ]
22 |
23 | # optimizer
24 | optim_wrapper = dict(
25 | type='OptimWrapper',
26 | optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001),
27 | clip_grad=dict(max_norm=35, norm_type=2))
28 |
--------------------------------------------------------------------------------
/configs/_base_/schedules/schedule_40e.py:
--------------------------------------------------------------------------------
1 | # training schedule for 1x
2 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=40, val_interval=1)
3 | val_cfg = dict(type='ValLoop')
4 | test_cfg = dict(type='TestLoop')
5 |
6 | # learning rate
7 | param_scheduler = [
8 | dict(
9 | type='LinearLR',
10 | start_factor=1.0 / 3,
11 | by_epoch=False,
12 | begin=0,
13 | end=500),
14 | dict(
15 | type='MultiStepLR',
16 | begin=0,
17 | end=40,
18 | by_epoch=True,
19 | milestones=[24, 32, 38],
20 | gamma=0.1)
21 | ]
22 |
23 | # optimizer
24 | optim_wrapper = dict(
25 | type='OptimWrapper',
26 | optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001),
27 | clip_grad=dict(max_norm=35, norm_type=2))
28 |
--------------------------------------------------------------------------------
/configs/_base_/schedules/schedule_6x.py:
--------------------------------------------------------------------------------
1 | # training schedule for 1x
2 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=72, val_interval=1)
3 | val_cfg = dict(type='ValLoop')
4 | test_cfg = dict(type='TestLoop')
5 |
6 | # learning rate
7 | param_scheduler = [
8 | dict(
9 | type='LinearLR',
10 | start_factor=1.0 / 3,
11 | by_epoch=False,
12 | begin=0,
13 | end=500),
14 | dict(
15 | type='MultiStepLR',
16 | begin=0,
17 | end=72,
18 | by_epoch=True,
19 | milestones=[48, 66],
20 | gamma=0.1)
21 | ]
22 |
23 | # optimizer
24 | optim_wrapper = dict(
25 | type='OptimWrapper',
26 | optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001),
27 | clip_grad=dict(max_norm=35, norm_type=2))
28 |
--------------------------------------------------------------------------------
/configs/rotated_fcos/README.md:
--------------------------------------------------------------------------------
1 | # Rotated FCOS
2 |
3 | > [FCOS: Fully Convolutional One-Stage Object Detection](https://arxiv.org/abs/1904.01355)
4 |
5 |
6 |
7 | ## Abstract
8 |
9 |
10 |

11 |
12 |
13 | We propose a fully convolutional one-stage object detector (FCOS) to solve object detection in a per-pixel prediction
14 | fashion, analogue to semantic segmentation. Almost all state-of-the-art object detectors such as RetinaNet, SSD, YOLOv3,
15 | and Faster R-CNN rely on pre-defined anchor boxes. In contrast, our proposed detector FCOS is anchor box free, as well
16 | as proposal free. By eliminating the predefined set of anchor boxes, FCOS completely avoids the complicated computation
17 | related to anchor boxes such as calculating overlapping during training. More importantly, we also avoid all
18 | hyper-parameters related to anchor boxes, which are often very sensitive to the final detection performance. With the
19 | only post-processing non-maximum suppression (NMS), FCOS with ResNeXt-64x4d-101 achieves 44.7% in AP with single-model
20 | and single-scale testing, surpassing previous one-stage detectors with the advantage of being much simpler. For the
21 | first time, we demonstrate a much simpler and flexible detection framework achieving improved detection accuracy. We
22 | hope that the proposed FCOS framework can serve as a simple and strong alternative for many other instance-level tasks.
23 |
24 | ## Results and Models
25 |
26 | DOTA1.0
27 |
28 | | Backbone | mAP | Angle | Separate Angle | Tricks | lr schd | Mem (GB) | Inf Time (fps) | Aug | Batch Size | Configs | Download |
29 | | :----------------------: | :---: | :---: | :------------: | :----: | :-----: | :------: | :------------: | :-: | :--------: | :-------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
30 | | ResNet50 (1024,1024,200) | 70.70 | le90 | Y | Y | 1x | 4.18 | 26.4 | - | 2 | [rotated-fcos-hbox-le90_r50_fpn_1x_dota](rotated-fcos-hbox-le90_r50_fpn_1x_dota.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90-0be71a0c.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90_20220409_023250.log.json) |
31 | | ResNet50 (1024,1024,200) | 71.28 | le90 | N | Y | 1x | 4.18 | 25.9 | - | 2 | [rotated-fcos-le90_r50_fpn_1x_dota](rotated-fcos-le90_r50_fpn_1x_dota.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90-d87568ed.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90_20220413_163526.log.json) |
32 | | ResNet50 (1024,1024,200) | 71.76 | le90 | Y | Y | 1x | 4.23 | 25.7 | - | 2 | [rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota](rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90-4e044ad2.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90_20220409_080616.log.json) |
33 | | ResNet50 (1024,1024,200) | 71.89 | le90 | N | Y | 1x | 4.18 | 26.2 | - | 2 | [rotated-fcos-le90_r50_fpn_kld_1x_dota](rotated-fcos-le90_r50_fpn_kld_1x_dota.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90-ecafdb2b.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90_20220409_202939.log.json) |
34 |
35 | **Notes:**
36 |
37 | - `MS` means multiple scale image split.
38 | - `RR` means random rotation.
39 | - `Rotated IoU Loss` need mmcv version 1.5.0 or above.
40 | - `Separate Angle` means angle loss is calculated separately.
41 | At this time bbox loss uses horizontal bbox loss such as `IoULoss`, `GIoULoss`.
42 | - Tricks means setting `norm_on_bbox`, `centerness_on_reg`, `center_sampling` as `True`.
43 | - Inf time was tested on a single RTX3090.
44 |
45 | ## Citation
46 |
47 | ```
48 | @article{tian2019fcos,
49 | title={FCOS: Fully Convolutional One-Stage Object Detection},
50 | author={Tian, Zhi and Shen, Chunhua and Chen, Hao and He, Tong},
51 | journal={arXiv preprint arXiv:1904.01355},
52 | year={2019}
53 | }
54 | ```
55 |
--------------------------------------------------------------------------------
/configs/rotated_fcos/metafile.yml:
--------------------------------------------------------------------------------
1 | Collections:
2 | - Name: rotated_fcos
3 | Metadata:
4 | Training Data: DOTAv1.0
5 | Training Techniques:
6 | - SGD with Momentum
7 | - Weight Decay
8 | Training Resources: 1x Tesla V100
9 | Architecture:
10 | - ResNet
11 | Paper:
12 | URL: https://arxiv.org/abs/1904.01355
13 | Title: 'FCOS: Fully Convolutional One-Stage Object Detection'
14 | README: configs/rotated_fcos/README.md
15 |
16 | Models:
17 | - Name: rotated-fcos-hbox-le90_r50_fpn_1x_dota
18 | In Collection: rotated_fcos
19 | Config: configs/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_1x_dota.py
20 | Metadata:
21 | Training Data: DOTAv1.0
22 | Results:
23 | - Task: Oriented Object Detection
24 | Dataset: DOTAv1.0
25 | Metrics:
26 | mAP: 70.70
27 | Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_1x_dota/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90-0be71a0c.pth
28 |
29 | - Name: rotated-fcos-le90_r50_fpn_1x_dota
30 | In Collection: rotated_fcos
31 | Config: configs/rotated_fcos/rotated-fcos-le90_r50_fpn_1x_dota.py
32 | Metadata:
33 | Training Data: DOTAv1.0
34 | Results:
35 | - Task: Oriented Object Detection
36 | Dataset: DOTAv1.0
37 | Metrics:
38 | mAP: 71.28
39 | Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90-d87568ed.pth
40 |
41 | - Name: rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota
42 | In Collection: rotated_fcos
43 | Config: configs/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota.py
44 | Metadata:
45 | Training Data: DOTAv1.0
46 | Results:
47 | - Task: Oriented Object Detection
48 | Dataset: DOTAv1.0
49 | Metrics:
50 | mAP: 71.76
51 | Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90-4e044ad2.pth
52 |
53 | - Name: rotated-fcos-le90_r50_fpn_kld_1x_dota
54 | In Collection: rotated_fcos
55 | Config: configs/rotated_fcos/rotated-fcos-le90_r50_fpn_kld_1x_dota.py
56 | Metadata:
57 | Training Data: DOTAv1.0
58 | Results:
59 | - Task: Oriented Object Detection
60 | Dataset: DOTAv1.0
61 | Metrics:
62 | mAP: 71.89
63 | Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90-ecafdb2b.pth
64 |
--------------------------------------------------------------------------------
/configs/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_1x_dota.py:
--------------------------------------------------------------------------------
1 | _base_ = 'rotated-fcos-le90_r50_fpn_1x_dota.py'
2 |
3 | model = dict(
4 | bbox_head=dict(
5 | use_hbbox_loss=True,
6 | scale_angle=True,
7 | angle_coder=dict(type='PseudoAngleCoder'),
8 | loss_angle=dict(_delete_=True, type='mmdet.L1Loss', loss_weight=0.2),
9 | loss_bbox=dict(type='mmdet.IoULoss', loss_weight=1.0),
10 | ))
11 |
--------------------------------------------------------------------------------
/configs/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota.py:
--------------------------------------------------------------------------------
1 | _base_ = 'rotated-fcos-le90_r50_fpn_1x_dota.py'
2 |
3 | angle_version = {{_base_.angle_version}}
4 |
5 | # model settings
6 | model = dict(
7 | bbox_head=dict(
8 | use_hbbox_loss=True,
9 | scale_angle=False,
10 | angle_coder=dict(
11 | type='CSLCoder',
12 | angle_version=angle_version,
13 | omega=1,
14 | window='gaussian',
15 | radius=1),
16 | loss_angle=dict(
17 | _delete_=True,
18 | type='SmoothFocalLoss',
19 | gamma=2.0,
20 | alpha=0.25,
21 | loss_weight=0.2),
22 | loss_bbox=dict(type='mmdet.IoULoss', loss_weight=1.0),
23 | ))
24 |
--------------------------------------------------------------------------------
/configs/rotated_fcos/rotated-fcos-le90_r50_fpn_1x_dota.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/datasets/dota.py', '../_base_/schedules/schedule_1x.py',
3 | '../_base_/default_runtime.py'
4 | ]
5 | angle_version = 'le90'
6 |
7 | # model settings
8 | model = dict(
9 | type='mmdet.FCOS',
10 | data_preprocessor=dict(
11 | type='mmdet.DetDataPreprocessor',
12 | mean=[123.675, 116.28, 103.53],
13 | std=[58.395, 57.12, 57.375],
14 | bgr_to_rgb=True,
15 | pad_size_divisor=32,
16 | boxtype2tensor=False),
17 | backbone=dict(
18 | type='mmdet.ResNet',
19 | depth=50,
20 | num_stages=4,
21 | out_indices=(0, 1, 2, 3),
22 | frozen_stages=1,
23 | norm_cfg=dict(type='BN', requires_grad=True),
24 | norm_eval=True,
25 | style='pytorch',
26 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
27 | neck=dict(
28 | type='mmdet.FPN',
29 | in_channels=[256, 512, 1024, 2048],
30 | out_channels=256,
31 | start_level=1,
32 | add_extra_convs='on_output',
33 | num_outs=5,
34 | relu_before_extra_convs=True),
35 | bbox_head=dict(
36 | type='RotatedFCOSHead',
37 | num_classes=15,
38 | in_channels=256,
39 | stacked_convs=4,
40 | feat_channels=256,
41 | strides=[8, 16, 32, 64, 128],
42 | center_sampling=True,
43 | center_sample_radius=1.5,
44 | norm_on_bbox=True,
45 | centerness_on_reg=True,
46 | use_hbbox_loss=False,
47 | scale_angle=True,
48 | bbox_coder=dict(
49 | type='DistanceAnglePointCoder', angle_version=angle_version),
50 | loss_cls=dict(
51 | type='mmdet.FocalLoss',
52 | use_sigmoid=True,
53 | gamma=2.0,
54 | alpha=0.25,
55 | loss_weight=1.0),
56 | loss_bbox=dict(type='RotatedIoULoss', loss_weight=1.0),
57 | loss_angle=None,
58 | loss_centerness=dict(
59 | type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
60 | # training and testing settings
61 | train_cfg=None,
62 | test_cfg=dict(
63 | nms_pre=2000,
64 | min_bbox_size=0,
65 | score_thr=0.05,
66 | nms=dict(type='nms_rotated', iou_threshold=0.1),
67 | max_per_img=2000))
68 |
--------------------------------------------------------------------------------
/configs/rotated_fcos/rotated-fcos-le90_r50_fpn_kld_1x_dota.py:
--------------------------------------------------------------------------------
1 | _base_ = 'rotated-fcos-le90_r50_fpn_1x_dota.py'
2 |
3 | model = dict(
4 | bbox_head=dict(
5 | loss_bbox=dict(
6 | _delete_=True,
7 | type='GDLoss_v1',
8 | loss_type='kld',
9 | fun='log1p',
10 | tau=1,
11 | loss_weight=1.0)))
12 |
--------------------------------------------------------------------------------
/configs/rotated_fcos/rotated-fcos-le90_r50_fpn_rr-6x_hrsc.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/datasets/hrsc.py', '../_base_/schedules/schedule_6x.py',
3 | '../_base_/default_runtime.py'
4 | ]
5 | angle_version = 'le90'
6 |
7 | # model settings
8 | model = dict(
9 | type='mmdet.FCOS',
10 | data_preprocessor=dict(
11 | type='mmdet.DetDataPreprocessor',
12 | mean=[123.675, 116.28, 103.53],
13 | std=[58.395, 57.12, 57.375],
14 | bgr_to_rgb=True,
15 | pad_size_divisor=32,
16 | boxtype2tensor=False),
17 | backbone=dict(
18 | type='mmdet.ResNet',
19 | depth=50,
20 | num_stages=4,
21 | out_indices=(0, 1, 2, 3),
22 | frozen_stages=1,
23 | norm_cfg=dict(type='BN', requires_grad=True),
24 | norm_eval=True,
25 | style='pytorch',
26 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
27 | neck=dict(
28 | type='mmdet.FPN',
29 | in_channels=[256, 512, 1024, 2048],
30 | out_channels=256,
31 | start_level=1,
32 | add_extra_convs='on_output',
33 | num_outs=5,
34 | relu_before_extra_convs=True),
35 | bbox_head=dict(
36 | type='RotatedFCOSHead',
37 | num_classes=1,
38 | in_channels=256,
39 | stacked_convs=4,
40 | feat_channels=256,
41 | strides=[8, 16, 32, 64, 128],
42 | center_sampling=True,
43 | center_sample_radius=1.5,
44 | norm_on_bbox=True,
45 | centerness_on_reg=True,
46 | use_hbbox_loss=False,
47 | scale_angle=True,
48 | bbox_coder=dict(
49 | type='DistanceAnglePointCoder', angle_version=angle_version),
50 | loss_cls=dict(
51 | type='mmdet.FocalLoss',
52 | use_sigmoid=True,
53 | gamma=2.0,
54 | alpha=0.25,
55 | loss_weight=1.0),
56 | loss_bbox=dict(type='RotatedIoULoss', loss_weight=1.0),
57 | loss_angle=None,
58 | loss_centerness=dict(
59 | type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
60 | # training and testing settings
61 | train_cfg=None,
62 | test_cfg=dict(
63 | nms_pre=2000,
64 | min_bbox_size=0,
65 | score_thr=0.05,
66 | nms=dict(type='nms_rotated', iou_threshold=0.1),
67 | max_per_img=2000))
68 |
69 | train_pipeline = [
70 | dict(type='mmdet.LoadImageFromFile', backend_args={{_base_.backend_args}}),
71 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
72 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
73 | dict(type='mmdet.Resize', scale=(800, 512), keep_ratio=True),
74 | dict(
75 | type='mmdet.RandomFlip',
76 | prob=0.75,
77 | direction=['horizontal', 'vertical', 'diagonal']),
78 | dict(type='RandomRotate', prob=0.5, angle_range=180),
79 | dict(type='mmdet.PackDetInputs')
80 | ]
81 |
82 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
83 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | from functools import partial
4 | from typing import Dict, Optional, Union, List
5 |
6 | from mmengine.runner import Runner
7 | from mmengine.evaluator import Evaluator
8 | from mmengine.dataset import worker_init_fn
9 | from mmengine.dist import get_rank
10 | from mmengine.logging import print_log
11 | from mmengine.registry import DATA_SAMPLERS, FUNCTIONS, EVALUATOR, VISUALIZERS
12 | from mmengine.utils import digit_version
13 | from mmengine.utils.dl_utils import TORCH_VERSION
14 |
15 | import transforms
16 | import visualizer
17 |
18 |
19 | from torch.utils.data import DataLoader
20 |
21 | from mmrotate.registry import DATASETS
22 |
23 |
24 | def build_data_loader(data_name=None):
25 | if data_name is None or data_name == 'trainval_with_hbox':
26 | return MMEngine_build_dataloader(dataloader=naive_trainval_dataloader)
27 | elif data_name == 'test_without_hbox':
28 | return MMEngine_build_dataloader(dataloader=naive_test_dataloader)
29 | else:
30 | raise NotImplementedError()
31 |
32 |
33 | def build_evaluator(merge_patches=True, format_only=False):
34 | naive_evaluator.update(dict(
35 | merge_patches=merge_patches, format_only=format_only))
36 | return MMEngine_build_evaluator(evaluator=naive_evaluator)
37 |
38 |
39 | def build_visualizer():
40 | vis_backends = [dict(type='LocalVisBackend')]
41 | visualizer = dict(
42 | type='RotLocalVisualizerMaskThenBox', vis_backends=vis_backends,
43 | name='sammrotate', save_dir='./rbbox_vis')
44 | return VISUALIZERS.build(visualizer)
45 |
46 |
47 | # dataset settings
48 | dataset_type = 'DOTADataset'
49 | data_root = 'data/split_ss_dota/'
50 | backend_args = None
51 |
52 | naive_trainval_pipeline = [
53 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
54 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
55 | # avoid bboxes being resized
56 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
57 | # Horizontal GTBox, (x1,y1,x2,y2)
58 | dict(type='AddConvertedGTBox', box_type_mapping=dict(h_gt_bboxes='hbox')),
59 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
60 | # # Horizontal GTBox, (x,y,w,h,theta)
61 | # dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
62 | dict(
63 | type='mmdet.PackDetInputs',
64 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
65 | 'scale_factor', 'h_gt_bboxes'))
66 | ]
67 |
68 | naive_test_pipeline = [
69 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args),
70 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True),
71 | dict(
72 | type='mmdet.PackDetInputs',
73 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
74 | 'scale_factor'))
75 | ]
76 |
77 | naive_trainval_dataset = dict(
78 | type=dataset_type,
79 | data_root=data_root,
80 | # ann_file='trainval/annfiles/',
81 | # ann_file='trainval/annfiles-1sample/',
82 | # ann_file='trainval/annfiles-3sample/',
83 | # ann_file='trainval/annfiles-10sample/',
84 | # ann_file='trainval/annfiles-30sample/',
85 | # ann_file='trainval/annfiles-100sample/',
86 | ann_file='trainval/annfiles-1000sample/',
87 | data_prefix=dict(img_path='trainval/images/'),
88 | test_mode=True, # we only inference the sam
89 | pipeline=naive_trainval_pipeline)
90 |
91 | naive_test_dataset = dict(
92 | type=dataset_type,
93 | data_root=data_root,
94 | data_prefix=dict(img_path='test/images/'),
95 | test_mode=True,
96 | pipeline=naive_test_pipeline)
97 |
98 | naive_trainval_dataloader = dict(
99 | batch_size=1,
100 | # num_workers=0, # For debug
101 | num_workers=2,
102 | # persistent_workers=False, # For debug
103 | persistent_workers=True,
104 | drop_last=False,
105 | sampler=dict(type='DefaultSampler', shuffle=False),
106 | dataset=naive_trainval_dataset)
107 |
108 | naive_test_dataloader = dict(
109 | batch_size=1,
110 | # num_workers=0, # For debug
111 | num_workers=2,
112 | # persistent_workers=False, # For debug
113 | persistent_workers=True,
114 | drop_last=False,
115 | sampler=dict(type='DefaultSampler', shuffle=False),
116 | dataset=naive_test_dataset)
117 |
118 | naive_evaluator = dict(
119 | type='DOTAMetric', metric='mAP', outfile_prefix='./work_dirs/dota/Task1')
120 |
121 |
122 | def MMEngine_build_dataloader(dataloader: Union[DataLoader, Dict],
123 | seed: Optional[int] = None,
124 | diff_rank_seed: bool = False) -> DataLoader:
125 | """Build dataloader.
126 |
127 | The method builds three components:
128 |
129 | - Dataset
130 | - Sampler
131 | - Dataloader
132 |
133 | An example of ``dataloader``::
134 |
135 | dataloader = dict(
136 | dataset=dict(type='ToyDataset'),
137 | sampler=dict(type='DefaultSampler', shuffle=True),
138 | batch_size=1,
139 | num_workers=9
140 | )
141 |
142 | Args:
143 | dataloader (DataLoader or dict): A Dataloader object or a dict to
144 | build Dataloader object. If ``dataloader`` is a Dataloader
145 | object, just returns itself.
146 | seed (int, optional): Random seed. Defaults to None.
147 | diff_rank_seed (bool): Whether or not set different seeds to
148 | different ranks. If True, the seed passed to sampler is set
149 | to None, in order to synchronize the seeds used in samplers
150 | across different ranks.
151 |
152 |
153 | Returns:
154 | Dataloader: DataLoader build from ``dataloader_cfg``.
155 | """
156 | if isinstance(dataloader, DataLoader):
157 | return dataloader
158 |
159 | dataloader_cfg = copy.deepcopy(dataloader)
160 |
161 | # build dataset
162 | dataset_cfg = dataloader_cfg.pop('dataset')
163 | if isinstance(dataset_cfg, dict):
164 | dataset = DATASETS.build(dataset_cfg)
165 | if hasattr(dataset, 'full_init'):
166 | dataset.full_init()
167 | else:
168 | # fallback to raise error in dataloader
169 | # if `dataset_cfg` is not a valid type
170 | dataset = dataset_cfg
171 |
172 | # build sampler
173 | sampler_cfg = dataloader_cfg.pop('sampler')
174 | if isinstance(sampler_cfg, dict):
175 | sampler_seed = None if diff_rank_seed else seed
176 | sampler = DATA_SAMPLERS.build(
177 | sampler_cfg,
178 | default_args=dict(dataset=dataset, seed=sampler_seed))
179 | else:
180 | # fallback to raise error in dataloader
181 | # if `sampler_cfg` is not a valid type
182 | sampler = sampler_cfg
183 |
184 | # build batch sampler
185 | batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None)
186 | if batch_sampler_cfg is None:
187 | batch_sampler = None
188 | elif isinstance(batch_sampler_cfg, dict):
189 | batch_sampler = DATA_SAMPLERS.build(
190 | batch_sampler_cfg,
191 | default_args=dict(
192 | sampler=sampler,
193 | batch_size=dataloader_cfg.pop('batch_size')))
194 | else:
195 | # fallback to raise error in dataloader
196 | # if `batch_sampler_cfg` is not a valid type
197 | batch_sampler = batch_sampler_cfg
198 |
199 | # build dataloader
200 | init_fn: Optional[partial]
201 |
202 | if seed is not None:
203 | disable_subprocess_warning = dataloader_cfg.pop(
204 | 'disable_subprocess_warning', False)
205 | assert isinstance(
206 | disable_subprocess_warning,
207 | bool), ('disable_subprocess_warning should be a bool, but got '
208 | f'{type(disable_subprocess_warning)}')
209 | init_fn = partial(
210 | worker_init_fn,
211 | num_workers=dataloader_cfg.get('num_workers'),
212 | rank=get_rank(),
213 | seed=seed,
214 | disable_subprocess_warning=disable_subprocess_warning)
215 | else:
216 | init_fn = None
217 |
218 | # `persistent_workers` requires pytorch version >= 1.7
219 | if ('persistent_workers' in dataloader_cfg
220 | and digit_version(TORCH_VERSION) < digit_version('1.7.0')):
221 | print_log(
222 | '`persistent_workers` is only available when '
223 | 'pytorch version >= 1.7',
224 | logger='current',
225 | level=logging.WARNING)
226 | dataloader_cfg.pop('persistent_workers')
227 |
228 | # The default behavior of `collat_fn` in dataloader is to
229 | # merge a list of samples to form a mini-batch of Tensor(s).
230 | # However, in mmengine, if `collate_fn` is not defined in
231 | # dataloader_cfg, `pseudo_collate` will only convert the list of
232 | # samples into a dict without stacking the batch tensor.
233 | collate_fn_cfg = dataloader_cfg.pop('collate_fn',
234 | dict(type='pseudo_collate'))
235 | collate_fn_type = collate_fn_cfg.pop('type')
236 | collate_fn = FUNCTIONS.get(collate_fn_type)
237 | collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
238 | data_loader = DataLoader(
239 | dataset=dataset,
240 | sampler=sampler if batch_sampler is None else None,
241 | batch_sampler=batch_sampler,
242 | collate_fn=collate_fn,
243 | worker_init_fn=init_fn,
244 | **dataloader_cfg)
245 | return data_loader
246 |
247 |
248 | def MMEngine_build_evaluator(evaluator: Union[Dict, List, Evaluator]) -> Evaluator:
249 | """Build evaluator.
250 |
251 | Examples of ``evaluator``::
252 |
253 | # evaluator could be a built Evaluator instance
254 | evaluator = Evaluator(metrics=[ToyMetric()])
255 |
256 | # evaluator can also be a list of dict
257 | evaluator = [
258 | dict(type='ToyMetric1'),
259 | dict(type='ToyEvaluator2')
260 | ]
261 |
262 | # evaluator can also be a list of built metric
263 | evaluator = [ToyMetric1(), ToyMetric2()]
264 |
265 | # evaluator can also be a dict with key metrics
266 | evaluator = dict(metrics=ToyMetric())
267 | # metric is a list
268 | evaluator = dict(metrics=[ToyMetric()])
269 |
270 | Args:
271 | evaluator (Evaluator or dict or list): An Evaluator object or a
272 | config dict or list of config dict used to build an Evaluator.
273 |
274 | Returns:
275 | Evaluator: Evaluator build from ``evaluator``.
276 | """
277 | if isinstance(evaluator, Evaluator):
278 | return evaluator
279 | elif isinstance(evaluator, dict):
280 | # if `metrics` in dict keys, it means to build customized evalutor
281 | if 'metrics' in evaluator:
282 | evaluator.setdefault('type', 'Evaluator')
283 | return EVALUATOR.build(evaluator)
284 | # otherwise, default evalutor will be built
285 | else:
286 | return Evaluator(evaluator) # type: ignore
287 | elif isinstance(evaluator, list):
288 | # use the default `Evaluator`
289 | return Evaluator(evaluator) # type: ignore
290 | else:
291 | raise TypeError(
292 | 'evaluator should be one of dict, list of dict, and Evaluator'
293 | f', but got {evaluator}')
294 |
295 |
296 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from pathlib import Path
4 | from copy import deepcopy
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import cv2
8 |
9 | from mmrotate.structures import RotatedBoxes
10 | from mmdet.models.utils import samplelist_boxtype2tensor
11 | from mmengine.runner import load_checkpoint
12 | from utils import show_box, show_mask
13 | import matplotlib.pyplot as plt
14 | from mmengine.structures import InstanceData
15 | from data import build_visualizer
16 |
17 |
18 | RESULT_WITH_MASK = True
19 | MAX_BATCH_NUM_PRED = 100
20 | VIS_SCORE_THR = 0.3
21 |
22 |
23 | @torch.no_grad()
24 | def single_sample_step(img_id, data, model, predictor, evaluator, dataloader, device, SHOW):
25 | copied_data = deepcopy(data) # for sam
26 | for item in data.values():
27 | item[0].to(device)
28 |
29 | # Stage 1
30 | # data['inputs'][0] = torch.flip(data['inputs'][0], dims=[0])
31 | with torch.no_grad():
32 | pred_results = model.test_step(data)
33 | pred_r_bboxes = pred_results[0].pred_instances.bboxes
34 | pred_r_bboxes = RotatedBoxes(pred_r_bboxes)
35 | h_bboxes = pred_r_bboxes.convert_to('hbox').tensor
36 | labels = pred_results[0].pred_instances.labels
37 | scores = pred_results[0].pred_instances.scores
38 |
39 | # Stage 2
40 | if len(h_bboxes) == 0:
41 | qualities = h_bboxes[:, 0]
42 | masks = h_bboxes.new_tensor((0, *data['inputs'][0].shape[:2]))
43 | data_samples = data['data_samples']
44 | r_bboxes = []
45 | else:
46 | img = copied_data['inputs'][0].permute(1, 2, 0).numpy()[:, :, ::-1]
47 | data_samples = copied_data['data_samples']
48 | data_sample = data_samples[0]
49 | data_sample = data_sample.to(device=device)
50 |
51 | predictor.set_image(img)
52 |
53 | # Too many predictions may result in OOM, hence,
54 | # we process the predictions in multiple batches.
55 | masks = []
56 | num_pred = len(h_bboxes)
57 | num_batches = int(np.ceil(num_pred / MAX_BATCH_NUM_PRED))
58 | for i in range(num_batches):
59 | left_index = i * MAX_BATCH_NUM_PRED
60 | right_index = (i + 1) * MAX_BATCH_NUM_PRED
61 | if i == num_batches - 1:
62 | batch_boxes = h_bboxes[left_index:]
63 | else:
64 | batch_boxes = h_bboxes[left_index: right_index]
65 |
66 | transformed_boxes = predictor.transform.apply_boxes_torch(batch_boxes, img.shape[:2])
67 | batch_masks, qualities, lr_logits = predictor.predict_torch(
68 | point_coords=None,
69 | point_labels=None,
70 | boxes=transformed_boxes,
71 | multimask_output=False)
72 | batch_masks = batch_masks.squeeze(1).cpu()
73 | masks.extend([*batch_masks])
74 | masks = torch.stack(masks, dim=0)
75 | r_bboxes = [mask2rbox(mask.numpy()) for mask in masks]
76 |
77 | results_list = get_instancedata_resultlist(r_bboxes, labels, masks, scores)
78 | data_samples = add_pred_to_datasample(results_list, data_samples)
79 |
80 | evaluator.process(data_samples=data_samples, data_batch=data)
81 |
82 | if SHOW:
83 | if len(h_bboxes) != 0 and img_id < 100:
84 | img_name = data_samples[0].img_id
85 | show_results(img, masks, h_bboxes, results_list, img_id, img_name, dataloader)
86 |
87 | return evaluator
88 |
89 |
90 | def mask2rbox(mask):
91 | y, x = np.nonzero(mask)
92 | points = np.stack([x, y], axis=-1)
93 | (cx, cy), (w, h), a = cv2.minAreaRect(points)
94 | r_bbox = np.array([cx, cy, w, h, a / 180 * np.pi])
95 | return r_bbox
96 |
97 | def show_results(img, masks, h_bboxes, results_list, i, img_name, dataloader):
98 | output_dir = './output_vis/'
99 | Path(output_dir).mkdir(exist_ok=True, parents=True)
100 |
101 | results = results_list[0]
102 |
103 | # vis first stage
104 | # plt.figure(figsize=(10, 10))
105 | # plt.imshow(img)
106 | # for mask in masks:
107 | # show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
108 | # for box in h_bboxes:
109 | # show_box(box.cpu().numpy(), plt.gca())
110 | # plt.axis('off')
111 | # # plt.show()
112 | # plt.savefig(f'./out_mask_{i}.png')
113 | # plt.close()
114 |
115 | # draw rbox with mmrotate
116 | visualizer = build_visualizer()
117 | visualizer.dataset_meta = dataloader.dataset.metainfo
118 |
119 | scores = results.scores
120 | keep_results = results[scores >= VIS_SCORE_THR]
121 | out_img = visualizer._draw_instances(
122 | img, keep_results,
123 | dataloader.dataset.metainfo['classes'],
124 | dataloader.dataset.metainfo['palette'],
125 | box_alpha=0.9, mask_alpha=0.3)
126 | # visualizer.show()
127 | # cv2.imwrite(os.path.join(output_dir, f'out_rbox_{i}.png'), out_img[:, :, ::-1])
128 | cv2.imwrite(os.path.join(output_dir, f'rdet-sam_{img_name}.png'),
129 | out_img[:, :, ::-1])
130 |
131 |
132 | def add_pred_to_datasample(results_list, data_samples):
133 | for data_sample, pred_instances in zip(data_samples, results_list):
134 | data_sample.pred_instances = pred_instances
135 | samplelist_boxtype2tensor(data_samples)
136 | return data_samples
137 |
138 |
139 | def get_instancedata_resultlist(r_bboxes, labels, masks, scores):
140 | results = InstanceData()
141 | results.bboxes = RotatedBoxes(r_bboxes)
142 | # results.scores = qualities
143 | results.scores = scores
144 | results.labels = labels
145 | if RESULT_WITH_MASK:
146 | results.masks = masks.cpu().numpy()
147 | results_list = [results]
148 | return results_list
149 |
--------------------------------------------------------------------------------
/main_rdet-sam_dota.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 |
4 | from mmrotate.utils import register_all_modules
5 | from data import build_data_loader, build_evaluator, build_visualizer
6 |
7 | from segment_anything import sam_model_registry, SamPredictor
8 | from mmrotate.registry import MODELS
9 |
10 | from mmengine import Config
11 | from mmengine.runner.checkpoint import _load_checkpoint
12 |
13 | from engine import single_sample_step
14 |
15 |
16 | register_all_modules(init_default_scope=True)
17 |
18 | SHOW = True
19 | FORMAT_ONLY = True
20 | MERGE_PATCHES = True
21 | SET_MIN_BOX = False
22 |
23 |
24 | if __name__ == '__main__':
25 |
26 | sam_checkpoint = r"../segment-anything/checkpoints/sam_vit_b_01ec64.pth"
27 | model_type = "vit_b"
28 | device = "cuda"
29 |
30 | ckpt_path = './rotated_fcos_sep_angle_r50_fpn_1x_dota_le90-0be71a0c.pth'
31 | model_cfg_path = 'configs/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_1x_dota.py'
32 | # ckpt_path = './rotated_fcos_kld_r50_fpn_1x_dota_le90-ecafdb2b.pth'
33 | # model_cfg_path = 'configs/rotated_fcos/rotated-fcos-le90_r50_fpn_kld_1x_dota.py'
34 |
35 | model_cfg = Config.fromfile(model_cfg_path).model
36 | if SET_MIN_BOX:
37 | model_cfg.test_cfg['min_bbox_size'] = 10
38 |
39 | model = MODELS.build(model_cfg)
40 | model.init_weights()
41 | checkpoint = _load_checkpoint(ckpt_path, map_location='cpu')
42 | sd = checkpoint.get('state_dict', checkpoint)
43 | print(model.load_state_dict(sd))
44 |
45 | dataloader = build_data_loader('test_without_hbox')
46 | # dataloader = build_data_loader('trainval_with_hbox')
47 | evaluator = build_evaluator(MERGE_PATCHES, FORMAT_ONLY)
48 | evaluator.dataset_meta = dataloader.dataset.metainfo
49 |
50 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
51 |
52 | model = model.to(device=device)
53 | sam = sam.to(device=device)
54 |
55 | predictor = SamPredictor(sam)
56 |
57 | model.eval()
58 | for i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
59 |
60 | evaluator = single_sample_step(i, data, model, predictor, evaluator, dataloader, device, SHOW)
61 |
62 | torch.save(evaluator, './evaluator.pth')
63 |
64 | metrics = evaluator.evaluate(len(dataloader.dataset))
65 |
--------------------------------------------------------------------------------
/main_sam_dota.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import numpy as np
4 | import cv2
5 | from mmrotate.utils import register_all_modules
6 | from data import build_data_loader, build_evaluator, build_visualizer
7 | from utils import show_box, show_mask
8 | import matplotlib.pyplot as plt
9 | from mmengine.structures import InstanceData
10 | from segment_anything import sam_model_registry, SamPredictor
11 | from mmrotate.structures import RotatedBoxes
12 | from mmengine import ProgressBar
13 | from mmdet.models.utils import samplelist_boxtype2tensor
14 |
15 |
16 | register_all_modules(init_default_scope=True)
17 |
18 | SHOW = False
19 | FORMAT_ONLY = False
20 | MERGE_PATCHES = False
21 |
22 |
23 | if __name__ == '__main__':
24 |
25 |
26 | dataloader = build_data_loader('trainval_with_hbox')
27 | evaluator = build_evaluator(MERGE_PATCHES, FORMAT_ONLY)
28 | evaluator.dataset_meta = dataloader.dataset.metainfo
29 |
30 | sam_checkpoint = r"../segment-anything/checkpoints/sam_vit_b_01ec64.pth"
31 | model_type = "vit_b"
32 | device = "cuda"
33 |
34 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
35 |
36 | sam = sam.to(device=device)
37 |
38 | predictor = SamPredictor(sam)
39 |
40 | for i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
41 |
42 | img = data['inputs'][0].permute(1, 2, 0).numpy()[:, :, ::-1]
43 | data_samples = data['data_samples']
44 | data_sample = data_samples[0]
45 | data_sample = data_sample.to(device=device)
46 |
47 | h_bboxes = data_sample.h_gt_bboxes.tensor.to(device=device)
48 | labels = data_sample.gt_instances.labels.to(device=device)
49 |
50 | r_bboxes = []
51 | if len(h_bboxes) == 0:
52 | qualities = h_bboxes[:, 0]
53 | masks = h_bboxes.new_tensor((0, *img.shape[:2]))
54 | else:
55 | predictor.set_image(img)
56 | transformed_boxes = predictor.transform.apply_boxes_torch(h_bboxes, img.shape[:2])
57 | masks, qualities, lr_logits = predictor.predict_torch(
58 | point_coords=None,
59 | point_labels=None,
60 | boxes=transformed_boxes,
61 | multimask_output=False)
62 | masks = masks.squeeze(1)
63 | qualities = qualities.squeeze(-1)
64 | for mask in masks:
65 | y, x = np.nonzero(mask.cpu().numpy())
66 | points = np.stack([x, y], axis=-1)
67 | (cx, cy), (w, h), a = cv2.minAreaRect(points)
68 | r_bboxes.append(np.array([cx, cy, w, h, a/180*np.pi]))
69 |
70 | results = InstanceData()
71 | results.bboxes = RotatedBoxes(r_bboxes)
72 | results.scores = qualities
73 | results.labels = labels
74 | results.masks = masks.cpu().numpy()
75 | results_list = [results]
76 |
77 | # add_pred_to_datasample
78 | for data_sample, pred_instances in zip(data_samples, results_list):
79 | data_sample.pred_instances = pred_instances
80 | samplelist_boxtype2tensor(data_samples)
81 |
82 | evaluator.process(data_samples=data_samples, data_batch=data)
83 |
84 | if SHOW:
85 | plt.figure(figsize=(10, 10))
86 | plt.imshow(img)
87 | for mask in masks:
88 | show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
89 | for box in h_bboxes:
90 | show_box(box.cpu().numpy(), plt.gca())
91 | plt.axis('off')
92 | # plt.show()
93 | plt.savefig(f'./out_mask_{i}.png')
94 |
95 | # draw rbox with mmrotate
96 | visualizer = build_visualizer()
97 | visualizer.dataset_meta = dataloader.dataset.metainfo
98 | out_img = visualizer._draw_instances(
99 | img, results,
100 | dataloader.dataset.metainfo['classes'],
101 | dataloader.dataset.metainfo['palette'])
102 | # visualizer.show()
103 | cv2.imwrite(f'./out_rbox_{i}.png', out_img[:, :, ::-1])
104 |
105 | metrics = evaluator.evaluate(len(dataloader.dataset))
106 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | from mmcv.transforms import BaseTransform
2 | from mmrotate.registry import TRANSFORMS
3 |
4 |
5 | @TRANSFORMS.register_module()
6 | class AddConvertedGTBox(BaseTransform):
7 | """Convert boxes in results to a certain box type."""
8 |
9 | def __init__(self, box_type_mapping: dict) -> None:
10 | self.box_type_mapping = box_type_mapping
11 |
12 | def transform(self, results: dict) -> dict:
13 | """The transform function."""
14 | for key, dst_box_type in self.box_type_mapping.items():
15 | assert key != 'gt_bboxes'
16 | gt_bboxes = results['gt_bboxes']
17 | results[key] = gt_bboxes.convert_to(dst_box_type)
18 | return results
19 |
20 | def __repr__(self):
21 | repr_str = self.__class__.__name__
22 | repr_str += f'(box_type_mapping={self.box_type_mapping})'
23 | return repr_str
24 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Stolen from sam
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | def show_mask(mask, ax, random_color=False):
7 | if random_color:
8 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
9 | else:
10 | color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
11 | h, w = mask.shape[-2:]
12 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
13 | ax.imshow(mask_image)
14 |
15 |
16 | def show_points(coords, labels, ax, marker_size=375):
17 | pos_points = coords[labels == 1]
18 | neg_points = coords[labels == 0]
19 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*',
20 | s=marker_size, edgecolor='white', linewidth=1.25)
21 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*',
22 | s=marker_size, edgecolor='white', linewidth=1.25)
23 |
24 |
25 | def show_box(box, ax):
26 | x0, y0 = box[0], box[1]
27 | w, h = box[2] - box[0], box[3] - box[1]
28 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green',
29 | facecolor=(0, 0, 0, 0), lw=2))
30 |
31 | def show_r_box(box, ax):
32 | x0, y0 = box[0], box[1]
33 | w, h = box[2] - box[0], box[3] - box[1]
34 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green',
35 | facecolor=(0, 0, 0, 0), lw=2))
36 |
--------------------------------------------------------------------------------
/visualizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import List, Optional
3 |
4 | import numpy as np
5 | import torch
6 | from torch import Tensor
7 |
8 | from mmdet.structures.mask import BitmapMasks, PolygonMasks, bitmap_to_polygon
9 | from mmdet.visualization import DetLocalVisualizer, jitter_color
10 | from mmdet.visualization.palette import _get_adaptive_scales
11 | from mmengine.structures import InstanceData
12 |
13 | from mmrotate.registry import VISUALIZERS
14 | from mmrotate.structures.bbox import QuadriBoxes, RotatedBoxes
15 | from mmrotate.visualization.palette import get_palette
16 |
17 |
18 | @VISUALIZERS.register_module()
19 | class RotLocalVisualizerMaskThenBox(DetLocalVisualizer):
20 | """MMRotate Local Visualizer.
21 |
22 | Args:
23 | name (str): Name of the instance. Defaults to 'visualizer'.
24 | image (np.ndarray, optional): the origin image to draw. The format
25 | should be RGB. Defaults to None.
26 | vis_backends (list, optional): Visual backend config list.
27 | Defaults to None.
28 | save_dir (str, optional): Save file dir for all storage backends.
29 | If it is None, the backend storage will not save any data.
30 | bbox_color (str, tuple(int), optional): Color of bbox lines.
31 | The tuple of color should be in BGR order. Defaults to None.
32 | text_color (str, tuple(int), optional): Color of texts.
33 | The tuple of color should be in BGR order.
34 | Defaults to (200, 200, 200).
35 | mask_color (str, tuple(int), optional): Color of masks.
36 | The tuple of color should be in BGR order.
37 | Defaults to None.
38 | line_width (int, float): The linewidth of lines.
39 | Defaults to 3.
40 | alpha (int, float): The transparency of bboxes or mask.
41 | Defaults to 0.8.
42 | """
43 |
44 | def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'],
45 | classes: Optional[List[str]],
46 | palette: Optional[List[tuple]],
47 | box_alpha=None, mask_alpha=None) -> np.ndarray:
48 | """Draw instances of GT or prediction.
49 |
50 | Args:
51 | image (np.ndarray): The image to draw.
52 | instances (:obj:`InstanceData`): Data structure for
53 | instance-level annotations or predictions.
54 | classes (List[str], optional): Category information.
55 | palette (List[tuple], optional): Palette information
56 | corresponding to the category.
57 | Returns:
58 | np.ndarray: the drawn image which channel is RGB.
59 | """
60 | if box_alpha is None:
61 | box_alpha = self.alpha
62 | if mask_alpha is None:
63 | mask_alpha = self.alpha
64 |
65 | self.set_image(image)
66 |
67 | if 'masks' in instances:
68 | labels = instances.labels
69 | masks = instances.masks
70 | if isinstance(masks, torch.Tensor):
71 | masks = masks.numpy()
72 | elif isinstance(masks, (PolygonMasks, BitmapMasks)):
73 | masks = masks.to_ndarray()
74 |
75 | masks = masks.astype(bool)
76 |
77 | max_label = int(max(labels) if len(labels) > 0 else 0)
78 | mask_color = palette if self.mask_color is None \
79 | else self.mask_color
80 | mask_palette = get_palette(mask_color, max_label + 1)
81 | colors = [jitter_color(mask_palette[label]) for label in labels]
82 | text_palette = get_palette(self.text_color, max_label + 1)
83 | text_colors = [text_palette[label] for label in labels]
84 |
85 | polygons = []
86 | for i, mask in enumerate(masks):
87 | contours, _ = bitmap_to_polygon(mask)
88 | polygons.extend(contours)
89 | self.draw_polygons(polygons, edge_colors='w', alpha=mask_alpha)
90 | self.draw_binary_masks(masks, colors=colors, alphas=mask_alpha)
91 |
92 | if 'bboxes' in instances:
93 | bboxes = instances.bboxes
94 | labels = instances.labels
95 |
96 | max_label = int(max(labels) if len(labels) > 0 else 0)
97 | text_palette = get_palette(self.text_color, max_label + 1)
98 | text_colors = [text_palette[label] for label in labels]
99 |
100 | bbox_color = palette if self.bbox_color is None \
101 | else self.bbox_color
102 | bbox_palette = get_palette(bbox_color, max_label + 1)
103 | colors = [bbox_palette[label] for label in labels]
104 |
105 | if isinstance(bboxes, Tensor):
106 | if bboxes.size(-1) == 5:
107 | bboxes = RotatedBoxes(bboxes)
108 | elif bboxes.size(-1) == 8:
109 | bboxes = QuadriBoxes(bboxes)
110 | else:
111 | raise TypeError(
112 | 'Require the shape of `bboxes` to be (n, 5) '
113 | 'or (n, 8), but get `bboxes` with shape being '
114 | f'{bboxes.shape}.')
115 |
116 | bboxes = bboxes.cpu()
117 | polygons = bboxes.convert_to('qbox').tensor
118 | polygons = polygons.reshape(-1, 4, 2)
119 | polygons = [p for p in polygons]
120 | self.draw_polygons(
121 | polygons,
122 | edge_colors=colors,
123 | alpha=box_alpha,
124 | line_widths=self.line_width)
125 |
126 | positions = bboxes.centers + self.line_width
127 | scales = _get_adaptive_scales(bboxes.areas)
128 |
129 | for i, (pos, label) in enumerate(zip(positions, labels)):
130 | label_text = classes[
131 | label] if classes is not None else f'class {label}'
132 | if 'scores' in instances:
133 | score = round(float(instances.scores[i]) * 100, 1)
134 | label_text += f': {score}'
135 |
136 | self.draw_texts(
137 | label_text,
138 | pos,
139 | colors=text_colors[i],
140 | font_sizes=int(13 * scales[i]),
141 | bboxes=[{
142 | 'facecolor': 'black',
143 | 'alpha': 0.8,
144 | 'pad': 0.7,
145 | 'edgecolor': 'none'
146 | }])
147 | return self.get_image()
148 |
--------------------------------------------------------------------------------