├── README.md ├── configs ├── _base_ │ ├── datasets │ │ ├── dior.py │ │ ├── dota.py │ │ ├── dota_coco.py │ │ ├── dota_ms.py │ │ ├── dota_qbox.py │ │ ├── dotav15.py │ │ ├── dotav2.py │ │ ├── hrsc.py │ │ ├── hrsid.py │ │ ├── rsdd.py │ │ ├── srsdd.py │ │ └── ssdd.py │ ├── default_runtime.py │ └── schedules │ │ ├── schedule_1x.py │ │ ├── schedule_3x.py │ │ ├── schedule_40e.py │ │ └── schedule_6x.py └── rotated_fcos │ ├── README.md │ ├── metafile.yml │ ├── rotated-fcos-hbox-le90_r50_fpn_1x_dota.py │ ├── rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota.py │ ├── rotated-fcos-le90_r50_fpn_1x_dota.py │ ├── rotated-fcos-le90_r50_fpn_kld_1x_dota.py │ └── rotated-fcos-le90_r50_fpn_rr-6x_hrsc.py ├── data.py ├── engine.py ├── main_rdet-sam_dota.py ├── main_sam_dota.py ├── transforms.py ├── utils.py └── visualizer.py /README.md: -------------------------------------------------------------------------------- 1 | # SAM-RBox 2 | This is an implementation of [SAM (Segment Anything Model)](https://github.com/facebookresearch/segment-anything) for generating rotated bounding boxes with [MMRotate](https://github.com/open-mmlab/mmrotate), which is a comparison method of [H2RBox-v2: Boosting HBox-supervised Oriented Object Detection via Symmetric Learning](https://arxiv.org/abs/2304.04403). 3 | 4 | **NOTE:** This project has been involved into OpenMMLab's new repo [**_PlayGround_**](https://github.com/open-mmlab/playground). For more details, please refer to [this](https://github.com/open-mmlab/playground/blob/main/mmrotate_sam/README.md). 5 | 6 |
7 | 8 |
9 | 10 | Recently, [SAM](https://arxiv.org/abs/2304.02643) has demonstrated strong zero-shot capabilities by training on the largest segmentation dataset to date. Thus, we use a trained horizontal FCOS detector to provide HBoxes into SAM as prompts, so that corresponding Masks can be generated by zero-shot, and finally the rotated RBoxes are obtained by performing the minimum circumscribed rectangle operation on the predicted Masks. Thanks to the powerful zero-shot capability, SAM-RBox based on ViT-B has achieved 63.94%. However, it is also limited to the time-consuming post-processing, only 1.7 FPS during inference. 11 | 12 | 13 | ![image](https://user-images.githubusercontent.com/79644233/230732578-649086b4-7720-4450-9e87-25873bec07cb.png) 14 | ![image](https://user-images.githubusercontent.com/29257168/230749605-f6584336-a69b-47e8-95ab-87669ca9baf0.png) 15 | 16 | ## Prepare Env 17 | 18 | The code is based on MMRotate 1.x and official API of SAM. 19 | 20 | Here is the installation commands of recommended environment. 21 | ```bash 22 | pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html 23 | 24 | pip install openmim 25 | mim install mmengine 'mmcv>=2.0.0rc0' 'mmrotate>=1.0.0rc0' 26 | 27 | pip install git+https://github.com/facebookresearch/segment-anything.git 28 | pip install opencv-python pycocotools matplotlib onnxruntime onnx 29 | ``` 30 | 31 | ## Note 32 | 1. Prepare DOTA data set according to MMRotate doc. 33 | 2. Download the detector weight from MMRotate model zoo. 34 | 3. `python main_sam_dota.py` prompts SAM with HBox obtained from annotation file (such as DOTA trainval). 35 | 4. `python main_rdet-sam_dota.py` prompts SAM with HBox predicted by a well-trained detector for non-annotated data (such as DOTA test). 36 | 5. Many configs, including pipeline (i.e. transforms), dataset, dataloader, evaluator, visualizer, are set in `data.py`. 37 | 6. You can change the detector config and the corresponding weight path in `main_rdet-sam_dota.py` to any detector that can be built with MMRotate. 38 | 39 | ## Citation 40 | ``` 41 | @article{yu2023h2rboxv2, 42 | title={H2RBox-v2: Boosting HBox-supervised Oriented Object Detection via Symmetric Learning}, 43 | author={Yu, Yi and Yang, Xue and Li, Qingyun and Zhou, Yue and Zhang, Gefan and Yan, Junchi and Da, Feipeng}, 44 | journal={arXiv preprint arXiv:2304.04403}, 45 | year={2023} 46 | } 47 | 48 | @inproceedings{yang2023h2rbox, 49 | title={H2RBox: Horizontal Box Annotation is All You Need for Oriented Object Detection}, 50 | author={Yang, Xue and Zhang, Gefan and Li, Wentong and Wang, Xuehui and Zhou, Yue and Yan, Junchi}, 51 | booktitle={International Conference on Learning Representations}, 52 | year={2023} 53 | } 54 | 55 | @article{kirillov2023segany, 56 | title={Segment Anything}, 57 | author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross}, 58 | journal={arXiv:2304.02643}, 59 | year={2023} 60 | } 61 | ``` 62 | 63 | ### Other awesome SAM projects: 64 | - [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) 65 | - [Zero-Shot Anomaly Detection](https://github.com/caoyunkang/GroundedSAM-zero-shot-anomaly-detection) 66 | - [EditAnything: ControlNet + StableDiffusion based on the SAM segmentation mask](https://github.com/sail-sg/EditAnything) 67 | - [IEA: Image Editing Anything](https://github.com/feizc/IEA) 68 | - [sam-with-mmdet](https://github.com/liuyanyi/sam-with-mmdet) (mmdet 3.0.0, provide RTMDet) 69 | - [Prompt-Segment-Anything](https://github.com/RockeyCoss/Prompt-Segment-Anything) (mmdet 3.0.0, H-DETR, DINO, Focal backbone) 70 | ...... 71 | -------------------------------------------------------------------------------- /configs/_base_/datasets/dior.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'DIORDataset' 3 | data_root = 'data/DIOR/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 10 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True), 11 | dict( 12 | type='mmdet.RandomFlip', 13 | prob=0.75, 14 | direction=['horizontal', 'vertical', 'diagonal']), 15 | dict(type='mmdet.PackDetInputs') 16 | ] 17 | val_pipeline = [ 18 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 19 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True), 20 | # avoid bboxes being resized 21 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 22 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 23 | dict( 24 | type='mmdet.PackDetInputs', 25 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 26 | 'scale_factor')) 27 | ] 28 | test_pipeline = [ 29 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 30 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True), 31 | dict( 32 | type='mmdet.PackDetInputs', 33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 34 | 'scale_factor')) 35 | ] 36 | train_dataloader = dict( 37 | batch_size=2, 38 | num_workers=2, 39 | persistent_workers=True, 40 | sampler=dict(type='DefaultSampler', shuffle=True), 41 | batch_sampler=None, 42 | dataset=dict( 43 | type='ConcatDataset', 44 | ignore_keys=['DATASET_TYPE'], 45 | datasets=[ 46 | dict( 47 | type=dataset_type, 48 | data_root=data_root, 49 | ann_file='ImageSets/Main/train.txt', 50 | data_prefix=dict(img_path='JPEGImages-trainval'), 51 | filter_cfg=dict(filter_empty_gt=True), 52 | pipeline=train_pipeline), 53 | dict( 54 | type=dataset_type, 55 | data_root=data_root, 56 | ann_file='ImageSets/Main/val.txt', 57 | data_prefix=dict(img_path='JPEGImages-trainval'), 58 | filter_cfg=dict(filter_empty_gt=True), 59 | pipeline=train_pipeline, 60 | backend_args=backend_args) 61 | ])) 62 | val_dataloader = dict( 63 | batch_size=1, 64 | num_workers=2, 65 | persistent_workers=True, 66 | drop_last=False, 67 | sampler=dict(type='DefaultSampler', shuffle=False), 68 | dataset=dict( 69 | type=dataset_type, 70 | data_root=data_root, 71 | ann_file='ImageSets/Main/test.txt', 72 | data_prefix=dict(img_path='JPEGImages-test'), 73 | test_mode=True, 74 | pipeline=val_pipeline, 75 | backend_args=backend_args)) 76 | test_dataloader = val_dataloader 77 | 78 | val_evaluator = dict(type='DOTAMetric', metric='mAP') 79 | test_evaluator = val_evaluator 80 | -------------------------------------------------------------------------------- /configs/_base_/datasets/dota.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'DOTADataset' 3 | data_root = 'data/split_ss_dota/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 10 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 11 | dict( 12 | type='mmdet.RandomFlip', 13 | prob=0.75, 14 | direction=['horizontal', 'vertical', 'diagonal']), 15 | dict(type='mmdet.PackDetInputs') 16 | ] 17 | val_pipeline = [ 18 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 19 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 20 | # avoid bboxes being resized 21 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 22 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 23 | dict( 24 | type='mmdet.PackDetInputs', 25 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 26 | 'scale_factor')) 27 | ] 28 | test_pipeline = [ 29 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 30 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 31 | dict( 32 | type='mmdet.PackDetInputs', 33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 34 | 'scale_factor')) 35 | ] 36 | train_dataloader = dict( 37 | batch_size=2, 38 | num_workers=2, 39 | persistent_workers=True, 40 | sampler=dict(type='DefaultSampler', shuffle=True), 41 | batch_sampler=None, 42 | dataset=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | ann_file='trainval/annfiles/', 46 | data_prefix=dict(img_path='trainval/images/'), 47 | filter_cfg=dict(filter_empty_gt=True), 48 | pipeline=train_pipeline)) 49 | val_dataloader = dict( 50 | batch_size=1, 51 | num_workers=2, 52 | persistent_workers=True, 53 | drop_last=False, 54 | sampler=dict(type='DefaultSampler', shuffle=False), 55 | dataset=dict( 56 | type=dataset_type, 57 | data_root=data_root, 58 | ann_file='trainval/annfiles/', 59 | data_prefix=dict(img_path='trainval/images/'), 60 | test_mode=True, 61 | pipeline=val_pipeline)) 62 | test_dataloader = val_dataloader 63 | 64 | val_evaluator = dict(type='DOTAMetric', metric='mAP') 65 | test_evaluator = val_evaluator 66 | 67 | # inference on test dataset and format the output results 68 | # for submission. Note: the test set has no annotation. 69 | # test_dataloader = dict( 70 | # batch_size=1, 71 | # num_workers=2, 72 | # persistent_workers=True, 73 | # drop_last=False, 74 | # sampler=dict(type='DefaultSampler', shuffle=False), 75 | # dataset=dict( 76 | # type=dataset_type, 77 | # data_root=data_root, 78 | # data_prefix=dict(img_path='test/images/'), 79 | # test_mode=True, 80 | # pipeline=test_pipeline)) 81 | # test_evaluator = dict( 82 | # type='DOTAMetric', 83 | # format_only=True, 84 | # merge_patches=True, 85 | # outfile_prefix='./work_dirs/dota/Task1') 86 | -------------------------------------------------------------------------------- /configs/_base_/datasets/dota_coco.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'mmdet.CocoDataset' 3 | data_root = 'data/split_ms_dota/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict( 9 | type='mmdet.LoadAnnotations', 10 | with_bbox=True, 11 | with_mask=True, 12 | poly2mask=False), 13 | dict(type='ConvertMask2BoxType', box_type='rbox'), 14 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 15 | dict( 16 | type='mmdet.RandomFlip', 17 | prob=0.75, 18 | direction=['horizontal', 'vertical', 'diagonal']), 19 | dict(type='mmdet.PackDetInputs') 20 | ] 21 | val_pipeline = [ 22 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 23 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 24 | # avoid bboxes being resized 25 | dict( 26 | type='mmdet.LoadAnnotations', 27 | with_bbox=True, 28 | with_mask=True, 29 | poly2mask=False), 30 | dict(type='ConvertMask2BoxType', box_type='qbox'), 31 | dict( 32 | type='mmdet.PackDetInputs', 33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 34 | 'scale_factor', 'instances')) 35 | ] 36 | test_pipeline = [ 37 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 38 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 39 | dict( 40 | type='mmdet.PackDetInputs', 41 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 42 | 'scale_factor')) 43 | ] 44 | 45 | metainfo = dict( 46 | classes=('plane', 'baseball-diamond', 'bridge', 'ground-track-field', 47 | 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', 48 | 'basketball-court', 'storage-tank', 'soccer-ball-field', 49 | 'roundabout', 'harbor', 'swimming-pool', 'helicopter')) 50 | 51 | train_dataloader = dict( 52 | batch_size=2, 53 | num_workers=2, 54 | persistent_workers=True, 55 | sampler=dict(type='DefaultSampler', shuffle=True), 56 | batch_sampler=None, 57 | dataset=dict( 58 | type=dataset_type, 59 | metainfo=metainfo, 60 | data_root=data_root, 61 | ann_file='train/train.json', 62 | data_prefix=dict(img='train/images/'), 63 | filter_cfg=dict(filter_empty_gt=True), 64 | pipeline=train_pipeline, 65 | backend_args=backend_args)) 66 | val_dataloader = dict( 67 | batch_size=1, 68 | num_workers=2, 69 | persistent_workers=True, 70 | drop_last=False, 71 | sampler=dict(type='DefaultSampler', shuffle=False), 72 | dataset=dict( 73 | type=dataset_type, 74 | metainfo=metainfo, 75 | data_root=data_root, 76 | ann_file='val/val.json', 77 | data_prefix=dict(img='val/images/'), 78 | test_mode=True, 79 | pipeline=val_pipeline, 80 | backend_args=backend_args)) 81 | test_dataloader = val_dataloader 82 | 83 | val_evaluator = dict( 84 | type='RotatedCocoMetric', 85 | metric='bbox', 86 | classwise=True, 87 | backend_args=backend_args) 88 | 89 | test_evaluator = val_evaluator 90 | 91 | # inference on test dataset and format the output results 92 | # for submission. Note: the test set has no annotation. 93 | # test_dataloader = dict( 94 | # batch_size=1, 95 | # num_workers=2, 96 | # persistent_workers=True, 97 | # drop_last=False, 98 | # sampler=dict(type='DefaultSampler', shuffle=False), 99 | # dataset=dict( 100 | # type=dataset_type, 101 | # ann_file='test/test.json', 102 | # data_prefix=dict(img='test/images/'), 103 | # test_mode=True, 104 | # pipeline=test_pipeline)) 105 | # test_evaluator = dict( 106 | # type='DOTAMetric', 107 | # format_only=True, 108 | # merge_patches=True, 109 | # outfile_prefix='./work_dirs/dota/Task1') 110 | -------------------------------------------------------------------------------- /configs/_base_/datasets/dota_ms.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'DOTADataset' 3 | data_root = 'data/split_ms_dota/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 10 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 11 | dict( 12 | type='mmdet.RandomFlip', 13 | prob=0.75, 14 | direction=['horizontal', 'vertical', 'diagonal']), 15 | dict( 16 | type='RandomRotate', 17 | prob=0.5, 18 | angle_range=180, 19 | rect_obj_labels=[9, 11]), 20 | dict(type='mmdet.PackDetInputs') 21 | ] 22 | val_pipeline = [ 23 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 24 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 25 | # avoid bboxes being resized 26 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 27 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 28 | dict( 29 | type='mmdet.PackDetInputs', 30 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 31 | 'scale_factor')) 32 | ] 33 | test_pipeline = [ 34 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 35 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 36 | dict( 37 | type='mmdet.PackDetInputs', 38 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 39 | 'scale_factor')) 40 | ] 41 | train_dataloader = dict( 42 | batch_size=2, 43 | num_workers=2, 44 | persistent_workers=True, 45 | sampler=dict(type='DefaultSampler', shuffle=True), 46 | batch_sampler=None, 47 | dataset=dict( 48 | type=dataset_type, 49 | data_root=data_root, 50 | ann_file='trainval/annfiles/', 51 | data_prefix=dict(img_path='trainval/images/'), 52 | filter_cfg=dict(filter_empty_gt=True), 53 | pipeline=train_pipeline)) 54 | val_dataloader = dict( 55 | batch_size=1, 56 | num_workers=2, 57 | persistent_workers=True, 58 | drop_last=False, 59 | sampler=dict(type='DefaultSampler', shuffle=False), 60 | dataset=dict( 61 | type=dataset_type, 62 | data_root=data_root, 63 | ann_file='trainval/annfiles/', 64 | data_prefix=dict(img_path='trainval/images/'), 65 | test_mode=True, 66 | pipeline=val_pipeline)) 67 | test_dataloader = val_dataloader 68 | 69 | val_evaluator = dict(type='DOTAMetric', metric='mAP') 70 | test_evaluator = val_evaluator 71 | 72 | # inference on test dataset and format the output results 73 | # for submission. Note: the test set has no annotation. 74 | # test_dataloader = dict( 75 | # batch_size=1, 76 | # num_workers=2, 77 | # persistent_workers=True, 78 | # drop_last=False, 79 | # sampler=dict(type='DefaultSampler', shuffle=False), 80 | # dataset=dict( 81 | # type=dataset_type, 82 | # data_root=data_root, 83 | # data_prefix=dict(img_path='test/images/'), 84 | # test_mode=True, 85 | # pipeline=test_pipeline)) 86 | # test_evaluator = dict( 87 | # type='DOTAMetric', 88 | # format_only=True, 89 | # merge_patches=True, 90 | # outfile_prefix='./work_dirs/dota/Task1') 91 | -------------------------------------------------------------------------------- /configs/_base_/datasets/dota_qbox.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'DOTADataset' 3 | data_root = 'data/split_ss_dota/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 9 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 10 | dict( 11 | type='mmdet.RandomFlip', 12 | prob=0.75, 13 | direction=['horizontal', 'vertical', 'diagonal']), 14 | dict(type='mmdet.PackDetInputs') 15 | ] 16 | val_pipeline = [ 17 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 18 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 19 | # avoid bboxes being resized 20 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 21 | dict( 22 | type='mmdet.PackDetInputs', 23 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 24 | 'scale_factor')) 25 | ] 26 | test_pipeline = [ 27 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 28 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 29 | dict( 30 | type='mmdet.PackDetInputs', 31 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 32 | 'scale_factor')) 33 | ] 34 | train_dataloader = dict( 35 | batch_size=2, 36 | num_workers=2, 37 | persistent_workers=True, 38 | sampler=dict(type='DefaultSampler', shuffle=True), 39 | batch_sampler=None, 40 | dataset=dict( 41 | type=dataset_type, 42 | data_root=data_root, 43 | ann_file='trainval/annfiles/', 44 | data_prefix=dict(img_path='trainval/images/'), 45 | filter_cfg=dict(filter_empty_gt=True), 46 | pipeline=train_pipeline)) 47 | val_dataloader = dict( 48 | batch_size=1, 49 | num_workers=2, 50 | persistent_workers=True, 51 | drop_last=False, 52 | sampler=dict(type='DefaultSampler', shuffle=False), 53 | dataset=dict( 54 | type=dataset_type, 55 | data_root=data_root, 56 | ann_file='trainval/annfiles/', 57 | data_prefix=dict(img_path='trainval/images/'), 58 | test_mode=True, 59 | pipeline=val_pipeline)) 60 | test_dataloader = val_dataloader 61 | 62 | val_evaluator = dict( 63 | type='DOTAMetric', metric='mAP', iou_thrs=0.2, predict_box_type='qbox') 64 | test_evaluator = val_evaluator 65 | 66 | # inference on test dataset and format the output results 67 | # for submission. Note: the test set has no annotation. 68 | # test_dataloader = dict( 69 | # batch_size=1, 70 | # num_workers=2, 71 | # persistent_workers=True, 72 | # drop_last=False, 73 | # sampler=dict(type='DefaultSampler', shuffle=False), 74 | # dataset=dict( 75 | # type=dataset_type, 76 | # data_root=data_root, 77 | # data_prefix=dict(img_path='test/images/'), 78 | # test_mode=True, 79 | # pipeline=test_pipeline)) 80 | # test_evaluator = dict( 81 | # type='DOTAMetric', 82 | # format_only=True, 83 | # merge_patches=True, 84 | # predict_box_type='qbox', 85 | # outfile_prefix='./work_dirs/dota/Task1') 86 | -------------------------------------------------------------------------------- /configs/_base_/datasets/dotav15.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'DOTAv15Dataset' 3 | data_root = 'data/split_ss_dota1_5/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 10 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 11 | dict( 12 | type='mmdet.RandomFlip', 13 | prob=0.75, 14 | direction=['horizontal', 'vertical', 'diagonal']), 15 | dict(type='mmdet.PackDetInputs') 16 | ] 17 | val_pipeline = [ 18 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 19 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 20 | # avoid bboxes being resized 21 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 22 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 23 | dict( 24 | type='mmdet.PackDetInputs', 25 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 26 | 'scale_factor')) 27 | ] 28 | test_pipeline = [ 29 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 30 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 31 | dict( 32 | type='mmdet.PackDetInputs', 33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 34 | 'scale_factor')) 35 | ] 36 | train_dataloader = dict( 37 | batch_size=2, 38 | num_workers=2, 39 | persistent_workers=True, 40 | sampler=dict(type='DefaultSampler', shuffle=True), 41 | batch_sampler=None, 42 | dataset=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | ann_file='trainval/annfiles/', 46 | data_prefix=dict(img_path='trainval/images/'), 47 | filter_cfg=dict(filter_empty_gt=True), 48 | pipeline=train_pipeline)) 49 | val_dataloader = dict( 50 | batch_size=1, 51 | num_workers=2, 52 | persistent_workers=True, 53 | drop_last=False, 54 | sampler=dict(type='DefaultSampler', shuffle=False), 55 | dataset=dict( 56 | type=dataset_type, 57 | data_root=data_root, 58 | ann_file='trainval/annfiles/', 59 | data_prefix=dict(img_path='trainval/images/'), 60 | test_mode=True, 61 | pipeline=val_pipeline)) 62 | test_dataloader = val_dataloader 63 | 64 | val_evaluator = dict(type='DOTAMetric', metric='mAP') 65 | test_evaluator = val_evaluator 66 | 67 | # inference on test dataset and format the output results 68 | # for submission. Note: the test set has no annotation. 69 | # test_dataloader = dict( 70 | # batch_size=1, 71 | # num_workers=2, 72 | # persistent_workers=True, 73 | # drop_last=False, 74 | # sampler=dict(type='DefaultSampler', shuffle=False), 75 | # dataset=dict( 76 | # type=dataset_type, 77 | # data_root=data_root, 78 | # data_prefix=dict(img_path='test/images/'), 79 | # test_mode=True, 80 | # pipeline=test_pipeline)) 81 | # test_evaluator = dict( 82 | # type='DOTAMetric', 83 | # format_only=True, 84 | # merge_patches=True, 85 | # outfile_prefix='./work_dirs/dotav15/h2rbox-le90_r50_fpn_adamw-1x_dotav15/Task1') 86 | -------------------------------------------------------------------------------- /configs/_base_/datasets/dotav2.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'DOTAv2Dataset' 3 | data_root = 'data/split_ss_dota2_0/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 10 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 11 | dict( 12 | type='mmdet.RandomFlip', 13 | prob=0.75, 14 | direction=['horizontal', 'vertical', 'diagonal']), 15 | dict(type='mmdet.PackDetInputs') 16 | ] 17 | val_pipeline = [ 18 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 19 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 20 | # avoid bboxes being resized 21 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 22 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 23 | dict( 24 | type='mmdet.PackDetInputs', 25 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 26 | 'scale_factor')) 27 | ] 28 | test_pipeline = [ 29 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 30 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 31 | dict( 32 | type='mmdet.PackDetInputs', 33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 34 | 'scale_factor')) 35 | ] 36 | train_dataloader = dict( 37 | batch_size=2, 38 | num_workers=2, 39 | persistent_workers=True, 40 | sampler=dict(type='DefaultSampler', shuffle=True), 41 | batch_sampler=None, 42 | dataset=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | ann_file='trainval/annfiles/', 46 | data_prefix=dict(img_path='trainval/images/'), 47 | filter_cfg=dict(filter_empty_gt=True), 48 | pipeline=train_pipeline)) 49 | val_dataloader = dict( 50 | batch_size=1, 51 | num_workers=2, 52 | persistent_workers=True, 53 | drop_last=False, 54 | sampler=dict(type='DefaultSampler', shuffle=False), 55 | dataset=dict( 56 | type=dataset_type, 57 | data_root=data_root, 58 | ann_file='trainval/annfiles/', 59 | data_prefix=dict(img_path='trainval/images/'), 60 | test_mode=True, 61 | pipeline=val_pipeline)) 62 | test_dataloader = val_dataloader 63 | 64 | val_evaluator = dict(type='DOTAMetric', metric='mAP') 65 | test_evaluator = val_evaluator 66 | 67 | # inference on test dataset and format the output results 68 | # for submission. Note: the test set has no annotation. 69 | # test_dataloader = dict( 70 | # batch_size=1, 71 | # num_workers=2, 72 | # persistent_workers=True, 73 | # drop_last=False, 74 | # sampler=dict(type='DefaultSampler', shuffle=False), 75 | # dataset=dict( 76 | # type=dataset_type, 77 | # data_root=data_root, 78 | # data_prefix=dict(img_path='test/images/'), 79 | # test_mode=True, 80 | # pipeline=test_pipeline)) 81 | # test_evaluator = dict( 82 | # type='DOTAMetric', 83 | # format_only=True, 84 | # merge_patches=True, 85 | # outfile_prefix='./work_dirs/dotav2/h2rbox-le90_r50_fpn_adamw-1x_dotav2/Task1') 86 | -------------------------------------------------------------------------------- /configs/_base_/datasets/hrsc.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'HRSCDataset' 3 | data_root = 'data/hrsc/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 9 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 10 | dict(type='mmdet.Resize', scale=(800, 512), keep_ratio=True), 11 | dict( 12 | type='mmdet.RandomFlip', 13 | prob=0.75, 14 | direction=['horizontal', 'vertical', 'diagonal']), 15 | dict(type='mmdet.PackDetInputs') 16 | ] 17 | val_pipeline = [ 18 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 19 | dict(type='mmdet.Resize', scale=(800, 512), keep_ratio=True), 20 | # avoid bboxes being resized 21 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 22 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 23 | dict( 24 | type='mmdet.PackDetInputs', 25 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 26 | 'scale_factor')) 27 | ] 28 | test_pipeline = [ 29 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 30 | dict(type='mmdet.Resize', scale=(800, 512), keep_ratio=True), 31 | dict( 32 | type='mmdet.PackDetInputs', 33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 34 | 'scale_factor')) 35 | ] 36 | train_dataloader = dict( 37 | batch_size=2, 38 | num_workers=2, 39 | persistent_workers=True, 40 | sampler=dict(type='DefaultSampler', shuffle=True), 41 | batch_sampler=None, 42 | dataset=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | ann_file='ImageSets/trainval.txt', 46 | data_prefix=dict(sub_data_root='FullDataSet/'), 47 | filter_cfg=dict(filter_empty_gt=True), 48 | pipeline=train_pipeline, 49 | backend_args=backend_args)) 50 | val_dataloader = dict( 51 | batch_size=1, 52 | num_workers=2, 53 | persistent_workers=True, 54 | drop_last=False, 55 | sampler=dict(type='DefaultSampler', shuffle=False), 56 | dataset=dict( 57 | type=dataset_type, 58 | data_root=data_root, 59 | ann_file='ImageSets/test.txt', 60 | data_prefix=dict(sub_data_root='FullDataSet/'), 61 | test_mode=True, 62 | pipeline=val_pipeline, 63 | backend_args=backend_args)) 64 | test_dataloader = val_dataloader 65 | 66 | val_evaluator = dict(type='DOTAMetric', metric='mAP') 67 | test_evaluator = val_evaluator 68 | -------------------------------------------------------------------------------- /configs/_base_/datasets/hrsid.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'mmdet.CocoDataset' 3 | data_root = 'data/HRSID_JPG/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict( 9 | type='mmdet.LoadAnnotations', 10 | with_bbox=True, 11 | with_mask=True, 12 | poly2mask=False), 13 | dict(type='ConvertMask2BoxType', box_type='rbox'), 14 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True), 15 | dict(type='mmdet.FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), 16 | dict( 17 | type='mmdet.RandomFlip', 18 | prob=0.75, 19 | direction=['horizontal', 'vertical', 'diagonal']), 20 | dict(type='mmdet.PackDetInputs') 21 | ] 22 | val_pipeline = [ 23 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 24 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True), 25 | # avoid bboxes being resized 26 | dict( 27 | type='mmdet.LoadAnnotations', 28 | with_bbox=True, 29 | with_mask=True, 30 | poly2mask=False), 31 | dict(type='ConvertMask2BoxType', box_type='qbox'), 32 | dict( 33 | type='mmdet.PackDetInputs', 34 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 35 | 'scale_factor', 'instances')) 36 | ] 37 | test_pipeline = [ 38 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 39 | dict(type='mmdet.Resize', scale=(800, 800), keep_ratio=True), 40 | dict( 41 | type='mmdet.PackDetInputs', 42 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 43 | 'scale_factor')) 44 | ] 45 | 46 | metainfo = dict(classes=('ship', )) 47 | 48 | train_dataloader = dict( 49 | batch_size=2, 50 | num_workers=2, 51 | persistent_workers=True, 52 | sampler=dict(type='DefaultSampler', shuffle=True), 53 | batch_sampler=None, 54 | dataset=dict( 55 | type=dataset_type, 56 | metainfo=metainfo, 57 | data_root=data_root, 58 | ann_file='annotations/train2017.json', 59 | data_prefix=dict(img='JPEGImages/'), 60 | filter_cfg=dict(filter_empty_gt=True), 61 | pipeline=train_pipeline, 62 | backend_args=backend_args)) 63 | val_dataloader = dict( 64 | batch_size=1, 65 | num_workers=2, 66 | persistent_workers=True, 67 | drop_last=False, 68 | sampler=dict(type='DefaultSampler', shuffle=False), 69 | dataset=dict( 70 | type=dataset_type, 71 | metainfo=metainfo, 72 | data_root=data_root, 73 | ann_file='annotations/test2017.json', 74 | data_prefix=dict(img='JPEGImages/'), 75 | test_mode=True, 76 | pipeline=val_pipeline, 77 | backend_args=backend_args)) 78 | test_dataloader = val_dataloader 79 | 80 | val_evaluator = dict(type='RotatedCocoMetric', metric='bbox') 81 | 82 | test_evaluator = val_evaluator 83 | -------------------------------------------------------------------------------- /configs/_base_/datasets/rsdd.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'mmdet.CocoDataset' 3 | data_root = 'data/rsdd/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict( 9 | type='mmdet.LoadAnnotations', 10 | with_bbox=True, 11 | with_mask=True, 12 | poly2mask=False), 13 | dict(type='ConvertMask2BoxType', box_type='rbox'), 14 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True), 15 | dict( 16 | type='mmdet.RandomFlip', 17 | prob=0.75, 18 | direction=['horizontal', 'vertical', 'diagonal']), 19 | dict(type='mmdet.PackDetInputs') 20 | ] 21 | val_pipeline = [ 22 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 23 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True), 24 | # avoid bboxes being resized 25 | dict( 26 | type='mmdet.LoadAnnotations', 27 | with_bbox=True, 28 | with_mask=True, 29 | poly2mask=False), 30 | dict(type='ConvertMask2BoxType', box_type='qbox'), 31 | dict( 32 | type='mmdet.PackDetInputs', 33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 34 | 'scale_factor', 'instances')) 35 | ] 36 | test_pipeline = [ 37 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 38 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True), 39 | dict( 40 | type='mmdet.PackDetInputs', 41 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 42 | 'scale_factor')) 43 | ] 44 | 45 | metainfo = dict(classes=('ship', )) 46 | 47 | train_dataloader = dict( 48 | batch_size=2, 49 | num_workers=2, 50 | persistent_workers=True, 51 | sampler=dict(type='DefaultSampler', shuffle=True), 52 | batch_sampler=None, 53 | dataset=dict( 54 | type=dataset_type, 55 | metainfo=metainfo, 56 | data_root=data_root, 57 | ann_file='ImageSets/train.json', 58 | data_prefix=dict(img='JPEGImages/'), 59 | filter_cfg=dict(filter_empty_gt=True), 60 | pipeline=train_pipeline, 61 | backend_args=backend_args)) 62 | val_dataloader = dict( 63 | batch_size=1, 64 | num_workers=2, 65 | persistent_workers=True, 66 | drop_last=False, 67 | sampler=dict(type='DefaultSampler', shuffle=False), 68 | dataset=dict( 69 | type=dataset_type, 70 | metainfo=metainfo, 71 | data_root=data_root, 72 | ann_file='ImageSets/test.json', 73 | data_prefix=dict(img='JPEGImages/'), 74 | test_mode=True, 75 | pipeline=val_pipeline, 76 | backend_args=backend_args)) 77 | test_dataloader = val_dataloader 78 | 79 | val_evaluator = dict(type='RotatedCocoMetric', metric='bbox') 80 | 81 | test_evaluator = val_evaluator 82 | -------------------------------------------------------------------------------- /configs/_base_/datasets/srsdd.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'mmdet.CocoDataset' 3 | data_root = 'data/srsdd/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict( 9 | type='mmdet.LoadAnnotations', 10 | with_bbox=True, 11 | with_mask=True, 12 | poly2mask=False), 13 | dict(type='ConvertMask2BoxType', box_type='rbox'), 14 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 15 | dict( 16 | type='mmdet.RandomFlip', 17 | prob=0.75, 18 | direction=['horizontal', 'vertical', 'diagonal']), 19 | dict(type='mmdet.PackDetInputs') 20 | ] 21 | val_pipeline = [ 22 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 23 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 24 | # avoid bboxes being resized 25 | dict( 26 | type='mmdet.LoadAnnotations', 27 | with_bbox=True, 28 | with_mask=True, 29 | poly2mask=False), 30 | dict(type='ConvertMask2BoxType', box_type='qbox'), 31 | dict( 32 | type='mmdet.PackDetInputs', 33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 34 | 'scale_factor', 'instances')) 35 | ] 36 | test_pipeline = [ 37 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 38 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 39 | dict( 40 | type='mmdet.PackDetInputs', 41 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 42 | 'scale_factor')) 43 | ] 44 | 45 | metainfo = dict( 46 | classes=('Container', 'Dredger', 'LawEnforce', 'Cell-Container', 'ore-oil', 47 | 'Fishing')) 48 | 49 | train_dataloader = dict( 50 | batch_size=2, 51 | num_workers=2, 52 | persistent_workers=True, 53 | sampler=dict(type='DefaultSampler', shuffle=True), 54 | batch_sampler=None, 55 | dataset=dict( 56 | type=dataset_type, 57 | metainfo=metainfo, 58 | data_root=data_root, 59 | ann_file='train/train.json', 60 | data_prefix=dict(img='train/images/'), 61 | filter_cfg=dict(filter_empty_gt=True), 62 | pipeline=train_pipeline, 63 | backend_args=backend_args)) 64 | val_dataloader = dict( 65 | batch_size=1, 66 | num_workers=2, 67 | persistent_workers=True, 68 | drop_last=False, 69 | sampler=dict(type='DefaultSampler', shuffle=False), 70 | dataset=dict( 71 | type=dataset_type, 72 | metainfo=metainfo, 73 | data_root=data_root, 74 | ann_file='test/test.json', 75 | data_prefix=dict(img='test/images/'), 76 | test_mode=True, 77 | pipeline=val_pipeline, 78 | backend_args=backend_args)) 79 | test_dataloader = val_dataloader 80 | 81 | val_evaluator = dict(type='RotatedCocoMetric', metric='bbox') 82 | 83 | test_evaluator = val_evaluator 84 | -------------------------------------------------------------------------------- /configs/_base_/datasets/ssdd.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'mmdet.CocoDataset' 3 | data_root = 'data/ssdd/' 4 | backend_args = None 5 | 6 | train_pipeline = [ 7 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 8 | dict( 9 | type='mmdet.LoadAnnotations', 10 | with_bbox=True, 11 | with_mask=True, 12 | poly2mask=False), 13 | dict(type='ConvertMask2BoxType', box_type='rbox'), 14 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True), 15 | dict( 16 | type='mmdet.RandomFlip', 17 | prob=0.75, 18 | direction=['horizontal', 'vertical', 'diagonal']), 19 | dict(type='mmdet.PackDetInputs') 20 | ] 21 | val_pipeline = [ 22 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 23 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True), 24 | # avoid bboxes being resized 25 | dict( 26 | type='mmdet.LoadAnnotations', 27 | with_bbox=True, 28 | with_mask=True, 29 | poly2mask=False), 30 | dict(type='ConvertMask2BoxType', box_type='qbox'), 31 | dict( 32 | type='mmdet.PackDetInputs', 33 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 34 | 'scale_factor', 'instances')) 35 | ] 36 | test_pipeline = [ 37 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 38 | dict(type='mmdet.Resize', scale=(512, 512), keep_ratio=True), 39 | dict( 40 | type='mmdet.PackDetInputs', 41 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 42 | 'scale_factor')) 43 | ] 44 | 45 | metainfo = dict(classes=('ship', )) 46 | 47 | train_dataloader = dict( 48 | batch_size=2, 49 | num_workers=2, 50 | persistent_workers=True, 51 | sampler=dict(type='DefaultSampler', shuffle=True), 52 | batch_sampler=None, 53 | dataset=dict( 54 | type=dataset_type, 55 | metainfo=metainfo, 56 | data_root=data_root, 57 | ann_file='train/train.json', 58 | data_prefix=dict(img='train/images/'), 59 | filter_cfg=dict(filter_empty_gt=True), 60 | pipeline=train_pipeline, 61 | backend_args=backend_args)) 62 | val_dataloader = dict( 63 | batch_size=1, 64 | num_workers=2, 65 | persistent_workers=True, 66 | drop_last=False, 67 | sampler=dict(type='DefaultSampler', shuffle=False), 68 | dataset=dict( 69 | type=dataset_type, 70 | metainfo=metainfo, 71 | data_root=data_root, 72 | ann_file='test/all/test.json', 73 | data_prefix=dict(img='test/all/images/'), 74 | test_mode=True, 75 | pipeline=val_pipeline, 76 | backend_args=backend_args)) 77 | test_dataloader = val_dataloader 78 | 79 | val_evaluator = dict(type='RotatedCocoMetric', metric='bbox') 80 | 81 | test_evaluator = val_evaluator 82 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | default_scope = 'mmrotate' 2 | 3 | default_hooks = dict( 4 | timer=dict(type='IterTimerHook'), 5 | logger=dict(type='LoggerHook', interval=50), 6 | param_scheduler=dict(type='ParamSchedulerHook'), 7 | checkpoint=dict(type='CheckpointHook', interval=1), 8 | sampler_seed=dict(type='DistSamplerSeedHook'), 9 | visualization=dict(type='mmdet.DetVisualizationHook')) 10 | 11 | env_cfg = dict( 12 | cudnn_benchmark=False, 13 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 14 | dist_cfg=dict(backend='nccl'), 15 | ) 16 | 17 | vis_backends = [dict(type='LocalVisBackend')] 18 | visualizer = dict( 19 | type='RotLocalVisualizer', vis_backends=vis_backends, name='visualizer') 20 | log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True) 21 | 22 | log_level = 'INFO' 23 | load_from = None 24 | resume = False 25 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_1x.py: -------------------------------------------------------------------------------- 1 | # training schedule for 1x 2 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1) 3 | val_cfg = dict(type='ValLoop') 4 | test_cfg = dict(type='TestLoop') 5 | 6 | # learning rate 7 | param_scheduler = [ 8 | dict( 9 | type='LinearLR', 10 | start_factor=1.0 / 3, 11 | by_epoch=False, 12 | begin=0, 13 | end=500), 14 | dict( 15 | type='MultiStepLR', 16 | begin=0, 17 | end=12, 18 | by_epoch=True, 19 | milestones=[8, 11], 20 | gamma=0.1) 21 | ] 22 | 23 | # optimizer 24 | optim_wrapper = dict( 25 | type='OptimWrapper', 26 | optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001), 27 | clip_grad=dict(max_norm=35, norm_type=2)) 28 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_3x.py: -------------------------------------------------------------------------------- 1 | # training schedule for 1x 2 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=36, val_interval=1) 3 | val_cfg = dict(type='ValLoop') 4 | test_cfg = dict(type='TestLoop') 5 | 6 | # learning rate 7 | param_scheduler = [ 8 | dict( 9 | type='LinearLR', 10 | start_factor=1.0 / 3, 11 | by_epoch=False, 12 | begin=0, 13 | end=500), 14 | dict( 15 | type='MultiStepLR', 16 | begin=0, 17 | end=36, 18 | by_epoch=True, 19 | milestones=[24, 33], 20 | gamma=0.1) 21 | ] 22 | 23 | # optimizer 24 | optim_wrapper = dict( 25 | type='OptimWrapper', 26 | optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001), 27 | clip_grad=dict(max_norm=35, norm_type=2)) 28 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_40e.py: -------------------------------------------------------------------------------- 1 | # training schedule for 1x 2 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=40, val_interval=1) 3 | val_cfg = dict(type='ValLoop') 4 | test_cfg = dict(type='TestLoop') 5 | 6 | # learning rate 7 | param_scheduler = [ 8 | dict( 9 | type='LinearLR', 10 | start_factor=1.0 / 3, 11 | by_epoch=False, 12 | begin=0, 13 | end=500), 14 | dict( 15 | type='MultiStepLR', 16 | begin=0, 17 | end=40, 18 | by_epoch=True, 19 | milestones=[24, 32, 38], 20 | gamma=0.1) 21 | ] 22 | 23 | # optimizer 24 | optim_wrapper = dict( 25 | type='OptimWrapper', 26 | optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001), 27 | clip_grad=dict(max_norm=35, norm_type=2)) 28 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_6x.py: -------------------------------------------------------------------------------- 1 | # training schedule for 1x 2 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=72, val_interval=1) 3 | val_cfg = dict(type='ValLoop') 4 | test_cfg = dict(type='TestLoop') 5 | 6 | # learning rate 7 | param_scheduler = [ 8 | dict( 9 | type='LinearLR', 10 | start_factor=1.0 / 3, 11 | by_epoch=False, 12 | begin=0, 13 | end=500), 14 | dict( 15 | type='MultiStepLR', 16 | begin=0, 17 | end=72, 18 | by_epoch=True, 19 | milestones=[48, 66], 20 | gamma=0.1) 21 | ] 22 | 23 | # optimizer 24 | optim_wrapper = dict( 25 | type='OptimWrapper', 26 | optimizer=dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001), 27 | clip_grad=dict(max_norm=35, norm_type=2)) 28 | -------------------------------------------------------------------------------- /configs/rotated_fcos/README.md: -------------------------------------------------------------------------------- 1 | # Rotated FCOS 2 | 3 | > [FCOS: Fully Convolutional One-Stage Object Detection](https://arxiv.org/abs/1904.01355) 4 | 5 | 6 | 7 | ## Abstract 8 | 9 |
10 | 11 |
12 | 13 | We propose a fully convolutional one-stage object detector (FCOS) to solve object detection in a per-pixel prediction 14 | fashion, analogue to semantic segmentation. Almost all state-of-the-art object detectors such as RetinaNet, SSD, YOLOv3, 15 | and Faster R-CNN rely on pre-defined anchor boxes. In contrast, our proposed detector FCOS is anchor box free, as well 16 | as proposal free. By eliminating the predefined set of anchor boxes, FCOS completely avoids the complicated computation 17 | related to anchor boxes such as calculating overlapping during training. More importantly, we also avoid all 18 | hyper-parameters related to anchor boxes, which are often very sensitive to the final detection performance. With the 19 | only post-processing non-maximum suppression (NMS), FCOS with ResNeXt-64x4d-101 achieves 44.7% in AP with single-model 20 | and single-scale testing, surpassing previous one-stage detectors with the advantage of being much simpler. For the 21 | first time, we demonstrate a much simpler and flexible detection framework achieving improved detection accuracy. We 22 | hope that the proposed FCOS framework can serve as a simple and strong alternative for many other instance-level tasks. 23 | 24 | ## Results and Models 25 | 26 | DOTA1.0 27 | 28 | | Backbone | mAP | Angle | Separate Angle | Tricks | lr schd | Mem (GB) | Inf Time (fps) | Aug | Batch Size | Configs | Download | 29 | | :----------------------: | :---: | :---: | :------------: | :----: | :-----: | :------: | :------------: | :-: | :--------: | :-------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | 30 | | ResNet50 (1024,1024,200) | 70.70 | le90 | Y | Y | 1x | 4.18 | 26.4 | - | 2 | [rotated-fcos-hbox-le90_r50_fpn_1x_dota](rotated-fcos-hbox-le90_r50_fpn_1x_dota.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90-0be71a0c.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90_20220409_023250.log.json) | 31 | | ResNet50 (1024,1024,200) | 71.28 | le90 | N | Y | 1x | 4.18 | 25.9 | - | 2 | [rotated-fcos-le90_r50_fpn_1x_dota](rotated-fcos-le90_r50_fpn_1x_dota.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90-d87568ed.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90_20220413_163526.log.json) | 32 | | ResNet50 (1024,1024,200) | 71.76 | le90 | Y | Y | 1x | 4.23 | 25.7 | - | 2 | [rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota](rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90-4e044ad2.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90_20220409_080616.log.json) | 33 | | ResNet50 (1024,1024,200) | 71.89 | le90 | N | Y | 1x | 4.18 | 26.2 | - | 2 | [rotated-fcos-le90_r50_fpn_kld_1x_dota](rotated-fcos-le90_r50_fpn_kld_1x_dota.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90-ecafdb2b.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90_20220409_202939.log.json) | 34 | 35 | **Notes:** 36 | 37 | - `MS` means multiple scale image split. 38 | - `RR` means random rotation. 39 | - `Rotated IoU Loss` need mmcv version 1.5.0 or above. 40 | - `Separate Angle` means angle loss is calculated separately. 41 | At this time bbox loss uses horizontal bbox loss such as `IoULoss`, `GIoULoss`. 42 | - Tricks means setting `norm_on_bbox`, `centerness_on_reg`, `center_sampling` as `True`. 43 | - Inf time was tested on a single RTX3090. 44 | 45 | ## Citation 46 | 47 | ``` 48 | @article{tian2019fcos, 49 | title={FCOS: Fully Convolutional One-Stage Object Detection}, 50 | author={Tian, Zhi and Shen, Chunhua and Chen, Hao and He, Tong}, 51 | journal={arXiv preprint arXiv:1904.01355}, 52 | year={2019} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /configs/rotated_fcos/metafile.yml: -------------------------------------------------------------------------------- 1 | Collections: 2 | - Name: rotated_fcos 3 | Metadata: 4 | Training Data: DOTAv1.0 5 | Training Techniques: 6 | - SGD with Momentum 7 | - Weight Decay 8 | Training Resources: 1x Tesla V100 9 | Architecture: 10 | - ResNet 11 | Paper: 12 | URL: https://arxiv.org/abs/1904.01355 13 | Title: 'FCOS: Fully Convolutional One-Stage Object Detection' 14 | README: configs/rotated_fcos/README.md 15 | 16 | Models: 17 | - Name: rotated-fcos-hbox-le90_r50_fpn_1x_dota 18 | In Collection: rotated_fcos 19 | Config: configs/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_1x_dota.py 20 | Metadata: 21 | Training Data: DOTAv1.0 22 | Results: 23 | - Task: Oriented Object Detection 24 | Dataset: DOTAv1.0 25 | Metrics: 26 | mAP: 70.70 27 | Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_1x_dota/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90-0be71a0c.pth 28 | 29 | - Name: rotated-fcos-le90_r50_fpn_1x_dota 30 | In Collection: rotated_fcos 31 | Config: configs/rotated_fcos/rotated-fcos-le90_r50_fpn_1x_dota.py 32 | Metadata: 33 | Training Data: DOTAv1.0 34 | Results: 35 | - Task: Oriented Object Detection 36 | Dataset: DOTAv1.0 37 | Metrics: 38 | mAP: 71.28 39 | Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90-d87568ed.pth 40 | 41 | - Name: rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota 42 | In Collection: rotated_fcos 43 | Config: configs/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota.py 44 | Metadata: 45 | Training Data: DOTAv1.0 46 | Results: 47 | - Task: Oriented Object Detection 48 | Dataset: DOTAv1.0 49 | Metrics: 50 | mAP: 71.76 51 | Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90-4e044ad2.pth 52 | 53 | - Name: rotated-fcos-le90_r50_fpn_kld_1x_dota 54 | In Collection: rotated_fcos 55 | Config: configs/rotated_fcos/rotated-fcos-le90_r50_fpn_kld_1x_dota.py 56 | Metadata: 57 | Training Data: DOTAv1.0 58 | Results: 59 | - Task: Oriented Object Detection 60 | Dataset: DOTAv1.0 61 | Metrics: 62 | mAP: 71.89 63 | Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90-ecafdb2b.pth 64 | -------------------------------------------------------------------------------- /configs/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_1x_dota.py: -------------------------------------------------------------------------------- 1 | _base_ = 'rotated-fcos-le90_r50_fpn_1x_dota.py' 2 | 3 | model = dict( 4 | bbox_head=dict( 5 | use_hbbox_loss=True, 6 | scale_angle=True, 7 | angle_coder=dict(type='PseudoAngleCoder'), 8 | loss_angle=dict(_delete_=True, type='mmdet.L1Loss', loss_weight=0.2), 9 | loss_bbox=dict(type='mmdet.IoULoss', loss_weight=1.0), 10 | )) 11 | -------------------------------------------------------------------------------- /configs/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_csl-gaussian_1x_dota.py: -------------------------------------------------------------------------------- 1 | _base_ = 'rotated-fcos-le90_r50_fpn_1x_dota.py' 2 | 3 | angle_version = {{_base_.angle_version}} 4 | 5 | # model settings 6 | model = dict( 7 | bbox_head=dict( 8 | use_hbbox_loss=True, 9 | scale_angle=False, 10 | angle_coder=dict( 11 | type='CSLCoder', 12 | angle_version=angle_version, 13 | omega=1, 14 | window='gaussian', 15 | radius=1), 16 | loss_angle=dict( 17 | _delete_=True, 18 | type='SmoothFocalLoss', 19 | gamma=2.0, 20 | alpha=0.25, 21 | loss_weight=0.2), 22 | loss_bbox=dict(type='mmdet.IoULoss', loss_weight=1.0), 23 | )) 24 | -------------------------------------------------------------------------------- /configs/rotated_fcos/rotated-fcos-le90_r50_fpn_1x_dota.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/dota.py', '../_base_/schedules/schedule_1x.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | angle_version = 'le90' 6 | 7 | # model settings 8 | model = dict( 9 | type='mmdet.FCOS', 10 | data_preprocessor=dict( 11 | type='mmdet.DetDataPreprocessor', 12 | mean=[123.675, 116.28, 103.53], 13 | std=[58.395, 57.12, 57.375], 14 | bgr_to_rgb=True, 15 | pad_size_divisor=32, 16 | boxtype2tensor=False), 17 | backbone=dict( 18 | type='mmdet.ResNet', 19 | depth=50, 20 | num_stages=4, 21 | out_indices=(0, 1, 2, 3), 22 | frozen_stages=1, 23 | norm_cfg=dict(type='BN', requires_grad=True), 24 | norm_eval=True, 25 | style='pytorch', 26 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), 27 | neck=dict( 28 | type='mmdet.FPN', 29 | in_channels=[256, 512, 1024, 2048], 30 | out_channels=256, 31 | start_level=1, 32 | add_extra_convs='on_output', 33 | num_outs=5, 34 | relu_before_extra_convs=True), 35 | bbox_head=dict( 36 | type='RotatedFCOSHead', 37 | num_classes=15, 38 | in_channels=256, 39 | stacked_convs=4, 40 | feat_channels=256, 41 | strides=[8, 16, 32, 64, 128], 42 | center_sampling=True, 43 | center_sample_radius=1.5, 44 | norm_on_bbox=True, 45 | centerness_on_reg=True, 46 | use_hbbox_loss=False, 47 | scale_angle=True, 48 | bbox_coder=dict( 49 | type='DistanceAnglePointCoder', angle_version=angle_version), 50 | loss_cls=dict( 51 | type='mmdet.FocalLoss', 52 | use_sigmoid=True, 53 | gamma=2.0, 54 | alpha=0.25, 55 | loss_weight=1.0), 56 | loss_bbox=dict(type='RotatedIoULoss', loss_weight=1.0), 57 | loss_angle=None, 58 | loss_centerness=dict( 59 | type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), 60 | # training and testing settings 61 | train_cfg=None, 62 | test_cfg=dict( 63 | nms_pre=2000, 64 | min_bbox_size=0, 65 | score_thr=0.05, 66 | nms=dict(type='nms_rotated', iou_threshold=0.1), 67 | max_per_img=2000)) 68 | -------------------------------------------------------------------------------- /configs/rotated_fcos/rotated-fcos-le90_r50_fpn_kld_1x_dota.py: -------------------------------------------------------------------------------- 1 | _base_ = 'rotated-fcos-le90_r50_fpn_1x_dota.py' 2 | 3 | model = dict( 4 | bbox_head=dict( 5 | loss_bbox=dict( 6 | _delete_=True, 7 | type='GDLoss_v1', 8 | loss_type='kld', 9 | fun='log1p', 10 | tau=1, 11 | loss_weight=1.0))) 12 | -------------------------------------------------------------------------------- /configs/rotated_fcos/rotated-fcos-le90_r50_fpn_rr-6x_hrsc.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/hrsc.py', '../_base_/schedules/schedule_6x.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | angle_version = 'le90' 6 | 7 | # model settings 8 | model = dict( 9 | type='mmdet.FCOS', 10 | data_preprocessor=dict( 11 | type='mmdet.DetDataPreprocessor', 12 | mean=[123.675, 116.28, 103.53], 13 | std=[58.395, 57.12, 57.375], 14 | bgr_to_rgb=True, 15 | pad_size_divisor=32, 16 | boxtype2tensor=False), 17 | backbone=dict( 18 | type='mmdet.ResNet', 19 | depth=50, 20 | num_stages=4, 21 | out_indices=(0, 1, 2, 3), 22 | frozen_stages=1, 23 | norm_cfg=dict(type='BN', requires_grad=True), 24 | norm_eval=True, 25 | style='pytorch', 26 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), 27 | neck=dict( 28 | type='mmdet.FPN', 29 | in_channels=[256, 512, 1024, 2048], 30 | out_channels=256, 31 | start_level=1, 32 | add_extra_convs='on_output', 33 | num_outs=5, 34 | relu_before_extra_convs=True), 35 | bbox_head=dict( 36 | type='RotatedFCOSHead', 37 | num_classes=1, 38 | in_channels=256, 39 | stacked_convs=4, 40 | feat_channels=256, 41 | strides=[8, 16, 32, 64, 128], 42 | center_sampling=True, 43 | center_sample_radius=1.5, 44 | norm_on_bbox=True, 45 | centerness_on_reg=True, 46 | use_hbbox_loss=False, 47 | scale_angle=True, 48 | bbox_coder=dict( 49 | type='DistanceAnglePointCoder', angle_version=angle_version), 50 | loss_cls=dict( 51 | type='mmdet.FocalLoss', 52 | use_sigmoid=True, 53 | gamma=2.0, 54 | alpha=0.25, 55 | loss_weight=1.0), 56 | loss_bbox=dict(type='RotatedIoULoss', loss_weight=1.0), 57 | loss_angle=None, 58 | loss_centerness=dict( 59 | type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), 60 | # training and testing settings 61 | train_cfg=None, 62 | test_cfg=dict( 63 | nms_pre=2000, 64 | min_bbox_size=0, 65 | score_thr=0.05, 66 | nms=dict(type='nms_rotated', iou_threshold=0.1), 67 | max_per_img=2000)) 68 | 69 | train_pipeline = [ 70 | dict(type='mmdet.LoadImageFromFile', backend_args={{_base_.backend_args}}), 71 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 72 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 73 | dict(type='mmdet.Resize', scale=(800, 512), keep_ratio=True), 74 | dict( 75 | type='mmdet.RandomFlip', 76 | prob=0.75, 77 | direction=['horizontal', 'vertical', 'diagonal']), 78 | dict(type='RandomRotate', prob=0.5, angle_range=180), 79 | dict(type='mmdet.PackDetInputs') 80 | ] 81 | 82 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) 83 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | from functools import partial 4 | from typing import Dict, Optional, Union, List 5 | 6 | from mmengine.runner import Runner 7 | from mmengine.evaluator import Evaluator 8 | from mmengine.dataset import worker_init_fn 9 | from mmengine.dist import get_rank 10 | from mmengine.logging import print_log 11 | from mmengine.registry import DATA_SAMPLERS, FUNCTIONS, EVALUATOR, VISUALIZERS 12 | from mmengine.utils import digit_version 13 | from mmengine.utils.dl_utils import TORCH_VERSION 14 | 15 | import transforms 16 | import visualizer 17 | 18 | 19 | from torch.utils.data import DataLoader 20 | 21 | from mmrotate.registry import DATASETS 22 | 23 | 24 | def build_data_loader(data_name=None): 25 | if data_name is None or data_name == 'trainval_with_hbox': 26 | return MMEngine_build_dataloader(dataloader=naive_trainval_dataloader) 27 | elif data_name == 'test_without_hbox': 28 | return MMEngine_build_dataloader(dataloader=naive_test_dataloader) 29 | else: 30 | raise NotImplementedError() 31 | 32 | 33 | def build_evaluator(merge_patches=True, format_only=False): 34 | naive_evaluator.update(dict( 35 | merge_patches=merge_patches, format_only=format_only)) 36 | return MMEngine_build_evaluator(evaluator=naive_evaluator) 37 | 38 | 39 | def build_visualizer(): 40 | vis_backends = [dict(type='LocalVisBackend')] 41 | visualizer = dict( 42 | type='RotLocalVisualizerMaskThenBox', vis_backends=vis_backends, 43 | name='sammrotate', save_dir='./rbbox_vis') 44 | return VISUALIZERS.build(visualizer) 45 | 46 | 47 | # dataset settings 48 | dataset_type = 'DOTADataset' 49 | data_root = 'data/split_ss_dota/' 50 | backend_args = None 51 | 52 | naive_trainval_pipeline = [ 53 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 54 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 55 | # avoid bboxes being resized 56 | dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'), 57 | # Horizontal GTBox, (x1,y1,x2,y2) 58 | dict(type='AddConvertedGTBox', box_type_mapping=dict(h_gt_bboxes='hbox')), 59 | dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 60 | # # Horizontal GTBox, (x,y,w,h,theta) 61 | # dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')), 62 | dict( 63 | type='mmdet.PackDetInputs', 64 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 65 | 'scale_factor', 'h_gt_bboxes')) 66 | ] 67 | 68 | naive_test_pipeline = [ 69 | dict(type='mmdet.LoadImageFromFile', backend_args=backend_args), 70 | dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), 71 | dict( 72 | type='mmdet.PackDetInputs', 73 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 74 | 'scale_factor')) 75 | ] 76 | 77 | naive_trainval_dataset = dict( 78 | type=dataset_type, 79 | data_root=data_root, 80 | # ann_file='trainval/annfiles/', 81 | # ann_file='trainval/annfiles-1sample/', 82 | # ann_file='trainval/annfiles-3sample/', 83 | # ann_file='trainval/annfiles-10sample/', 84 | # ann_file='trainval/annfiles-30sample/', 85 | # ann_file='trainval/annfiles-100sample/', 86 | ann_file='trainval/annfiles-1000sample/', 87 | data_prefix=dict(img_path='trainval/images/'), 88 | test_mode=True, # we only inference the sam 89 | pipeline=naive_trainval_pipeline) 90 | 91 | naive_test_dataset = dict( 92 | type=dataset_type, 93 | data_root=data_root, 94 | data_prefix=dict(img_path='test/images/'), 95 | test_mode=True, 96 | pipeline=naive_test_pipeline) 97 | 98 | naive_trainval_dataloader = dict( 99 | batch_size=1, 100 | # num_workers=0, # For debug 101 | num_workers=2, 102 | # persistent_workers=False, # For debug 103 | persistent_workers=True, 104 | drop_last=False, 105 | sampler=dict(type='DefaultSampler', shuffle=False), 106 | dataset=naive_trainval_dataset) 107 | 108 | naive_test_dataloader = dict( 109 | batch_size=1, 110 | # num_workers=0, # For debug 111 | num_workers=2, 112 | # persistent_workers=False, # For debug 113 | persistent_workers=True, 114 | drop_last=False, 115 | sampler=dict(type='DefaultSampler', shuffle=False), 116 | dataset=naive_test_dataset) 117 | 118 | naive_evaluator = dict( 119 | type='DOTAMetric', metric='mAP', outfile_prefix='./work_dirs/dota/Task1') 120 | 121 | 122 | def MMEngine_build_dataloader(dataloader: Union[DataLoader, Dict], 123 | seed: Optional[int] = None, 124 | diff_rank_seed: bool = False) -> DataLoader: 125 | """Build dataloader. 126 | 127 | The method builds three components: 128 | 129 | - Dataset 130 | - Sampler 131 | - Dataloader 132 | 133 | An example of ``dataloader``:: 134 | 135 | dataloader = dict( 136 | dataset=dict(type='ToyDataset'), 137 | sampler=dict(type='DefaultSampler', shuffle=True), 138 | batch_size=1, 139 | num_workers=9 140 | ) 141 | 142 | Args: 143 | dataloader (DataLoader or dict): A Dataloader object or a dict to 144 | build Dataloader object. If ``dataloader`` is a Dataloader 145 | object, just returns itself. 146 | seed (int, optional): Random seed. Defaults to None. 147 | diff_rank_seed (bool): Whether or not set different seeds to 148 | different ranks. If True, the seed passed to sampler is set 149 | to None, in order to synchronize the seeds used in samplers 150 | across different ranks. 151 | 152 | 153 | Returns: 154 | Dataloader: DataLoader build from ``dataloader_cfg``. 155 | """ 156 | if isinstance(dataloader, DataLoader): 157 | return dataloader 158 | 159 | dataloader_cfg = copy.deepcopy(dataloader) 160 | 161 | # build dataset 162 | dataset_cfg = dataloader_cfg.pop('dataset') 163 | if isinstance(dataset_cfg, dict): 164 | dataset = DATASETS.build(dataset_cfg) 165 | if hasattr(dataset, 'full_init'): 166 | dataset.full_init() 167 | else: 168 | # fallback to raise error in dataloader 169 | # if `dataset_cfg` is not a valid type 170 | dataset = dataset_cfg 171 | 172 | # build sampler 173 | sampler_cfg = dataloader_cfg.pop('sampler') 174 | if isinstance(sampler_cfg, dict): 175 | sampler_seed = None if diff_rank_seed else seed 176 | sampler = DATA_SAMPLERS.build( 177 | sampler_cfg, 178 | default_args=dict(dataset=dataset, seed=sampler_seed)) 179 | else: 180 | # fallback to raise error in dataloader 181 | # if `sampler_cfg` is not a valid type 182 | sampler = sampler_cfg 183 | 184 | # build batch sampler 185 | batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None) 186 | if batch_sampler_cfg is None: 187 | batch_sampler = None 188 | elif isinstance(batch_sampler_cfg, dict): 189 | batch_sampler = DATA_SAMPLERS.build( 190 | batch_sampler_cfg, 191 | default_args=dict( 192 | sampler=sampler, 193 | batch_size=dataloader_cfg.pop('batch_size'))) 194 | else: 195 | # fallback to raise error in dataloader 196 | # if `batch_sampler_cfg` is not a valid type 197 | batch_sampler = batch_sampler_cfg 198 | 199 | # build dataloader 200 | init_fn: Optional[partial] 201 | 202 | if seed is not None: 203 | disable_subprocess_warning = dataloader_cfg.pop( 204 | 'disable_subprocess_warning', False) 205 | assert isinstance( 206 | disable_subprocess_warning, 207 | bool), ('disable_subprocess_warning should be a bool, but got ' 208 | f'{type(disable_subprocess_warning)}') 209 | init_fn = partial( 210 | worker_init_fn, 211 | num_workers=dataloader_cfg.get('num_workers'), 212 | rank=get_rank(), 213 | seed=seed, 214 | disable_subprocess_warning=disable_subprocess_warning) 215 | else: 216 | init_fn = None 217 | 218 | # `persistent_workers` requires pytorch version >= 1.7 219 | if ('persistent_workers' in dataloader_cfg 220 | and digit_version(TORCH_VERSION) < digit_version('1.7.0')): 221 | print_log( 222 | '`persistent_workers` is only available when ' 223 | 'pytorch version >= 1.7', 224 | logger='current', 225 | level=logging.WARNING) 226 | dataloader_cfg.pop('persistent_workers') 227 | 228 | # The default behavior of `collat_fn` in dataloader is to 229 | # merge a list of samples to form a mini-batch of Tensor(s). 230 | # However, in mmengine, if `collate_fn` is not defined in 231 | # dataloader_cfg, `pseudo_collate` will only convert the list of 232 | # samples into a dict without stacking the batch tensor. 233 | collate_fn_cfg = dataloader_cfg.pop('collate_fn', 234 | dict(type='pseudo_collate')) 235 | collate_fn_type = collate_fn_cfg.pop('type') 236 | collate_fn = FUNCTIONS.get(collate_fn_type) 237 | collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore 238 | data_loader = DataLoader( 239 | dataset=dataset, 240 | sampler=sampler if batch_sampler is None else None, 241 | batch_sampler=batch_sampler, 242 | collate_fn=collate_fn, 243 | worker_init_fn=init_fn, 244 | **dataloader_cfg) 245 | return data_loader 246 | 247 | 248 | def MMEngine_build_evaluator(evaluator: Union[Dict, List, Evaluator]) -> Evaluator: 249 | """Build evaluator. 250 | 251 | Examples of ``evaluator``:: 252 | 253 | # evaluator could be a built Evaluator instance 254 | evaluator = Evaluator(metrics=[ToyMetric()]) 255 | 256 | # evaluator can also be a list of dict 257 | evaluator = [ 258 | dict(type='ToyMetric1'), 259 | dict(type='ToyEvaluator2') 260 | ] 261 | 262 | # evaluator can also be a list of built metric 263 | evaluator = [ToyMetric1(), ToyMetric2()] 264 | 265 | # evaluator can also be a dict with key metrics 266 | evaluator = dict(metrics=ToyMetric()) 267 | # metric is a list 268 | evaluator = dict(metrics=[ToyMetric()]) 269 | 270 | Args: 271 | evaluator (Evaluator or dict or list): An Evaluator object or a 272 | config dict or list of config dict used to build an Evaluator. 273 | 274 | Returns: 275 | Evaluator: Evaluator build from ``evaluator``. 276 | """ 277 | if isinstance(evaluator, Evaluator): 278 | return evaluator 279 | elif isinstance(evaluator, dict): 280 | # if `metrics` in dict keys, it means to build customized evalutor 281 | if 'metrics' in evaluator: 282 | evaluator.setdefault('type', 'Evaluator') 283 | return EVALUATOR.build(evaluator) 284 | # otherwise, default evalutor will be built 285 | else: 286 | return Evaluator(evaluator) # type: ignore 287 | elif isinstance(evaluator, list): 288 | # use the default `Evaluator` 289 | return Evaluator(evaluator) # type: ignore 290 | else: 291 | raise TypeError( 292 | 'evaluator should be one of dict, list of dict, and Evaluator' 293 | f', but got {evaluator}') 294 | 295 | 296 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from pathlib import Path 4 | from copy import deepcopy 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import cv2 8 | 9 | from mmrotate.structures import RotatedBoxes 10 | from mmdet.models.utils import samplelist_boxtype2tensor 11 | from mmengine.runner import load_checkpoint 12 | from utils import show_box, show_mask 13 | import matplotlib.pyplot as plt 14 | from mmengine.structures import InstanceData 15 | from data import build_visualizer 16 | 17 | 18 | RESULT_WITH_MASK = True 19 | MAX_BATCH_NUM_PRED = 100 20 | VIS_SCORE_THR = 0.3 21 | 22 | 23 | @torch.no_grad() 24 | def single_sample_step(img_id, data, model, predictor, evaluator, dataloader, device, SHOW): 25 | copied_data = deepcopy(data) # for sam 26 | for item in data.values(): 27 | item[0].to(device) 28 | 29 | # Stage 1 30 | # data['inputs'][0] = torch.flip(data['inputs'][0], dims=[0]) 31 | with torch.no_grad(): 32 | pred_results = model.test_step(data) 33 | pred_r_bboxes = pred_results[0].pred_instances.bboxes 34 | pred_r_bboxes = RotatedBoxes(pred_r_bboxes) 35 | h_bboxes = pred_r_bboxes.convert_to('hbox').tensor 36 | labels = pred_results[0].pred_instances.labels 37 | scores = pred_results[0].pred_instances.scores 38 | 39 | # Stage 2 40 | if len(h_bboxes) == 0: 41 | qualities = h_bboxes[:, 0] 42 | masks = h_bboxes.new_tensor((0, *data['inputs'][0].shape[:2])) 43 | data_samples = data['data_samples'] 44 | r_bboxes = [] 45 | else: 46 | img = copied_data['inputs'][0].permute(1, 2, 0).numpy()[:, :, ::-1] 47 | data_samples = copied_data['data_samples'] 48 | data_sample = data_samples[0] 49 | data_sample = data_sample.to(device=device) 50 | 51 | predictor.set_image(img) 52 | 53 | # Too many predictions may result in OOM, hence, 54 | # we process the predictions in multiple batches. 55 | masks = [] 56 | num_pred = len(h_bboxes) 57 | num_batches = int(np.ceil(num_pred / MAX_BATCH_NUM_PRED)) 58 | for i in range(num_batches): 59 | left_index = i * MAX_BATCH_NUM_PRED 60 | right_index = (i + 1) * MAX_BATCH_NUM_PRED 61 | if i == num_batches - 1: 62 | batch_boxes = h_bboxes[left_index:] 63 | else: 64 | batch_boxes = h_bboxes[left_index: right_index] 65 | 66 | transformed_boxes = predictor.transform.apply_boxes_torch(batch_boxes, img.shape[:2]) 67 | batch_masks, qualities, lr_logits = predictor.predict_torch( 68 | point_coords=None, 69 | point_labels=None, 70 | boxes=transformed_boxes, 71 | multimask_output=False) 72 | batch_masks = batch_masks.squeeze(1).cpu() 73 | masks.extend([*batch_masks]) 74 | masks = torch.stack(masks, dim=0) 75 | r_bboxes = [mask2rbox(mask.numpy()) for mask in masks] 76 | 77 | results_list = get_instancedata_resultlist(r_bboxes, labels, masks, scores) 78 | data_samples = add_pred_to_datasample(results_list, data_samples) 79 | 80 | evaluator.process(data_samples=data_samples, data_batch=data) 81 | 82 | if SHOW: 83 | if len(h_bboxes) != 0 and img_id < 100: 84 | img_name = data_samples[0].img_id 85 | show_results(img, masks, h_bboxes, results_list, img_id, img_name, dataloader) 86 | 87 | return evaluator 88 | 89 | 90 | def mask2rbox(mask): 91 | y, x = np.nonzero(mask) 92 | points = np.stack([x, y], axis=-1) 93 | (cx, cy), (w, h), a = cv2.minAreaRect(points) 94 | r_bbox = np.array([cx, cy, w, h, a / 180 * np.pi]) 95 | return r_bbox 96 | 97 | def show_results(img, masks, h_bboxes, results_list, i, img_name, dataloader): 98 | output_dir = './output_vis/' 99 | Path(output_dir).mkdir(exist_ok=True, parents=True) 100 | 101 | results = results_list[0] 102 | 103 | # vis first stage 104 | # plt.figure(figsize=(10, 10)) 105 | # plt.imshow(img) 106 | # for mask in masks: 107 | # show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) 108 | # for box in h_bboxes: 109 | # show_box(box.cpu().numpy(), plt.gca()) 110 | # plt.axis('off') 111 | # # plt.show() 112 | # plt.savefig(f'./out_mask_{i}.png') 113 | # plt.close() 114 | 115 | # draw rbox with mmrotate 116 | visualizer = build_visualizer() 117 | visualizer.dataset_meta = dataloader.dataset.metainfo 118 | 119 | scores = results.scores 120 | keep_results = results[scores >= VIS_SCORE_THR] 121 | out_img = visualizer._draw_instances( 122 | img, keep_results, 123 | dataloader.dataset.metainfo['classes'], 124 | dataloader.dataset.metainfo['palette'], 125 | box_alpha=0.9, mask_alpha=0.3) 126 | # visualizer.show() 127 | # cv2.imwrite(os.path.join(output_dir, f'out_rbox_{i}.png'), out_img[:, :, ::-1]) 128 | cv2.imwrite(os.path.join(output_dir, f'rdet-sam_{img_name}.png'), 129 | out_img[:, :, ::-1]) 130 | 131 | 132 | def add_pred_to_datasample(results_list, data_samples): 133 | for data_sample, pred_instances in zip(data_samples, results_list): 134 | data_sample.pred_instances = pred_instances 135 | samplelist_boxtype2tensor(data_samples) 136 | return data_samples 137 | 138 | 139 | def get_instancedata_resultlist(r_bboxes, labels, masks, scores): 140 | results = InstanceData() 141 | results.bboxes = RotatedBoxes(r_bboxes) 142 | # results.scores = qualities 143 | results.scores = scores 144 | results.labels = labels 145 | if RESULT_WITH_MASK: 146 | results.masks = masks.cpu().numpy() 147 | results_list = [results] 148 | return results_list 149 | -------------------------------------------------------------------------------- /main_rdet-sam_dota.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | from mmrotate.utils import register_all_modules 5 | from data import build_data_loader, build_evaluator, build_visualizer 6 | 7 | from segment_anything import sam_model_registry, SamPredictor 8 | from mmrotate.registry import MODELS 9 | 10 | from mmengine import Config 11 | from mmengine.runner.checkpoint import _load_checkpoint 12 | 13 | from engine import single_sample_step 14 | 15 | 16 | register_all_modules(init_default_scope=True) 17 | 18 | SHOW = True 19 | FORMAT_ONLY = True 20 | MERGE_PATCHES = True 21 | SET_MIN_BOX = False 22 | 23 | 24 | if __name__ == '__main__': 25 | 26 | sam_checkpoint = r"../segment-anything/checkpoints/sam_vit_b_01ec64.pth" 27 | model_type = "vit_b" 28 | device = "cuda" 29 | 30 | ckpt_path = './rotated_fcos_sep_angle_r50_fpn_1x_dota_le90-0be71a0c.pth' 31 | model_cfg_path = 'configs/rotated_fcos/rotated-fcos-hbox-le90_r50_fpn_1x_dota.py' 32 | # ckpt_path = './rotated_fcos_kld_r50_fpn_1x_dota_le90-ecafdb2b.pth' 33 | # model_cfg_path = 'configs/rotated_fcos/rotated-fcos-le90_r50_fpn_kld_1x_dota.py' 34 | 35 | model_cfg = Config.fromfile(model_cfg_path).model 36 | if SET_MIN_BOX: 37 | model_cfg.test_cfg['min_bbox_size'] = 10 38 | 39 | model = MODELS.build(model_cfg) 40 | model.init_weights() 41 | checkpoint = _load_checkpoint(ckpt_path, map_location='cpu') 42 | sd = checkpoint.get('state_dict', checkpoint) 43 | print(model.load_state_dict(sd)) 44 | 45 | dataloader = build_data_loader('test_without_hbox') 46 | # dataloader = build_data_loader('trainval_with_hbox') 47 | evaluator = build_evaluator(MERGE_PATCHES, FORMAT_ONLY) 48 | evaluator.dataset_meta = dataloader.dataset.metainfo 49 | 50 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 51 | 52 | model = model.to(device=device) 53 | sam = sam.to(device=device) 54 | 55 | predictor = SamPredictor(sam) 56 | 57 | model.eval() 58 | for i, data in tqdm(enumerate(dataloader), total=len(dataloader)): 59 | 60 | evaluator = single_sample_step(i, data, model, predictor, evaluator, dataloader, device, SHOW) 61 | 62 | torch.save(evaluator, './evaluator.pth') 63 | 64 | metrics = evaluator.evaluate(len(dataloader.dataset)) 65 | -------------------------------------------------------------------------------- /main_sam_dota.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import numpy as np 4 | import cv2 5 | from mmrotate.utils import register_all_modules 6 | from data import build_data_loader, build_evaluator, build_visualizer 7 | from utils import show_box, show_mask 8 | import matplotlib.pyplot as plt 9 | from mmengine.structures import InstanceData 10 | from segment_anything import sam_model_registry, SamPredictor 11 | from mmrotate.structures import RotatedBoxes 12 | from mmengine import ProgressBar 13 | from mmdet.models.utils import samplelist_boxtype2tensor 14 | 15 | 16 | register_all_modules(init_default_scope=True) 17 | 18 | SHOW = False 19 | FORMAT_ONLY = False 20 | MERGE_PATCHES = False 21 | 22 | 23 | if __name__ == '__main__': 24 | 25 | 26 | dataloader = build_data_loader('trainval_with_hbox') 27 | evaluator = build_evaluator(MERGE_PATCHES, FORMAT_ONLY) 28 | evaluator.dataset_meta = dataloader.dataset.metainfo 29 | 30 | sam_checkpoint = r"../segment-anything/checkpoints/sam_vit_b_01ec64.pth" 31 | model_type = "vit_b" 32 | device = "cuda" 33 | 34 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 35 | 36 | sam = sam.to(device=device) 37 | 38 | predictor = SamPredictor(sam) 39 | 40 | for i, data in tqdm(enumerate(dataloader), total=len(dataloader)): 41 | 42 | img = data['inputs'][0].permute(1, 2, 0).numpy()[:, :, ::-1] 43 | data_samples = data['data_samples'] 44 | data_sample = data_samples[0] 45 | data_sample = data_sample.to(device=device) 46 | 47 | h_bboxes = data_sample.h_gt_bboxes.tensor.to(device=device) 48 | labels = data_sample.gt_instances.labels.to(device=device) 49 | 50 | r_bboxes = [] 51 | if len(h_bboxes) == 0: 52 | qualities = h_bboxes[:, 0] 53 | masks = h_bboxes.new_tensor((0, *img.shape[:2])) 54 | else: 55 | predictor.set_image(img) 56 | transformed_boxes = predictor.transform.apply_boxes_torch(h_bboxes, img.shape[:2]) 57 | masks, qualities, lr_logits = predictor.predict_torch( 58 | point_coords=None, 59 | point_labels=None, 60 | boxes=transformed_boxes, 61 | multimask_output=False) 62 | masks = masks.squeeze(1) 63 | qualities = qualities.squeeze(-1) 64 | for mask in masks: 65 | y, x = np.nonzero(mask.cpu().numpy()) 66 | points = np.stack([x, y], axis=-1) 67 | (cx, cy), (w, h), a = cv2.minAreaRect(points) 68 | r_bboxes.append(np.array([cx, cy, w, h, a/180*np.pi])) 69 | 70 | results = InstanceData() 71 | results.bboxes = RotatedBoxes(r_bboxes) 72 | results.scores = qualities 73 | results.labels = labels 74 | results.masks = masks.cpu().numpy() 75 | results_list = [results] 76 | 77 | # add_pred_to_datasample 78 | for data_sample, pred_instances in zip(data_samples, results_list): 79 | data_sample.pred_instances = pred_instances 80 | samplelist_boxtype2tensor(data_samples) 81 | 82 | evaluator.process(data_samples=data_samples, data_batch=data) 83 | 84 | if SHOW: 85 | plt.figure(figsize=(10, 10)) 86 | plt.imshow(img) 87 | for mask in masks: 88 | show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) 89 | for box in h_bboxes: 90 | show_box(box.cpu().numpy(), plt.gca()) 91 | plt.axis('off') 92 | # plt.show() 93 | plt.savefig(f'./out_mask_{i}.png') 94 | 95 | # draw rbox with mmrotate 96 | visualizer = build_visualizer() 97 | visualizer.dataset_meta = dataloader.dataset.metainfo 98 | out_img = visualizer._draw_instances( 99 | img, results, 100 | dataloader.dataset.metainfo['classes'], 101 | dataloader.dataset.metainfo['palette']) 102 | # visualizer.show() 103 | cv2.imwrite(f'./out_rbox_{i}.png', out_img[:, :, ::-1]) 104 | 105 | metrics = evaluator.evaluate(len(dataloader.dataset)) 106 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | from mmcv.transforms import BaseTransform 2 | from mmrotate.registry import TRANSFORMS 3 | 4 | 5 | @TRANSFORMS.register_module() 6 | class AddConvertedGTBox(BaseTransform): 7 | """Convert boxes in results to a certain box type.""" 8 | 9 | def __init__(self, box_type_mapping: dict) -> None: 10 | self.box_type_mapping = box_type_mapping 11 | 12 | def transform(self, results: dict) -> dict: 13 | """The transform function.""" 14 | for key, dst_box_type in self.box_type_mapping.items(): 15 | assert key != 'gt_bboxes' 16 | gt_bboxes = results['gt_bboxes'] 17 | results[key] = gt_bboxes.convert_to(dst_box_type) 18 | return results 19 | 20 | def __repr__(self): 21 | repr_str = self.__class__.__name__ 22 | repr_str += f'(box_type_mapping={self.box_type_mapping})' 23 | return repr_str 24 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Stolen from sam 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def show_mask(mask, ax, random_color=False): 7 | if random_color: 8 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 9 | else: 10 | color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) 11 | h, w = mask.shape[-2:] 12 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 13 | ax.imshow(mask_image) 14 | 15 | 16 | def show_points(coords, labels, ax, marker_size=375): 17 | pos_points = coords[labels == 1] 18 | neg_points = coords[labels == 0] 19 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', 20 | s=marker_size, edgecolor='white', linewidth=1.25) 21 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', 22 | s=marker_size, edgecolor='white', linewidth=1.25) 23 | 24 | 25 | def show_box(box, ax): 26 | x0, y0 = box[0], box[1] 27 | w, h = box[2] - box[0], box[3] - box[1] 28 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', 29 | facecolor=(0, 0, 0, 0), lw=2)) 30 | 31 | def show_r_box(box, ax): 32 | x0, y0 = box[0], box[1] 33 | w, h = box[2] - box[0], box[3] - box[1] 34 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', 35 | facecolor=(0, 0, 0, 0), lw=2)) 36 | -------------------------------------------------------------------------------- /visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | import torch 6 | from torch import Tensor 7 | 8 | from mmdet.structures.mask import BitmapMasks, PolygonMasks, bitmap_to_polygon 9 | from mmdet.visualization import DetLocalVisualizer, jitter_color 10 | from mmdet.visualization.palette import _get_adaptive_scales 11 | from mmengine.structures import InstanceData 12 | 13 | from mmrotate.registry import VISUALIZERS 14 | from mmrotate.structures.bbox import QuadriBoxes, RotatedBoxes 15 | from mmrotate.visualization.palette import get_palette 16 | 17 | 18 | @VISUALIZERS.register_module() 19 | class RotLocalVisualizerMaskThenBox(DetLocalVisualizer): 20 | """MMRotate Local Visualizer. 21 | 22 | Args: 23 | name (str): Name of the instance. Defaults to 'visualizer'. 24 | image (np.ndarray, optional): the origin image to draw. The format 25 | should be RGB. Defaults to None. 26 | vis_backends (list, optional): Visual backend config list. 27 | Defaults to None. 28 | save_dir (str, optional): Save file dir for all storage backends. 29 | If it is None, the backend storage will not save any data. 30 | bbox_color (str, tuple(int), optional): Color of bbox lines. 31 | The tuple of color should be in BGR order. Defaults to None. 32 | text_color (str, tuple(int), optional): Color of texts. 33 | The tuple of color should be in BGR order. 34 | Defaults to (200, 200, 200). 35 | mask_color (str, tuple(int), optional): Color of masks. 36 | The tuple of color should be in BGR order. 37 | Defaults to None. 38 | line_width (int, float): The linewidth of lines. 39 | Defaults to 3. 40 | alpha (int, float): The transparency of bboxes or mask. 41 | Defaults to 0.8. 42 | """ 43 | 44 | def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'], 45 | classes: Optional[List[str]], 46 | palette: Optional[List[tuple]], 47 | box_alpha=None, mask_alpha=None) -> np.ndarray: 48 | """Draw instances of GT or prediction. 49 | 50 | Args: 51 | image (np.ndarray): The image to draw. 52 | instances (:obj:`InstanceData`): Data structure for 53 | instance-level annotations or predictions. 54 | classes (List[str], optional): Category information. 55 | palette (List[tuple], optional): Palette information 56 | corresponding to the category. 57 | Returns: 58 | np.ndarray: the drawn image which channel is RGB. 59 | """ 60 | if box_alpha is None: 61 | box_alpha = self.alpha 62 | if mask_alpha is None: 63 | mask_alpha = self.alpha 64 | 65 | self.set_image(image) 66 | 67 | if 'masks' in instances: 68 | labels = instances.labels 69 | masks = instances.masks 70 | if isinstance(masks, torch.Tensor): 71 | masks = masks.numpy() 72 | elif isinstance(masks, (PolygonMasks, BitmapMasks)): 73 | masks = masks.to_ndarray() 74 | 75 | masks = masks.astype(bool) 76 | 77 | max_label = int(max(labels) if len(labels) > 0 else 0) 78 | mask_color = palette if self.mask_color is None \ 79 | else self.mask_color 80 | mask_palette = get_palette(mask_color, max_label + 1) 81 | colors = [jitter_color(mask_palette[label]) for label in labels] 82 | text_palette = get_palette(self.text_color, max_label + 1) 83 | text_colors = [text_palette[label] for label in labels] 84 | 85 | polygons = [] 86 | for i, mask in enumerate(masks): 87 | contours, _ = bitmap_to_polygon(mask) 88 | polygons.extend(contours) 89 | self.draw_polygons(polygons, edge_colors='w', alpha=mask_alpha) 90 | self.draw_binary_masks(masks, colors=colors, alphas=mask_alpha) 91 | 92 | if 'bboxes' in instances: 93 | bboxes = instances.bboxes 94 | labels = instances.labels 95 | 96 | max_label = int(max(labels) if len(labels) > 0 else 0) 97 | text_palette = get_palette(self.text_color, max_label + 1) 98 | text_colors = [text_palette[label] for label in labels] 99 | 100 | bbox_color = palette if self.bbox_color is None \ 101 | else self.bbox_color 102 | bbox_palette = get_palette(bbox_color, max_label + 1) 103 | colors = [bbox_palette[label] for label in labels] 104 | 105 | if isinstance(bboxes, Tensor): 106 | if bboxes.size(-1) == 5: 107 | bboxes = RotatedBoxes(bboxes) 108 | elif bboxes.size(-1) == 8: 109 | bboxes = QuadriBoxes(bboxes) 110 | else: 111 | raise TypeError( 112 | 'Require the shape of `bboxes` to be (n, 5) ' 113 | 'or (n, 8), but get `bboxes` with shape being ' 114 | f'{bboxes.shape}.') 115 | 116 | bboxes = bboxes.cpu() 117 | polygons = bboxes.convert_to('qbox').tensor 118 | polygons = polygons.reshape(-1, 4, 2) 119 | polygons = [p for p in polygons] 120 | self.draw_polygons( 121 | polygons, 122 | edge_colors=colors, 123 | alpha=box_alpha, 124 | line_widths=self.line_width) 125 | 126 | positions = bboxes.centers + self.line_width 127 | scales = _get_adaptive_scales(bboxes.areas) 128 | 129 | for i, (pos, label) in enumerate(zip(positions, labels)): 130 | label_text = classes[ 131 | label] if classes is not None else f'class {label}' 132 | if 'scores' in instances: 133 | score = round(float(instances.scores[i]) * 100, 1) 134 | label_text += f': {score}' 135 | 136 | self.draw_texts( 137 | label_text, 138 | pos, 139 | colors=text_colors[i], 140 | font_sizes=int(13 * scales[i]), 141 | bboxes=[{ 142 | 'facecolor': 'black', 143 | 'alpha': 0.8, 144 | 'pad': 0.7, 145 | 'edgecolor': 'none' 146 | }]) 147 | return self.get_image() 148 | --------------------------------------------------------------------------------