├── .gitignore ├── README.md ├── configs ├── _base_ │ ├── datasets │ │ ├── coco_detection_default.py │ │ ├── coco_detection_rgb.py │ │ ├── disl_coco_detection_default.py │ │ └── disl_coco_detection_rgb.py │ ├── models │ │ ├── faster_rcnn_swin_small.py │ │ ├── mixed_cascade_rcnn_swin_small.py │ │ ├── mixed_faster_rcnn_r50_caffe_fpn.py │ │ └── mixed_faster_rcnn_swin_small.py │ ├── runtimes │ │ └── default_runtime.py │ └── schedules │ │ ├── sgd_schedule_1x.py │ │ └── swin_adamw_schedule_1x.py ├── baseline │ ├── faster_rcnn_r50_caffe_fpn_coco_025x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_05x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_1x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_2x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_4x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_5x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_8x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_strong_025x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_strong_05x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_strong_1x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_strong_2x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_strong_4x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_strong_5x.py │ ├── faster_rcnn_r50_caffe_fpn_coco_strong_8x.py │ ├── faster_rcnn_swin_small_fpn_coco_1x.py │ ├── faster_rcnn_swin_small_fpn_coco_2x.py │ ├── faster_rcnn_swin_small_fpn_coco_3x.py │ └── faster_rcnn_swin_small_fpn_coco_4x.py ├── for_print.py └── ours │ ├── mixed_cascade_rcnn_swin_small_fpn_coco_4x.py │ ├── mixed_faster_rcnn_r50_caffe_fpn_coco_8x.py │ └── mixed_faster_rcnn_swin_small_fpn_coco_4x.py ├── requirements.txt ├── setup.py ├── src ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ └── geometric_transform.py ├── datasets │ ├── __init__.py │ ├── backends │ │ ├── __init__.py │ │ ├── _utils.py │ │ ├── mem_backends.py │ │ └── zip_backends.py │ ├── builder.py │ ├── dataset_wrappers.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── formating.py │ │ ├── immediate_transform.py │ │ ├── loading.py │ │ └── transform.py │ └── samplers │ │ ├── __init__.py │ │ ├── balance_sampler.py │ │ ├── distributed_sampler.py │ │ └── group_sampler.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ └── swing_transformer.py │ ├── detectors │ │ ├── __init__.py │ │ ├── semi_base.py │ │ ├── semi_two_stage.py │ │ ├── student_wrapper │ │ │ ├── __init__.py │ │ │ └── two_stage_student.py │ │ └── teacher_wrapper │ │ │ ├── __init__.py │ │ │ └── two_stage_teacher.py │ └── utils │ │ ├── __init__.py │ │ ├── cascade_wrappers.py │ │ ├── semi_wrapper.py │ │ └── standard_wrappers.py └── utils │ ├── __init__.py │ ├── config.py │ ├── debug_utils.py │ ├── file_utils.py │ ├── hooks.py │ ├── log_utils.py │ ├── structure_utils.py │ └── web_utils.py └── tools ├── dist_test.sh ├── dist_train.sh ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | #vs code 2 | .history/ 3 | .vscode 4 | .idea 5 | .history 6 | .DS_Store 7 | #python 8 | __pycache__/ 9 | */__pycache__ 10 | *.egg-info 11 | build 12 | #lib 13 | tests 14 | thirdparty 15 | .history 16 | #develop 17 | wandb 18 | data/ 19 | data 20 | *.pkl 21 | *.pkl.json 22 | *.log.json 23 | work_dirs/ 24 | figures 25 | 26 | # Pytorch 27 | *.pth 28 | *.py~ 29 | *.sh~ 30 | launch.py 31 | 32 | #nvidia 33 | *.qdrep 34 | *.sqlite -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixTraining 2 | Official codes for our NeurIPS 2021 paper "Bootstrap Your Object Detector via Mixed Training" ([paper](https://proceedings.neurips.cc/paper/2021/file/5e15fb59326e7a9c3d6558ca74621683-Paper.pdf)). 3 | ## Main Results: 4 | 5 | |Model|mAP|AP50|AP75|APs|APm|APl|Link| 6 | |----|-------|-----|----|---|---|---|---| 7 | |[mixed_faster_rcnn_swin_small](configs/ours/mixed_faster_rcnn_swin_small_fpn_coco_4x.py)|0.503|0.716| 0.552| 0.347 |0.540| 0.659|[Google](https://drive.google.com/file/d/1dbxJybYigdL8VOm99q7vI7MKHNMDjLBo/view?usp=sharing)| 8 | |[mixed_cascade_rcnn_swin_small](configs/ours/mixed_cascade_rcnn_swin_small_fpn_coco_4x.py)|0.528|0.721|0.580|0.366|0.568|0.686|[Google](https://drive.google.com/file/d/14VVVml9EPqdA1g4vGBnpuI4sWpeqxX3U/view?usp=sharing)| 9 | # Implementation 10 | - ### Enviroment 11 | ``` 12 | torch==1.6.0 13 | torchvision==0.7.0 14 | wandb==0.10.26 15 | apex==0.1 16 | mmdet==2.11.0 17 | mmcv-full==1.3.0 18 | ``` 19 | Install required packages with 20 | ``` 21 | cd ${your_code_dir} 22 | mkdir -p thirdparty 23 | git clone https://github.com/open-mmlab/mmdetection.git thirdparty/mmdetection 24 | cd thirdparty/mmdetection && git checkout v2.11.0 && python -m pip install -e . 25 | python -m pip install -e . 26 | mkdir -p data 27 | ln -s ${your_coco_path} data/coco 28 | ``` 29 | - ### For testing 30 | ```shell 31 | bash tools/dist_test.sh ${selected_config} 8 32 | ``` 33 | where `selected_config` is one of provided script under the `config/bvr` folder. 34 | - ### For training 35 | ```shell 36 | bash tools/dist_train.sh ${selected_config} 8 37 | ``` 38 | where `selected_config` is one of provided script under the `config/bvr` folder. 39 | - ### For more dataset 40 | We have not trained or tested on other dataset. If you would like to use it on other data, please refer to [mmdetection](https://github.com/open-mmlab/mmdetection/blob/master/docs/1_exist_data_model.md). 41 | ## Citing us 42 | 43 | ``` 44 | @inproceedings{xu2021bootstrap, 45 | title={Bootstrap Your Object Detector via Mixed Training}, 46 | author={Xu, Mengde and Zhang, Zheng and Wei, Fangyun and Lin, Yutong and Cao, Yue and Lin, Stephen and Hu, Han and Bai, Xiang}, 47 | journal={Advances in Neural Information Processing Systems}, 48 | volume={34}, 49 | year={2021} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /configs/_base_/datasets/coco_detection_default.py: -------------------------------------------------------------------------------- 1 | # def temp 2 | color_transform = dict( 3 | type="RandomApply", 4 | policies=[ 5 | dict(type="Identity"), 6 | dict(type="Jitter", contrast=1.0), 7 | dict(type="Jitter", brightness=1.0), 8 | dict(type="Jitter", hue=1.0), 9 | dict(type="Equalize"), 10 | dict(type="AutoContrast"), 11 | dict(type="PosterizeV1"), 12 | dict(type="RandomGrayScale"), 13 | dict(type="SolarizeV1"), 14 | ], 15 | ) 16 | img_norm_cfg = dict(mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 17 | scale_cfg = dict( 18 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 19 | ) 20 | data_root = "data/coco" 21 | sup_set = "train2017" 22 | val_set = "val2017" 23 | test_set = "val2017" 24 | # end def 25 | 26 | sup = [ 27 | dict(type="LoadImageFromFile"), 28 | dict(type="LoadAnnotations", with_bbox=True), 29 | color_transform, 30 | dict(type="Resize", **scale_cfg), 31 | dict(type="RandomFlip", flip_ratio=0.5), 32 | dict(type="Normalize", **img_norm_cfg), 33 | dict(type="Pad", size_divisor=32), 34 | dict(type="DefaultFormatBundle"), 35 | dict(type="ExtraAttrs", tag="sup"), 36 | dict( 37 | type="CollectV1", 38 | keys=["img", "gt_bboxes", "gt_labels"], 39 | extra_meta_keys=["tag"], 40 | ), 41 | ] 42 | test_pipeline = [ 43 | dict(type="LoadImageFromFile"), 44 | dict( 45 | type="MultiScaleFlipAug", 46 | img_scale=(1333, 800), 47 | flip=False, 48 | transforms=[ 49 | dict(type="Resize", keep_ratio=True), 50 | dict(type="RandomFlip"), 51 | dict(type="Normalize", **img_norm_cfg), 52 | dict(type="Pad", size_divisor=32), 53 | dict(type="ImageToTensor", keys=["img"]), 54 | dict(type="Collect", keys=["img"]), 55 | ], 56 | ), 57 | ] 58 | 59 | data = dict( 60 | samples_per_gpu=4, 61 | workers_per_gpu=4, 62 | train=dict( 63 | type="CocoDataset", 64 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 65 | img_prefix="{data_root}/{sup_set}/", 66 | pipeline=sup, 67 | ), 68 | val=dict( 69 | type="CocoDataset", 70 | ann_file="{data_root}/annotations/instances_{val_set}.json", 71 | img_prefix="{data_root}/{val_set}", 72 | pipeline=test_pipeline, 73 | ), 74 | test=dict( 75 | type="CocoDataset", 76 | ann_file="{data_root}/annotations/instances_{test_set}.json", 77 | img_prefix="{data_root}/{test_set}", 78 | pipeline=test_pipeline, 79 | ), 80 | ) 81 | 82 | evaluation = dict(gpu_collect=True, metric=["bbox"]) 83 | -------------------------------------------------------------------------------- /configs/_base_/datasets/coco_detection_rgb.py: -------------------------------------------------------------------------------- 1 | # def temp 2 | color_transform = dict( 3 | type="RandomApply", 4 | policies=[ 5 | dict(type="Identity"), 6 | dict(type="Jitter", contrast=1.0), 7 | dict(type="Jitter", brightness=1.0), 8 | dict(type="Jitter", hue=1.0), 9 | dict(type="Equalize"), 10 | dict(type="AutoContrast"), 11 | dict(type="PosterizeV1"), 12 | dict(type="RandomGrayScale"), 13 | dict(type="SolarizeV1"), 14 | ], 15 | ) 16 | img_norm_cfg = dict( 17 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 18 | ) 19 | scale_cfg = dict( 20 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 21 | ) 22 | data_root = "data/coco" 23 | sup_set = "train2017" 24 | val_set = "val2017" 25 | test_set = "val2017" 26 | # end def 27 | 28 | sup = [ 29 | dict(type="LoadImageFromFile"), 30 | dict(type="LoadAnnotations", with_bbox=True), 31 | color_transform, 32 | dict(type="Resize", **scale_cfg), 33 | dict(type="RandomFlip", flip_ratio=0.5), 34 | dict(type="Normalize", **img_norm_cfg), 35 | dict(type="Pad", size_divisor=32), 36 | dict(type="DefaultFormatBundle"), 37 | dict(type="ExtraAttrs", tag="sup"), 38 | dict( 39 | type="CollectV1", 40 | keys=["img", "gt_bboxes", "gt_labels"], 41 | extra_meta_keys=["tag"], 42 | ), 43 | ] 44 | test_pipeline = [ 45 | dict(type="LoadImageFromFile"), 46 | dict( 47 | type="MultiScaleFlipAug", 48 | img_scale=(1333, 800), 49 | flip=False, 50 | transforms=[ 51 | dict(type="Resize", keep_ratio=True), 52 | dict(type="RandomFlip"), 53 | dict(type="Normalize", **img_norm_cfg), 54 | dict(type="Pad", size_divisor=32), 55 | dict(type="ImageToTensor", keys=["img"]), 56 | dict(type="Collect", keys=["img"]), 57 | ], 58 | ), 59 | ] 60 | 61 | data = dict( 62 | samples_per_gpu=4, 63 | workers_per_gpu=4, 64 | train=dict( 65 | type="CocoDataset", 66 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 67 | img_prefix="{data_root}/{sup_set}/", 68 | pipeline=sup, 69 | ), 70 | val=dict( 71 | type="CocoDataset", 72 | ann_file="{data_root}/annotations/instances_{val_set}.json", 73 | img_prefix="{data_root}/{val_set}", 74 | pipeline=test_pipeline, 75 | ), 76 | test=dict( 77 | type="CocoDataset", 78 | ann_file="{data_root}/annotations/instances_{test_set}.json", 79 | img_prefix="{data_root}/{test_set}", 80 | pipeline=test_pipeline, 81 | ), 82 | ) 83 | 84 | evaluation = dict(gpu_collect=True, metric=["bbox"]) 85 | -------------------------------------------------------------------------------- /configs/_base_/datasets/disl_coco_detection_default.py: -------------------------------------------------------------------------------- 1 | # def temp 2 | 3 | color_transform = dict( 4 | type="RandomApply", 5 | policies=[ 6 | dict(type="Identity"), 7 | dict(type="Jitter", contrast=1.0), 8 | dict(type="Jitter", brightness=1.0), 9 | dict(type="Jitter", hue=1.0), 10 | dict(type="Equalize"), 11 | dict(type="AutoContrast"), 12 | dict(type="PosterizeV1"), 13 | dict(type="RandomGrayScale"), 14 | dict(type="SolarizeV1"), 15 | ], 16 | ) 17 | geo_transform = dict( 18 | type="RandomApply", 19 | policies=[ 20 | dict(type="LazyGeoIdentity"), 21 | dict(type="LazyTranslate", max_translate_offset=0.2, direction="horizontal",), 22 | dict(type="LazyTranslate", max_translate_offset=0.2, direction="veritical",), 23 | dict(type="LazyRotate"), 24 | dict(type="LazyShear"), 25 | ], 26 | ) 27 | img_norm_cfg = dict(mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 28 | scale_cfg = dict( 29 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 30 | ) 31 | data_root = "data/coco" 32 | sup_set = "train2017" 33 | unsup_set = "train2017" 34 | val_set = "val2017" 35 | test_set = "test2017" 36 | # end def 37 | 38 | sup = [ 39 | dict(type="LoadImageFromFile"), 40 | dict(type="LoadAnnotations", with_bbox=True), 41 | color_transform, 42 | dict(type="Resize", **scale_cfg), 43 | dict(type="RandomFlip", flip_ratio=0.5), 44 | dict(type="Normalize", **img_norm_cfg), 45 | dict(type="Pad", size_divisor=32), 46 | dict(type="DefaultFormatBundle"), 47 | dict(type="ExtraAttrs", tag="sup"), 48 | dict( 49 | type="CollectV1", 50 | keys=["img", "gt_bboxes", "gt_labels"], 51 | extra_meta_keys=["tag"], 52 | ), 53 | ] 54 | unsup_strong = [ 55 | color_transform, 56 | dict(type="LazyResize", **scale_cfg), 57 | dict(type="LazyRandomFlip"), 58 | geo_transform, 59 | dict(type="Normalize", **img_norm_cfg), 60 | dict(type="TransformImage"), 61 | dict( 62 | type="CutOut", 63 | n_holes=(1, 5), 64 | cutout_ratio=[ 65 | (0.05, 0.05), 66 | (0.75, 0.75), 67 | (0.1, 0.1), 68 | (0.125, 0.125), 69 | (0.15, 0.15), 70 | (0.175, 0.175), 71 | (0.2, 0.2), 72 | ], 73 | fill_in=(0, 0, 0), 74 | ), 75 | dict(type="Pad", size_divisor=32), 76 | dict(type="DefaultFormatBundle"), 77 | dict(type="ExtraAttrs", tag="unsup_student"), 78 | dict( 79 | type="CollectV1", 80 | keys=["img", "gt_bboxes", "gt_labels"], 81 | extra_meta_keys=["tag", "trans_matrix"], 82 | ), 83 | ] 84 | unsup_weak = [ 85 | dict(type="LazyResize", **scale_cfg), 86 | dict(type="LazyRandomFlip"), 87 | dict(type="Normalize", **img_norm_cfg), 88 | dict(type="TransformImage"), 89 | dict(type="Pad", size_divisor=32), 90 | dict(type="DefaultFormatBundle"), 91 | dict(type="ExtraAttrs", tag="unsup_teacher"), 92 | dict( 93 | type="CollectV1", 94 | keys=["img", "gt_bboxes", "gt_labels"], 95 | extra_meta_keys=["tag", "trans_matrix"], 96 | ), 97 | ] 98 | unsup = [ 99 | dict(type="LoadImageFromFile"), 100 | dict(type="LoadAnnotations", with_bbox=True), 101 | dict( 102 | type="MultiBranch", 103 | policies=[ 104 | dict(type="Compose", transforms=unsup_strong), 105 | dict(type="Compose", transforms=unsup_weak), 106 | ], 107 | ), 108 | ] 109 | test_pipeline = [ 110 | dict(type="LoadImageFromFile"), 111 | dict( 112 | type="MultiScaleFlipAug", 113 | img_scale=(1333, 800), 114 | flip=False, 115 | transforms=[ 116 | dict(type="Resize", keep_ratio=True), 117 | dict(type="RandomFlip"), 118 | dict(type="Normalize", **img_norm_cfg), 119 | dict(type="Pad", size_divisor=32), 120 | dict(type="ImageToTensor", keys=["img"]), 121 | dict(type="Collect", keys=["img"]), 122 | ], 123 | ), 124 | ] 125 | 126 | data = dict( 127 | samples_per_gpu=4, 128 | workers_per_gpu=4, 129 | train=dict( 130 | type="MultiSourceDataset", 131 | datasets=[ 132 | dict( 133 | type="CocoDataset", 134 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 135 | img_prefix="{data_root}/{sup_set}/", 136 | pipeline=sup, 137 | ), 138 | dict( 139 | type="CocoDataset", 140 | ann_file="{data_root}/annotations/instances_{unsup_set}.json", 141 | img_prefix="{data_root}/{unsup_set}/", 142 | filter_empty_gt=False, 143 | pipeline=unsup, 144 | ), 145 | ], 146 | sample_ratio=[0.5, 0.5], 147 | ), 148 | val=dict( 149 | type="CocoDataset", 150 | ann_file="{data_root}/annotations/instances_{val_set}.json", 151 | img_prefix="{data_root}/{val_set}", 152 | pipeline=test_pipeline, 153 | ), 154 | test=dict( 155 | type="CocoDataset", 156 | ann_file="{data_root}/annotations/instances_{test_set}.json", 157 | img_prefix="{data_root}/{test_set}", 158 | pipeline=test_pipeline, 159 | ), 160 | sampler=dict( 161 | train=dict( 162 | type="SemiBalanceSampler", 163 | epoch_length=7330, 164 | by_prob=True, 165 | at_least_one=True, 166 | ) 167 | ), 168 | loader=dict(train=None), 169 | ) 170 | 171 | evaluation = dict(gpu_collect=True, metric=["bbox"]) 172 | -------------------------------------------------------------------------------- /configs/_base_/datasets/disl_coco_detection_rgb.py: -------------------------------------------------------------------------------- 1 | # def temp 2 | color_transform = dict( 3 | type="RandomApply", 4 | policies=[ 5 | dict(type="Identity"), 6 | dict(type="Jitter", contrast=1.0), 7 | dict(type="Jitter", brightness=1.0), 8 | dict(type="Jitter", hue=1.0), 9 | dict(type="Equalize"), 10 | dict(type="AutoContrast"), 11 | dict(type="PosterizeV1"), 12 | dict(type="RandomGrayScale"), 13 | dict(type="SolarizeV1"), 14 | ], 15 | ) 16 | geo_transform = dict( 17 | type="RandomApply", 18 | policies=[ 19 | dict(type="LazyGeoIdentity"), 20 | dict(type="LazyTranslate", max_translate_offset=0.2, direction="horizontal",), 21 | dict(type="LazyTranslate", max_translate_offset=0.2, direction="veritical",), 22 | dict(type="LazyRotate"), 23 | dict(type="LazyShear"), 24 | ], 25 | ) 26 | img_norm_cfg = dict( 27 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 28 | ) 29 | scale_cfg = dict( 30 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 31 | ) 32 | data_root = "data/coco" 33 | sup_set = "train2017" 34 | unsup_set = "train2017" 35 | val_set = "val2017" 36 | test_set = "test2017" 37 | # end def 38 | 39 | sup = [ 40 | dict(type="LoadImageFromFile"), 41 | dict(type="LoadAnnotations", with_bbox=True), 42 | color_transform, 43 | dict(type="Resize", **scale_cfg), 44 | dict(type="RandomFlip", flip_ratio=0.5), 45 | dict(type="Normalize", **img_norm_cfg), 46 | dict(type="Pad", size_divisor=32), 47 | dict(type="DefaultFormatBundle"), 48 | dict(type="ExtraAttrs", tag="sup"), 49 | dict( 50 | type="CollectV1", 51 | keys=["img", "gt_bboxes", "gt_labels"], 52 | extra_meta_keys=["tag"], 53 | ), 54 | ] 55 | unsup_strong = [ 56 | color_transform, 57 | dict(type="LazyResize", **scale_cfg), 58 | dict(type="LazyRandomFlip"), 59 | geo_transform, 60 | dict(type="Normalize", **img_norm_cfg), 61 | dict(type="TransformImage"), 62 | dict( 63 | type="CutOut", 64 | n_holes=(1, 5), 65 | cutout_ratio=[ 66 | (0.05, 0.05), 67 | (0.75, 0.75), 68 | (0.1, 0.1), 69 | (0.125, 0.125), 70 | (0.15, 0.15), 71 | (0.175, 0.175), 72 | (0.2, 0.2), 73 | ], 74 | fill_in=(0, 0, 0), 75 | ), 76 | dict(type="Pad", size_divisor=32), 77 | dict(type="DefaultFormatBundle"), 78 | dict(type="ExtraAttrs", tag="unsup_student"), 79 | dict( 80 | type="CollectV1", 81 | keys=["img", "gt_bboxes", "gt_labels"], 82 | extra_meta_keys=["tag", "trans_matrix"], 83 | ), 84 | ] 85 | unsup_weak = [ 86 | dict(type="LazyResize", **scale_cfg), 87 | dict(type="LazyRandomFlip"), 88 | dict(type="Normalize", **img_norm_cfg), 89 | dict(type="TransformImage"), 90 | dict(type="Pad", size_divisor=32), 91 | dict(type="DefaultFormatBundle"), 92 | dict(type="ExtraAttrs", tag="unsup_teacher"), 93 | dict( 94 | type="CollectV1", 95 | keys=["img", "gt_bboxes", "gt_labels"], 96 | extra_meta_keys=["tag", "trans_matrix"], 97 | ), 98 | ] 99 | unsup = [ 100 | dict(type="LoadImageFromFile"), 101 | dict(type="LoadAnnotations", with_bbox=True), 102 | dict( 103 | type="MultiBranch", 104 | policies=[ 105 | dict(type="Compose", transforms=unsup_strong), 106 | dict(type="Compose", transforms=unsup_weak), 107 | ], 108 | ), 109 | ] 110 | test_pipeline = [ 111 | dict(type="LoadImageFromFile"), 112 | dict( 113 | type="MultiScaleFlipAug", 114 | img_scale=(1333, 800), 115 | flip=False, 116 | transforms=[ 117 | dict(type="Resize", keep_ratio=True), 118 | dict(type="RandomFlip"), 119 | dict(type="Normalize", **img_norm_cfg), 120 | dict(type="Pad", size_divisor=32), 121 | dict(type="ImageToTensor", keys=["img"]), 122 | dict(type="Collect", keys=["img"]), 123 | ], 124 | ), 125 | ] 126 | 127 | data = dict( 128 | samples_per_gpu=4, 129 | workers_per_gpu=4, 130 | train=dict( 131 | type="MultiSourceDataset", 132 | datasets=[ 133 | dict( 134 | type="CocoDataset", 135 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 136 | img_prefix="{data_root}/{sup_set}/", 137 | pipeline=sup, 138 | ), 139 | dict( 140 | type="CocoDataset", 141 | ann_file="{data_root}/annotations/instances_{unsup_set}.json", 142 | img_prefix="{data_root}/{unsup_set}/", 143 | filter_empty_gt=False, 144 | pipeline=unsup, 145 | ), 146 | ], 147 | sample_ratio=[0.5, 0.5], 148 | ), 149 | val=dict( 150 | type="CocoDataset", 151 | ann_file="{data_root}/annotations/instances_{val_set}.json", 152 | img_prefix="{data_root}/{val_set}", 153 | pipeline=test_pipeline, 154 | ), 155 | test=dict( 156 | type="CocoDataset", 157 | ann_file="{data_root}/annotations/instances_{test_set}.json", 158 | img_prefix="{data_root}/{test_set}", 159 | pipeline=test_pipeline, 160 | ), 161 | sampler=dict( 162 | train=dict( 163 | type="SemiBalanceSampler", 164 | epoch_length=7330, 165 | by_prob=True, 166 | at_least_one=True, 167 | ) 168 | ), 169 | loader=dict(train=None), 170 | ) 171 | 172 | evaluation = dict(gpu_collect=True, metric=["bbox"]) 173 | -------------------------------------------------------------------------------- /configs/_base_/models/faster_rcnn_swin_small.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type="FasterRCNN", 3 | pretrained="https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth", 4 | backbone=dict( 5 | type="SwinTransformer", 6 | embed_dim=96, 7 | depths=[2, 2, 18, 2], 8 | num_heads=[3, 6, 12, 24], 9 | window_size=7, 10 | mlp_ratio=4.0, 11 | qkv_bias=True, 12 | qk_scale=None, 13 | drop_rate=0.0, 14 | attn_drop_rate=0.0, 15 | drop_path_rate=0.2, 16 | ape=False, 17 | patch_norm=True, 18 | out_indices=(0, 1, 2, 3), 19 | use_checkpoint=False, 20 | ), 21 | neck=dict( 22 | type="FPN", in_channels=[96, 192, 384, 768], out_channels=256, num_outs=5 23 | ), 24 | rpn_head=dict( 25 | type="RPNHead", 26 | in_channels=256, 27 | feat_channels=256, 28 | anchor_generator=dict( 29 | type="AnchorGenerator", 30 | scales=[8], 31 | ratios=[0.5, 1.0, 2.0], 32 | strides=[4, 8, 16, 32, 64], 33 | ), 34 | bbox_coder=dict( 35 | type="DeltaXYWHBBoxCoder", 36 | target_means=[0.0, 0.0, 0.0, 0.0], 37 | target_stds=[1.0, 1.0, 1.0, 1.0], 38 | ), 39 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 40 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 41 | ), 42 | roi_head=dict( 43 | type="StandardRoIHead", 44 | bbox_roi_extractor=dict( 45 | type="SingleRoIExtractor", 46 | roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), 47 | out_channels=256, 48 | featmap_strides=[4, 8, 16, 32], 49 | ), 50 | bbox_head=dict( 51 | type="Shared2FCBBoxHead", 52 | in_channels=256, 53 | fc_out_channels=1024, 54 | roi_feat_size=7, 55 | num_classes=80, 56 | bbox_coder=dict( 57 | type="DeltaXYWHBBoxCoder", 58 | target_means=[0.0, 0.0, 0.0, 0.0], 59 | target_stds=[0.1, 0.1, 0.2, 0.2], 60 | ), 61 | reg_class_agnostic=False, 62 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), 63 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 64 | ), 65 | ), 66 | train_cfg=dict( 67 | rpn=dict( 68 | assigner=dict( 69 | type="MaxIoUAssigner", 70 | pos_iou_thr=0.7, 71 | neg_iou_thr=0.3, 72 | min_pos_iou=0.3, 73 | match_low_quality=True, 74 | ignore_iof_thr=-1, 75 | ), 76 | sampler=dict( 77 | type="RandomSampler", 78 | num=256, 79 | pos_fraction=0.5, 80 | neg_pos_ub=-1, 81 | add_gt_as_proposals=False, 82 | ), 83 | allowed_border=-1, 84 | pos_weight=-1, 85 | debug=False, 86 | ), 87 | rpn_proposal=dict( 88 | nms_pre=2000, 89 | max_per_img=1000, 90 | nms=dict(type="nms", iou_threshold=0.7), 91 | min_bbox_size=0, 92 | ), 93 | rcnn=dict( 94 | assigner=dict( 95 | type="MaxIoUAssigner", 96 | pos_iou_thr=0.5, 97 | neg_iou_thr=0.5, 98 | min_pos_iou=0.5, 99 | match_low_quality=False, 100 | ignore_iof_thr=-1, 101 | ), 102 | sampler=dict( 103 | type="RandomSampler", 104 | num=512, 105 | pos_fraction=0.25, 106 | neg_pos_ub=-1, 107 | add_gt_as_proposals=True, 108 | ), 109 | pos_weight=-1, 110 | debug=False, 111 | ), 112 | ), 113 | test_cfg=dict( 114 | rpn=dict( 115 | nms_pre=1000, 116 | max_per_img=1000, 117 | nms=dict(type="nms", iou_threshold=0.7), 118 | min_bbox_size=0, 119 | ), 120 | rcnn=dict( 121 | score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100 122 | ), 123 | ), 124 | ) 125 | -------------------------------------------------------------------------------- /configs/_base_/models/mixed_cascade_rcnn_swin_small.py: -------------------------------------------------------------------------------- 1 | student_cfg = dict( 2 | type="CascadeRCNN", 3 | pretrained="https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth", 4 | backbone=dict( 5 | type="SwinTransformer", 6 | embed_dim=96, 7 | depths=[2, 2, 18, 2], 8 | num_heads=[3, 6, 12, 24], 9 | window_size=7, 10 | mlp_ratio=4.0, 11 | qkv_bias=True, 12 | qk_scale=None, 13 | drop_rate=0.0, 14 | attn_drop_rate=0.0, 15 | drop_path_rate=0.2, 16 | ape=False, 17 | patch_norm=True, 18 | out_indices=(0, 1, 2, 3), 19 | use_checkpoint=False, 20 | ), 21 | neck=dict( 22 | type="FPN", in_channels=[96, 192, 384, 768], out_channels=256, num_outs=5 23 | ), 24 | rpn_head=dict( 25 | type="RPNHead", 26 | in_channels=256, 27 | feat_channels=256, 28 | anchor_generator=dict( 29 | type="AnchorGenerator", 30 | scales=[8], 31 | ratios=[0.5, 1.0, 2.0], 32 | strides=[4, 8, 16, 32, 64], 33 | ), 34 | bbox_coder=dict( 35 | type="DeltaXYWHBBoxCoder", 36 | target_means=[0.0, 0.0, 0.0, 0.0], 37 | target_stds=[1.0, 1.0, 1.0, 1.0], 38 | ), 39 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 40 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 41 | ), 42 | roi_head=dict( 43 | type="CascadeRoIHead", 44 | num_stages=3, 45 | stage_loss_weights=[1, 0.5, 0.25], 46 | bbox_roi_extractor=dict( 47 | type="SingleRoIExtractor", 48 | roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), 49 | out_channels=256, 50 | featmap_strides=[4, 8, 16, 32], 51 | ), 52 | bbox_head=[ 53 | dict( 54 | type="Shared2FCBBoxHead", 55 | in_channels=256, 56 | fc_out_channels=1024, 57 | roi_feat_size=7, 58 | num_classes=80, 59 | bbox_coder=dict( 60 | type="DeltaXYWHBBoxCoder", 61 | target_means=[0.0, 0.0, 0.0, 0.0], 62 | target_stds=[0.1, 0.1, 0.2, 0.2], 63 | ), 64 | reg_class_agnostic=True, 65 | loss_cls=dict( 66 | type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0 67 | ), 68 | loss_bbox=dict(type="SmoothL1Loss", beta=1.0, loss_weight=1.0), 69 | ), 70 | dict( 71 | type="Shared2FCBBoxHead", 72 | in_channels=256, 73 | fc_out_channels=1024, 74 | roi_feat_size=7, 75 | num_classes=80, 76 | bbox_coder=dict( 77 | type="DeltaXYWHBBoxCoder", 78 | target_means=[0.0, 0.0, 0.0, 0.0], 79 | target_stds=[0.05, 0.05, 0.1, 0.1], 80 | ), 81 | reg_class_agnostic=True, 82 | loss_cls=dict( 83 | type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0 84 | ), 85 | loss_bbox=dict(type="SmoothL1Loss", beta=1.0, loss_weight=1.0), 86 | ), 87 | dict( 88 | type="Shared2FCBBoxHead", 89 | in_channels=256, 90 | fc_out_channels=1024, 91 | roi_feat_size=7, 92 | num_classes=80, 93 | bbox_coder=dict( 94 | type="DeltaXYWHBBoxCoder", 95 | target_means=[0.0, 0.0, 0.0, 0.0], 96 | target_stds=[0.033, 0.033, 0.067, 0.067], 97 | ), 98 | reg_class_agnostic=True, 99 | loss_cls=dict( 100 | type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0 101 | ), 102 | loss_bbox=dict(type="SmoothL1Loss", beta=1.0, loss_weight=1.0), 103 | ), 104 | ], 105 | ), 106 | # model training and testing settings 107 | train_cfg=dict( 108 | rpn=dict( 109 | assigner=dict( 110 | type="MaxIoUAssigner", 111 | pos_iou_thr=0.7, 112 | neg_iou_thr=0.3, 113 | min_pos_iou=0.3, 114 | match_low_quality=True, 115 | ignore_iof_thr=-1, 116 | ), 117 | sampler=dict( 118 | type="RandomSampler", 119 | num=256, 120 | pos_fraction=0.5, 121 | neg_pos_ub=-1, 122 | add_gt_as_proposals=False, 123 | ), 124 | allowed_border=0, 125 | pos_weight=-1, 126 | debug=False, 127 | ), 128 | rpn_proposal=dict( 129 | nms_pre=2000, 130 | max_per_img=2000, 131 | nms=dict(type="nms", iou_threshold=0.7), 132 | min_bbox_size=0, 133 | ), 134 | rcnn=[ 135 | dict( 136 | assigner=dict( 137 | type="MaxIoUAssigner", 138 | pos_iou_thr=0.5, 139 | neg_iou_thr=0.5, 140 | min_pos_iou=0.5, 141 | match_low_quality=False, 142 | ignore_iof_thr=-1, 143 | ), 144 | sampler=dict( 145 | type="RandomSampler", 146 | num=512, 147 | pos_fraction=0.25, 148 | neg_pos_ub=-1, 149 | add_gt_as_proposals=True, 150 | ), 151 | pos_weight=-1, 152 | debug=False, 153 | ), 154 | dict( 155 | assigner=dict( 156 | type="MaxIoUAssigner", 157 | pos_iou_thr=0.6, 158 | neg_iou_thr=0.6, 159 | min_pos_iou=0.6, 160 | match_low_quality=False, 161 | ignore_iof_thr=-1, 162 | ), 163 | sampler=dict( 164 | type="RandomSampler", 165 | num=512, 166 | pos_fraction=0.25, 167 | neg_pos_ub=-1, 168 | add_gt_as_proposals=True, 169 | ), 170 | pos_weight=-1, 171 | debug=False, 172 | ), 173 | dict( 174 | assigner=dict( 175 | type="MaxIoUAssigner", 176 | pos_iou_thr=0.7, 177 | neg_iou_thr=0.7, 178 | min_pos_iou=0.7, 179 | match_low_quality=False, 180 | ignore_iof_thr=-1, 181 | ), 182 | sampler=dict( 183 | type="RandomSampler", 184 | num=512, 185 | pos_fraction=0.25, 186 | neg_pos_ub=-1, 187 | add_gt_as_proposals=True, 188 | ), 189 | pos_weight=-1, 190 | debug=False, 191 | ), 192 | ], 193 | ), 194 | test_cfg=dict( 195 | rpn=dict( 196 | nms_pre=1000, 197 | max_per_img=1000, 198 | nms=dict(type="nms", iou_threshold=0.7), 199 | min_bbox_size=0, 200 | ), 201 | rcnn=dict( 202 | score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100 203 | ), 204 | ), 205 | ) 206 | 207 | 208 | model = dict( 209 | type="SemiTwoStageDetector", 210 | student_cfg=student_cfg, 211 | train_cfg=dict( 212 | supervised_fields=["gt_bboxes", "gt_labels"], 213 | with_soft_teacher=True, 214 | score_thr=0.9, 215 | ), 216 | test_cfg=dict(inference_on="student"), 217 | ) 218 | -------------------------------------------------------------------------------- /configs/_base_/models/mixed_faster_rcnn_r50_caffe_fpn.py: -------------------------------------------------------------------------------- 1 | student_cfg = dict( 2 | type="FasterRCNN", 3 | pretrained="open-mmlab://detectron2/resnet50_caffe", 4 | backbone=dict( 5 | type="ResNet", 6 | depth=50, 7 | num_stages=4, 8 | out_indices=(0, 1, 2, 3), 9 | frozen_stages=1, 10 | norm_cfg=dict(type="BN", requires_grad=False), 11 | norm_eval=True, 12 | style="caffe", 13 | ), 14 | neck=dict( 15 | type="FPN", in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5 16 | ), 17 | rpn_head=dict( 18 | type="RPNHead", 19 | in_channels=256, 20 | feat_channels=256, 21 | anchor_generator=dict( 22 | type="AnchorGenerator", 23 | scales=[8], 24 | ratios=[0.5, 1.0, 2.0], 25 | strides=[4, 8, 16, 32, 64], 26 | ), 27 | bbox_coder=dict( 28 | type="DeltaXYWHBBoxCoder", 29 | target_means=[0.0, 0.0, 0.0, 0.0], 30 | target_stds=[1.0, 1.0, 1.0, 1.0], 31 | ), 32 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 33 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 34 | ), 35 | roi_head=dict( 36 | type="StandardRoIHead", 37 | bbox_roi_extractor=dict( 38 | type="SingleRoIExtractor", 39 | roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), 40 | out_channels=256, 41 | featmap_strides=[4, 8, 16, 32], 42 | ), 43 | bbox_head=dict( 44 | type="Shared2FCBBoxHead", 45 | in_channels=256, 46 | fc_out_channels=1024, 47 | roi_feat_size=7, 48 | num_classes=80, 49 | bbox_coder=dict( 50 | type="DeltaXYWHBBoxCoder", 51 | target_means=[0.0, 0.0, 0.0, 0.0], 52 | target_stds=[0.1, 0.1, 0.2, 0.2], 53 | ), 54 | reg_class_agnostic=False, 55 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), 56 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 57 | ), 58 | ), 59 | # model training and testing settings 60 | train_cfg=dict( 61 | rpn=dict( 62 | assigner=dict( 63 | type="MaxIoUAssigner", 64 | pos_iou_thr=0.7, 65 | neg_iou_thr=0.3, 66 | min_pos_iou=0.3, 67 | match_low_quality=True, 68 | ignore_iof_thr=-1, 69 | ), 70 | sampler=dict( 71 | type="RandomSampler", 72 | num=256, 73 | pos_fraction=0.5, 74 | neg_pos_ub=-1, 75 | add_gt_as_proposals=False, 76 | ), 77 | allowed_border=-1, 78 | pos_weight=-1, 79 | debug=False, 80 | ), 81 | rpn_proposal=dict( 82 | nms_pre=2000, 83 | max_per_img=1000, 84 | nms=dict(type="nms", iou_threshold=0.7), 85 | min_bbox_size=0, 86 | ), 87 | rcnn=dict( 88 | assigner=dict( 89 | type="MaxIoUAssigner", 90 | pos_iou_thr=0.5, 91 | neg_iou_thr=0.5, 92 | min_pos_iou=0.5, 93 | match_low_quality=False, 94 | ignore_iof_thr=-1, 95 | ), 96 | sampler=dict( 97 | type="RandomSampler", 98 | num=512, 99 | pos_fraction=0.25, 100 | neg_pos_ub=-1, 101 | add_gt_as_proposals=True, 102 | ), 103 | pos_weight=-1, 104 | debug=False, 105 | ), 106 | ), 107 | test_cfg=dict( 108 | rpn=dict( 109 | nms_pre=1000, 110 | max_per_img=1000, 111 | nms=dict(type="nms", iou_threshold=0.7), 112 | min_bbox_size=0, 113 | ), 114 | rcnn=dict( 115 | score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100 116 | ) 117 | # soft-nms is also supported for rcnn testing 118 | # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) 119 | ), 120 | ) 121 | 122 | 123 | model = dict( 124 | type="SemiTwoStageDetector", 125 | student_cfg=student_cfg, 126 | train_cfg=dict( 127 | supervised_fields=["gt_bboxes", "gt_labels"], 128 | with_soft_teacher=True, 129 | score_thr=0.9, 130 | ), 131 | test_cfg=dict(inference_on="student"), 132 | ) 133 | -------------------------------------------------------------------------------- /configs/_base_/models/mixed_faster_rcnn_swin_small.py: -------------------------------------------------------------------------------- 1 | student_cfg = dict( 2 | type="FasterRCNN", 3 | pretrained="https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth", 4 | backbone=dict( 5 | type="SwinTransformer", 6 | embed_dim=96, 7 | depths=[2, 2, 18, 2], 8 | num_heads=[3, 6, 12, 24], 9 | window_size=7, 10 | mlp_ratio=4.0, 11 | qkv_bias=True, 12 | qk_scale=None, 13 | drop_rate=0.0, 14 | attn_drop_rate=0.0, 15 | drop_path_rate=0.2, 16 | ape=False, 17 | patch_norm=True, 18 | out_indices=(0, 1, 2, 3), 19 | use_checkpoint=False, 20 | ), 21 | neck=dict( 22 | type="FPN", in_channels=[96, 192, 384, 768], out_channels=256, num_outs=5 23 | ), 24 | rpn_head=dict( 25 | type="RPNHead", 26 | in_channels=256, 27 | feat_channels=256, 28 | anchor_generator=dict( 29 | type="AnchorGenerator", 30 | scales=[8], 31 | ratios=[0.5, 1.0, 2.0], 32 | strides=[4, 8, 16, 32, 64], 33 | ), 34 | bbox_coder=dict( 35 | type="DeltaXYWHBBoxCoder", 36 | target_means=[0.0, 0.0, 0.0, 0.0], 37 | target_stds=[1.0, 1.0, 1.0, 1.0], 38 | ), 39 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), 40 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 41 | ), 42 | roi_head=dict( 43 | type="StandardRoIHead", 44 | bbox_roi_extractor=dict( 45 | type="SingleRoIExtractor", 46 | roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), 47 | out_channels=256, 48 | featmap_strides=[4, 8, 16, 32], 49 | ), 50 | bbox_head=dict( 51 | type="Shared2FCBBoxHead", 52 | in_channels=256, 53 | fc_out_channels=1024, 54 | roi_feat_size=7, 55 | num_classes=80, 56 | bbox_coder=dict( 57 | type="DeltaXYWHBBoxCoder", 58 | target_means=[0.0, 0.0, 0.0, 0.0], 59 | target_stds=[0.1, 0.1, 0.2, 0.2], 60 | ), 61 | reg_class_agnostic=False, 62 | loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), 63 | loss_bbox=dict(type="L1Loss", loss_weight=1.0), 64 | ), 65 | ), 66 | train_cfg=dict( 67 | rpn=dict( 68 | assigner=dict( 69 | type="MaxIoUAssigner", 70 | pos_iou_thr=0.7, 71 | neg_iou_thr=0.3, 72 | min_pos_iou=0.3, 73 | match_low_quality=True, 74 | ignore_iof_thr=-1, 75 | ), 76 | sampler=dict( 77 | type="RandomSampler", 78 | num=256, 79 | pos_fraction=0.5, 80 | neg_pos_ub=-1, 81 | add_gt_as_proposals=False, 82 | ), 83 | allowed_border=-1, 84 | pos_weight=-1, 85 | debug=False, 86 | ), 87 | rpn_proposal=dict( 88 | nms_pre=2000, 89 | max_per_img=1000, 90 | nms=dict(type="nms", iou_threshold=0.7), 91 | min_bbox_size=0, 92 | ), 93 | rcnn=dict( 94 | assigner=dict( 95 | type="MaxIoUAssigner", 96 | pos_iou_thr=0.5, 97 | neg_iou_thr=0.5, 98 | min_pos_iou=0.5, 99 | match_low_quality=False, 100 | ignore_iof_thr=-1, 101 | ), 102 | sampler=dict( 103 | type="RandomSampler", 104 | num=512, 105 | pos_fraction=0.25, 106 | neg_pos_ub=-1, 107 | add_gt_as_proposals=True, 108 | ), 109 | pos_weight=-1, 110 | debug=False, 111 | ), 112 | ), 113 | test_cfg=dict( 114 | rpn=dict( 115 | nms_pre=1000, 116 | max_per_img=1000, 117 | nms=dict(type="nms", iou_threshold=0.7), 118 | min_bbox_size=0, 119 | ), 120 | rcnn=dict( 121 | score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100 122 | ), 123 | ), 124 | ) 125 | 126 | 127 | model = dict( 128 | type="SemiTwoStageDetector", 129 | student_cfg=student_cfg, 130 | train_cfg=dict( 131 | supervised_fields=["gt_bboxes", "gt_labels"], 132 | with_soft_teacher=True, 133 | score_thr=0.9, 134 | ), 135 | test_cfg=dict(inference_on="student"), 136 | ) 137 | -------------------------------------------------------------------------------- /configs/_base_/runtimes/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1, create_symlink=False) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type="TextLoggerHook"), 7 | dict( 8 | type="GlobalWandbLoggerHook", 9 | init_kwargs=dict( 10 | project="self_distil", 11 | entity="sdl", 12 | name="{exp_name}", 13 | # log some important params 14 | config=dict(work_dir="{work_dir}"), 15 | notes="{note}", 16 | ), 17 | ), 18 | ], 19 | ) 20 | # yapf:enable 21 | # fp16 22 | fp16 = dict(loss_scale="dynamic") 23 | # momentum update 24 | custom_hooks = [dict(type="NumClassCheckHook"), dict(type="MeanTeacherHook")] 25 | # custom 26 | dist_params = dict(backend="nccl") 27 | log_level = "INFO" 28 | auto_resume = True 29 | load_from = None 30 | resume_from = None 31 | workflow = [("train", 1)] 32 | note = "" 33 | -------------------------------------------------------------------------------- /configs/_base_/schedules/sgd_schedule_1x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy="step", warmup="linear", warmup_iters=500, warmup_ratio=0.001, step=[8, 11] 7 | ) 8 | runner = dict(type="EpochBasedRunner", max_epochs=12) 9 | -------------------------------------------------------------------------------- /configs/_base_/schedules/swin_adamw_schedule_1x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict( 3 | type="AdamW", 4 | lr=0.0001, 5 | betas=(0.9, 0.999), 6 | weight_decay=0.05, 7 | paramwise_cfg=dict( 8 | custom_keys=dict( 9 | absolute_pos_embed=dict(decay_mult=0.0), 10 | relative_position_bias_table=dict(decay_mult=0.0), 11 | norm=dict(decay_mult=0.0), 12 | ) 13 | ), 14 | ) 15 | optimizer_config = dict(grad_clip=None) 16 | # learning policy 17 | lr_config = dict( 18 | policy="step", warmup="linear", warmup_iters=500, warmup_ratio=0.001, step=[8, 11] 19 | ) 20 | runner = dict(type="EpochBasedRunner", max_epochs=12) 21 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_025x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 12 | 13 | lr_config = dict(step=[8, 11]) 14 | runner = dict(max_epochs=12) 15 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_05x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 12 | 13 | lr_config = dict(step=[16, 22]) 14 | runner = dict(max_epochs=24) 15 | custom_hooks = [dict(type="NumClassCheckHook")] 16 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_1x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 12 | 13 | lr_config = dict(step=[32, 44]) 14 | runner = dict(max_epochs=48) 15 | custom_hooks = [dict(type="NumClassCheckHook")] 16 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_2x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 12 | 13 | lr_config = dict(step=[64, 88]) 14 | runner = dict(max_epochs=96) 15 | custom_hooks = [dict(type="NumClassCheckHook")] 16 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_4x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 12 | 13 | lr_config = dict(step=[64 * 2, 88 * 2]) 14 | runner = dict(max_epochs=96 * 2) 15 | custom_hooks = [dict(type="NumClassCheckHook")] 16 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_5x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "https://raw.githubusercontent.com/open-mmlab/mmdetection/v2.11.0/configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 12 | 13 | lr_config = dict(step=[16 * 10, 22 * 10]) 14 | runner = dict(max_epochs=24 * 20) 15 | custom_hooks = [dict(type="NumClassCheckHook")] 16 | resume_from = "work_dirs/supervised_v2/faster_rcnn_r50_caffe_fpn_coco_4x/epoch_128.pth" 17 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_8x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 12 | 13 | 14 | lr_config = dict(step=[64 * 4, 88 * 4]) 15 | runner = dict(max_epochs=96 * 4) 16 | custom_hooks = [dict(type="NumClassCheckHook")] 17 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_strong_025x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | 12 | 13 | color_transform = dict( 14 | type="RandomApply", 15 | policies=[ 16 | dict(type="Identity"), 17 | dict(type="Jitter", contrast=1.0), 18 | dict(type="Jitter", brightness=1.0), 19 | dict(type="Jitter", hue=1.0), 20 | dict(type="Equalize"), 21 | dict(type="AutoContrast"), 22 | dict(type="PosterizeV1"), 23 | dict(type="RandomGrayScale"), 24 | dict(type="SolarizeV1"), 25 | ], 26 | ) 27 | 28 | geo_transform = dict( 29 | type="RandomApply", 30 | policies=[ 31 | dict(type="Identity"), 32 | dict( 33 | type="ImTranslate", 34 | level=5, 35 | prob=1.0, 36 | max_translate_offset=100, 37 | direction="horizontal", 38 | ), 39 | dict( 40 | type="ImTranslate", 41 | level=5, 42 | prob=1.0, 43 | max_translate_offset=100, 44 | direction="vertical", 45 | ), 46 | dict(type="ImRotate", level=5, prob=1.0), 47 | dict(type="ImShear", level=5, prob=1.0), 48 | ], 49 | ) 50 | 51 | img_norm_cfg = dict(mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 52 | scale_cfg = dict( 53 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 54 | ) 55 | data_root = "data/coco" 56 | sup_set = "train2017" 57 | val_set = "val2017" 58 | test_set = "val2017" 59 | # end def 60 | 61 | sup = [ 62 | dict(type="LoadImageFromFile", file_client_args=dict(backend="zip")), 63 | dict(type="LoadAnnotations", with_bbox=True), 64 | color_transform, 65 | dict(type="Resize", **scale_cfg), 66 | dict(type="RandomFlip", flip_ratio=0.5), 67 | dict(type="Normalize", **img_norm_cfg), 68 | geo_transform, 69 | dict( 70 | type="CutOut", 71 | n_holes=(1, 5), 72 | cutout_ratio=[ 73 | (0.05, 0.05), 74 | (0.75, 0.75), 75 | (0.1, 0.1), 76 | (0.125, 0.125), 77 | (0.15, 0.15), 78 | (0.175, 0.175), 79 | (0.2, 0.2), 80 | ], 81 | fill_in=(0, 0, 0), 82 | ), 83 | dict(type="Pad", size_divisor=32), 84 | dict(type="DefaultFormatBundle"), 85 | dict(type="ExtraAttrs", tag="sup"), 86 | dict( 87 | type="CollectV1", 88 | keys=["img", "gt_bboxes", "gt_labels"], 89 | extra_meta_keys=["tag"], 90 | ), 91 | ] 92 | 93 | data = dict( 94 | samples_per_gpu=2, 95 | workers_per_gpu=2, 96 | train=dict( 97 | type="CocoDataset", 98 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 99 | img_prefix="{data_root}/{sup_set}/", 100 | pipeline=sup, 101 | ), 102 | ) 103 | 104 | lr_config = dict(step=[8, 11]) 105 | runner = dict(max_epochs=12) 106 | 107 | custom_hooks = [dict(type="NumClassCheckHook")] 108 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_strong_05x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | 12 | 13 | color_transform = dict( 14 | type="RandomApply", 15 | policies=[ 16 | dict(type="Identity"), 17 | dict(type="Jitter", contrast=1.0), 18 | dict(type="Jitter", brightness=1.0), 19 | dict(type="Jitter", hue=1.0), 20 | dict(type="Equalize"), 21 | dict(type="AutoContrast"), 22 | dict(type="PosterizeV1"), 23 | dict(type="RandomGrayScale"), 24 | dict(type="SolarizeV1"), 25 | ], 26 | ) 27 | 28 | geo_transform = dict( 29 | type="RandomApply", 30 | policies=[ 31 | dict(type="Identity"), 32 | dict( 33 | type="ImTranslate", 34 | level=5, 35 | prob=1.0, 36 | max_translate_offset=100, 37 | direction="horizontal", 38 | ), 39 | dict( 40 | type="ImTranslate", 41 | level=5, 42 | prob=1.0, 43 | max_translate_offset=100, 44 | direction="vertical", 45 | ), 46 | dict(type="ImRotate", level=5, prob=1.0), 47 | dict(type="ImShear", level=5, prob=1.0), 48 | ], 49 | ) 50 | 51 | img_norm_cfg = dict(mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 52 | scale_cfg = dict( 53 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 54 | ) 55 | data_root = "data/coco" 56 | sup_set = "train2017" 57 | val_set = "val2017" 58 | test_set = "val2017" 59 | # end def 60 | 61 | sup = [ 62 | dict(type="LoadImageFromFile", file_client_args=dict(backend="zip")), 63 | dict(type="LoadAnnotations", with_bbox=True), 64 | color_transform, 65 | dict(type="Resize", **scale_cfg), 66 | dict(type="RandomFlip", flip_ratio=0.5), 67 | dict(type="Normalize", **img_norm_cfg), 68 | geo_transform, 69 | dict( 70 | type="CutOut", 71 | n_holes=(1, 5), 72 | cutout_ratio=[ 73 | (0.05, 0.05), 74 | (0.75, 0.75), 75 | (0.1, 0.1), 76 | (0.125, 0.125), 77 | (0.15, 0.15), 78 | (0.175, 0.175), 79 | (0.2, 0.2), 80 | ], 81 | fill_in=(0, 0, 0), 82 | ), 83 | dict(type="Pad", size_divisor=32), 84 | dict(type="DefaultFormatBundle"), 85 | dict(type="ExtraAttrs", tag="sup"), 86 | dict( 87 | type="CollectV1", 88 | keys=["img", "gt_bboxes", "gt_labels"], 89 | extra_meta_keys=["tag"], 90 | ), 91 | ] 92 | 93 | data = dict( 94 | samples_per_gpu=2, 95 | workers_per_gpu=2, 96 | train=dict( 97 | type="CocoDataset", 98 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 99 | img_prefix="{data_root}/{sup_set}/", 100 | pipeline=sup, 101 | ), 102 | ) 103 | 104 | lr_config = dict(step=[16, 22]) 105 | runner = dict(max_epochs=24) 106 | 107 | custom_hooks = [dict(type="NumClassCheckHook")] 108 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_strong_1x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | 12 | 13 | color_transform = dict( 14 | type="RandomApply", 15 | policies=[ 16 | dict(type="Identity"), 17 | dict(type="Jitter", contrast=1.0), 18 | dict(type="Jitter", brightness=1.0), 19 | dict(type="Jitter", hue=1.0), 20 | dict(type="Equalize"), 21 | dict(type="AutoContrast"), 22 | dict(type="PosterizeV1"), 23 | dict(type="RandomGrayScale"), 24 | dict(type="SolarizeV1"), 25 | ], 26 | ) 27 | 28 | geo_transform = dict( 29 | type="RandomApply", 30 | policies=[ 31 | dict(type="Identity"), 32 | dict( 33 | type="ImTranslate", 34 | level=5, 35 | prob=1.0, 36 | max_translate_offset=100, 37 | direction="horizontal", 38 | ), 39 | dict( 40 | type="ImTranslate", 41 | level=5, 42 | prob=1.0, 43 | max_translate_offset=100, 44 | direction="vertical", 45 | ), 46 | dict(type="ImRotate", level=5, prob=1.0), 47 | dict(type="ImShear", level=5, prob=1.0), 48 | ], 49 | ) 50 | 51 | img_norm_cfg = dict(mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 52 | scale_cfg = dict( 53 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 54 | ) 55 | data_root = "data/coco" 56 | sup_set = "train2017" 57 | val_set = "val2017" 58 | test_set = "val2017" 59 | # end def 60 | 61 | sup = [ 62 | dict(type="LoadImageFromFile", file_client_args=dict(backend="zip")), 63 | dict(type="LoadAnnotations", with_bbox=True), 64 | color_transform, 65 | dict(type="Resize", **scale_cfg), 66 | dict(type="RandomFlip", flip_ratio=0.5), 67 | dict(type="Normalize", **img_norm_cfg), 68 | geo_transform, 69 | dict( 70 | type="CutOut", 71 | n_holes=(1, 5), 72 | cutout_ratio=[ 73 | (0.05, 0.05), 74 | (0.75, 0.75), 75 | (0.1, 0.1), 76 | (0.125, 0.125), 77 | (0.15, 0.15), 78 | (0.175, 0.175), 79 | (0.2, 0.2), 80 | ], 81 | fill_in=(0, 0, 0), 82 | ), 83 | dict(type="Pad", size_divisor=32), 84 | dict(type="DefaultFormatBundle"), 85 | dict(type="ExtraAttrs", tag="sup"), 86 | dict( 87 | type="CollectV1", 88 | keys=["img", "gt_bboxes", "gt_labels"], 89 | extra_meta_keys=["tag"], 90 | ), 91 | ] 92 | 93 | data = dict( 94 | samples_per_gpu=2, 95 | workers_per_gpu=2, 96 | train=dict( 97 | type="CocoDataset", 98 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 99 | img_prefix="{data_root}/{sup_set}/", 100 | pipeline=sup, 101 | ), 102 | ) 103 | 104 | lr_config = dict(step=[16 * 2, 22 * 2]) 105 | runner = dict(max_epochs=24 * 2) 106 | 107 | custom_hooks = [dict(type="NumClassCheckHook")] 108 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_strong_2x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | 12 | 13 | color_transform = dict( 14 | type="RandomApply", 15 | policies=[ 16 | dict(type="Identity"), 17 | dict(type="Jitter", contrast=1.0), 18 | dict(type="Jitter", brightness=1.0), 19 | dict(type="Jitter", hue=1.0), 20 | dict(type="Equalize"), 21 | dict(type="AutoContrast"), 22 | dict(type="PosterizeV1"), 23 | dict(type="RandomGrayScale"), 24 | dict(type="SolarizeV1"), 25 | ], 26 | ) 27 | 28 | geo_transform = dict( 29 | type="RandomApply", 30 | policies=[ 31 | dict(type="Identity"), 32 | dict( 33 | type="ImTranslate", 34 | level=5, 35 | prob=1.0, 36 | max_translate_offset=100, 37 | direction="horizontal", 38 | ), 39 | dict( 40 | type="ImTranslate", 41 | level=5, 42 | prob=1.0, 43 | max_translate_offset=100, 44 | direction="vertical", 45 | ), 46 | dict(type="ImRotate", level=5, prob=1.0), 47 | dict(type="ImShear", level=5, prob=1.0), 48 | ], 49 | ) 50 | 51 | img_norm_cfg = dict(mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 52 | scale_cfg = dict( 53 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 54 | ) 55 | data_root = "data/coco" 56 | sup_set = "train2017" 57 | val_set = "val2017" 58 | test_set = "val2017" 59 | # end def 60 | 61 | sup = [ 62 | dict(type="LoadImageFromFile", file_client_args=dict(backend="zip")), 63 | dict(type="LoadAnnotations", with_bbox=True), 64 | color_transform, 65 | dict(type="Resize", **scale_cfg), 66 | dict(type="RandomFlip", flip_ratio=0.5), 67 | dict(type="Normalize", **img_norm_cfg), 68 | geo_transform, 69 | dict( 70 | type="CutOut", 71 | n_holes=(1, 5), 72 | cutout_ratio=[ 73 | (0.05, 0.05), 74 | (0.75, 0.75), 75 | (0.1, 0.1), 76 | (0.125, 0.125), 77 | (0.15, 0.15), 78 | (0.175, 0.175), 79 | (0.2, 0.2), 80 | ], 81 | fill_in=(0, 0, 0), 82 | ), 83 | dict(type="Pad", size_divisor=32), 84 | dict(type="DefaultFormatBundle"), 85 | dict(type="ExtraAttrs", tag="sup"), 86 | dict( 87 | type="CollectV1", 88 | keys=["img", "gt_bboxes", "gt_labels"], 89 | extra_meta_keys=["tag"], 90 | ), 91 | ] 92 | 93 | data = dict( 94 | samples_per_gpu=2, 95 | workers_per_gpu=2, 96 | train=dict( 97 | type="CocoDataset", 98 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 99 | img_prefix="{data_root}/{sup_set}/", 100 | pipeline=sup, 101 | ), 102 | ) 103 | 104 | lr_config = dict(step=[16 * 4, 22 * 4]) 105 | runner = dict(max_epochs=24 * 4) 106 | 107 | custom_hooks = [dict(type="NumClassCheckHook")] 108 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_strong_4x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "https://raw.githubusercontent.com/open-mmlab/mmdetection/v2.11.0/configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | 12 | 13 | color_transform = dict( 14 | type="RandomApply", 15 | policies=[ 16 | dict(type="Identity"), 17 | dict(type="Jitter", contrast=1.0), 18 | dict(type="Jitter", brightness=1.0), 19 | dict(type="Jitter", hue=1.0), 20 | dict(type="Equalize"), 21 | dict(type="AutoContrast"), 22 | dict(type="PosterizeV1"), 23 | dict(type="RandomGrayScale"), 24 | dict(type="SolarizeV1"), 25 | ], 26 | ) 27 | 28 | geo_transform = dict( 29 | type="RandomApply", 30 | policies=[ 31 | dict(type="Identity"), 32 | dict( 33 | type="ImTranslate", 34 | level=5, 35 | prob=1.0, 36 | max_translate_offset=100, 37 | direction="horizontal", 38 | ), 39 | dict( 40 | type="ImTranslate", 41 | level=5, 42 | prob=1.0, 43 | max_translate_offset=100, 44 | direction="vertical", 45 | ), 46 | dict(type="ImRotate", level=5, prob=1.0), 47 | dict(type="ImShear", level=5, prob=1.0), 48 | ], 49 | ) 50 | 51 | img_norm_cfg = dict(mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 52 | scale_cfg = dict( 53 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 54 | ) 55 | data_root = "data/coco" 56 | sup_set = "train2017" 57 | val_set = "val2017" 58 | test_set = "val2017" 59 | # end def 60 | 61 | sup = [ 62 | dict(type="LoadImageFromFile", file_client_args=dict(backend="zip")), 63 | dict(type="LoadAnnotations", with_bbox=True), 64 | color_transform, 65 | dict(type="Resize", **scale_cfg), 66 | dict(type="RandomFlip", flip_ratio=0.5), 67 | dict(type="Normalize", **img_norm_cfg), 68 | geo_transform, 69 | dict( 70 | type="CutOut", 71 | n_holes=(1, 5), 72 | cutout_ratio=[ 73 | (0.05, 0.05), 74 | (0.75, 0.75), 75 | (0.1, 0.1), 76 | (0.125, 0.125), 77 | (0.15, 0.15), 78 | (0.175, 0.175), 79 | (0.2, 0.2), 80 | ], 81 | fill_in=(0, 0, 0), 82 | ), 83 | dict(type="Pad", size_divisor=32), 84 | dict(type="DefaultFormatBundle"), 85 | dict(type="ExtraAttrs", tag="sup"), 86 | dict( 87 | type="CollectV1", 88 | keys=["img", "gt_bboxes", "gt_labels"], 89 | extra_meta_keys=["tag"], 90 | ), 91 | ] 92 | 93 | data = dict( 94 | samples_per_gpu=2, 95 | workers_per_gpu=2, 96 | train=dict( 97 | type="CocoDataset", 98 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 99 | img_prefix="{data_root}/{sup_set}/", 100 | pipeline=sup, 101 | ), 102 | ) 103 | 104 | lr_config = dict(step=[16 * 8, 22 * 8]) 105 | runner = dict(max_epochs=24 * 8) 106 | 107 | custom_hooks = [dict(type="NumClassCheckHook")] 108 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_strong_5x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "https://raw.githubusercontent.com/open-mmlab/mmdetection/v2.11.0/configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | 12 | 13 | color_transform = dict( 14 | type="RandomApply", 15 | policies=[ 16 | dict(type="Identity"), 17 | dict(type="Jitter", contrast=1.0), 18 | dict(type="Jitter", brightness=1.0), 19 | dict(type="Jitter", hue=1.0), 20 | dict(type="Equalize"), 21 | dict(type="AutoContrast"), 22 | dict(type="PosterizeV1"), 23 | dict(type="RandomGrayScale"), 24 | dict(type="SolarizeV1"), 25 | ], 26 | ) 27 | 28 | geo_transform = dict( 29 | type="RandomApply", 30 | policies=[ 31 | dict(type="Identity"), 32 | dict( 33 | type="ImTranslate", 34 | level=5, 35 | prob=1.0, 36 | max_translate_offset=100, 37 | direction="horizontal", 38 | ), 39 | dict( 40 | type="ImTranslate", 41 | level=5, 42 | prob=1.0, 43 | max_translate_offset=100, 44 | direction="vertical", 45 | ), 46 | dict(type="ImRotate", level=5, prob=1.0), 47 | dict(type="ImShear", level=5, prob=1.0), 48 | ], 49 | ) 50 | 51 | img_norm_cfg = dict(mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 52 | scale_cfg = dict( 53 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 54 | ) 55 | data_root = "data/coco" 56 | sup_set = "train2017" 57 | val_set = "val2017" 58 | test_set = "val2017" 59 | # end def 60 | 61 | sup = [ 62 | dict(type="LoadImageFromFile", file_client_args=dict(backend="zip")), 63 | dict(type="LoadAnnotations", with_bbox=True), 64 | color_transform, 65 | dict(type="Resize", **scale_cfg), 66 | dict(type="RandomFlip", flip_ratio=0.5), 67 | dict(type="Normalize", **img_norm_cfg), 68 | geo_transform, 69 | dict( 70 | type="CutOut", 71 | n_holes=(1, 5), 72 | cutout_ratio=[ 73 | (0.05, 0.05), 74 | (0.75, 0.75), 75 | (0.1, 0.1), 76 | (0.125, 0.125), 77 | (0.15, 0.15), 78 | (0.175, 0.175), 79 | (0.2, 0.2), 80 | ], 81 | fill_in=(0, 0, 0), 82 | ), 83 | dict(type="Pad", size_divisor=32), 84 | dict(type="DefaultFormatBundle"), 85 | dict(type="ExtraAttrs", tag="sup"), 86 | dict( 87 | type="CollectV1", 88 | keys=["img", "gt_bboxes", "gt_labels"], 89 | extra_meta_keys=["tag"], 90 | ), 91 | ] 92 | 93 | data = dict( 94 | samples_per_gpu=2, 95 | workers_per_gpu=2, 96 | train=dict( 97 | type="CocoDataset", 98 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 99 | img_prefix="{data_root}/{sup_set}/", 100 | pipeline=sup, 101 | ), 102 | ) 103 | 104 | lr_config = dict(step=[16 * 10, 22 * 10]) 105 | runner = dict(max_epochs=24 * 10) 106 | 107 | custom_hooks = [dict(type="NumClassCheckHook")] 108 | 109 | resume_from = ( 110 | "work_dirs/supervised_v2/faster_rcnn_r50_caffe_fpn_coco_strong_4x/epoch_128.pth" 111 | ) 112 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_r50_caffe_fpn_coco_strong_8x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_default.py", 3 | "mmdet:configs/_base_/models/faster_rcnn_r50_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | model = dict( 8 | pretrained="open-mmlab://detectron2/resnet50_caffe", 9 | backbone=dict(norm_cfg=dict(requires_grad=False), norm_eval=True, style="caffe"), 10 | ) 11 | 12 | 13 | color_transform = dict( 14 | type="RandomApply", 15 | policies=[ 16 | dict(type="Identity"), 17 | dict(type="Jitter", contrast=1.0), 18 | dict(type="Jitter", brightness=1.0), 19 | dict(type="Jitter", hue=1.0), 20 | dict(type="Equalize"), 21 | dict(type="AutoContrast"), 22 | dict(type="PosterizeV1"), 23 | dict(type="RandomGrayScale"), 24 | dict(type="SolarizeV1"), 25 | ], 26 | ) 27 | 28 | geo_transform = dict( 29 | type="RandomApply", 30 | policies=[ 31 | dict(type="Identity"), 32 | dict( 33 | type="ImTranslate", 34 | level=5, 35 | prob=1.0, 36 | max_translate_offset=100, 37 | direction="horizontal", 38 | ), 39 | dict( 40 | type="ImTranslate", 41 | level=5, 42 | prob=1.0, 43 | max_translate_offset=100, 44 | direction="vertical", 45 | ), 46 | dict(type="ImRotate", level=5, prob=1.0), 47 | dict(type="ImShear", level=5, prob=1.0), 48 | ], 49 | ) 50 | 51 | img_norm_cfg = dict(mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 52 | scale_cfg = dict( 53 | img_scale=[(1333, 400), (1333, 1200)], multiscale_mode="range", keep_ratio=True 54 | ) 55 | data_root = "data/coco" 56 | sup_set = "train2017" 57 | val_set = "val2017" 58 | test_set = "val2017" 59 | # end def 60 | 61 | sup = [ 62 | dict(type="LoadImageFromFile", file_client_args=dict(backend="zip")), 63 | dict(type="LoadAnnotations", with_bbox=True), 64 | color_transform, 65 | dict(type="Resize", **scale_cfg), 66 | dict(type="RandomFlip", flip_ratio=0.5), 67 | dict(type="Normalize", **img_norm_cfg), 68 | geo_transform, 69 | dict( 70 | type="CutOut", 71 | n_holes=(1, 5), 72 | cutout_ratio=[ 73 | (0.05, 0.05), 74 | (0.75, 0.75), 75 | (0.1, 0.1), 76 | (0.125, 0.125), 77 | (0.15, 0.15), 78 | (0.175, 0.175), 79 | (0.2, 0.2), 80 | ], 81 | fill_in=(0, 0, 0), 82 | ), 83 | dict(type="Pad", size_divisor=32), 84 | dict(type="DefaultFormatBundle"), 85 | dict(type="ExtraAttrs", tag="sup"), 86 | dict( 87 | type="CollectV1", 88 | keys=["img", "gt_bboxes", "gt_labels"], 89 | extra_meta_keys=["tag"], 90 | ), 91 | ] 92 | 93 | data = dict( 94 | samples_per_gpu=2, 95 | workers_per_gpu=2, 96 | train=dict( 97 | type="CocoDataset", 98 | ann_file="{data_root}/annotations/instances_{sup_set}.json", 99 | img_prefix="{data_root}/{sup_set}/", 100 | pipeline=sup, 101 | ), 102 | ) 103 | 104 | lr_config = dict(step=[64 * 4, 88 * 4]) 105 | runner = dict(max_epochs=96 * 4) 106 | 107 | custom_hooks = [dict(type="NumClassCheckHook")] 108 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_swin_small_fpn_coco_1x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_rgb.py", 3 | "../_base_/models/faster_rcnn_swin_small.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/swin_adamw_schedule_1x.py", 6 | ] 7 | 8 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 9 | 10 | lr_config = dict(step=[16, 22]) 11 | runner = dict(max_epochs=24) 12 | 13 | custom_hooks = [dict(type="NumClassCheckHook")] 14 | 15 | resume_from = "work_dirs/supervised/faster_rcnn_swin_small_fpn_coco_4x/epoch_16.pth" 16 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_swin_small_fpn_coco_2x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_rgb.py", 3 | "../_base_/models/faster_rcnn_swin_small.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/swin_adamw_schedule_1x.py", 6 | ] 7 | 8 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 9 | 10 | lr_config = dict(step=[32, 44]) 11 | runner = dict(max_epochs=48) 12 | 13 | custom_hooks = [dict(type="NumClassCheckHook")] 14 | 15 | resume_from = "work_dirs/supervised/faster_rcnn_swin_small_fpn_coco_4x/epoch_32.pth" 16 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_swin_small_fpn_coco_3x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_rgb.py", 3 | "../_base_/models/faster_rcnn_swin_small.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/swin_adamw_schedule_1x.py", 6 | ] 7 | 8 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 9 | 10 | lr_config = dict(step=[48, 66]) 11 | runner = dict(max_epochs=72) 12 | 13 | custom_hooks = [dict(type="NumClassCheckHook")] 14 | 15 | resume_from = "work_dirs/supervised/faster_rcnn_swin_small_fpn_coco_4x/epoch_48.pth" 16 | -------------------------------------------------------------------------------- /configs/baseline/faster_rcnn_swin_small_fpn_coco_4x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/coco_detection_rgb.py", 3 | "../_base_/models/faster_rcnn_swin_small.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/swin_adamw_schedule_1x.py", 6 | ] 7 | 8 | data = dict(samples_per_gpu=2, workers_per_gpu=2,) 9 | 10 | lr_config = dict(step=[64, 88]) 11 | runner = dict(max_epochs=96) 12 | 13 | custom_hooks = [dict(type="NumClassCheckHook")] 14 | -------------------------------------------------------------------------------- /configs/for_print.py: -------------------------------------------------------------------------------- 1 | _base_ = "https://raw.githubusercontent.com/SwinTransformer/Swin-Transformer-Object-Detection/master/configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py" 2 | -------------------------------------------------------------------------------- /configs/ours/mixed_cascade_rcnn_swin_small_fpn_coco_4x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/disl_coco_detection_rgb.py", 3 | "../_base_/models/mixed_cascade_rcnn_swin_small.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/swin_adamw_schedule_1x.py", 6 | ] 7 | lr_config = dict(step=[32, 44]) 8 | runner = dict(max_epochs=48) 9 | -------------------------------------------------------------------------------- /configs/ours/mixed_faster_rcnn_r50_caffe_fpn_coco_8x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/disl_coco_detection_default.py", 3 | "../_base_/models/mixed_faster_rcnn_r50_caffe_fpn.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/sgd_schedule_1x.py", 6 | ] 7 | 8 | lr_config = dict(step=[32 * 2, 44 * 2]) 9 | runner = dict(max_epochs=48 * 2) 10 | -------------------------------------------------------------------------------- /configs/ours/mixed_faster_rcnn_swin_small_fpn_coco_4x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | "../_base_/datasets/disl_coco_detection_rgb.py", 3 | "../_base_/models/mixed_faster_rcnn_swin_small.py", 4 | "../_base_/runtimes/default_runtime.py", 5 | "../_base_/schedules/swin_adamw_schedule_1x.py", 6 | ] 7 | lr_config = dict(step=[32, 44]) 8 | runner = dict(max_epochs=48) 9 | data = dict(samples_per_gpu=2, workers_per_gpu=2) 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pre-commit 2 | #git+https://github.com/open-mmlab/mmdetection.git 3 | #--extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 4 | #nvidia-dali-cuda100 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | from setuptools import find_packages, setup 4 | 5 | import torch 6 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 7 | 8 | 9 | def parse_requirements(fname="requirements.txt", with_version=True): 10 | """Parse the package dependencies listed in a requirements file but strips 11 | specific versioning information. 12 | 13 | Args: 14 | fname (str): path to requirements file 15 | with_version (bool, default=False): if True include version specs 16 | 17 | Returns: 18 | List[str]: list of requirements items 19 | 20 | CommandLine: 21 | python -c "import setup; print(setup.parse_requirements())" 22 | """ 23 | import sys 24 | from os.path import exists 25 | import re 26 | 27 | require_fpath = fname 28 | 29 | def parse_line(line): 30 | """Parse information from a line in a requirements text file.""" 31 | if line.startswith("-r "): 32 | # Allow specifying requirements in other files 33 | target = line.split(" ")[1] 34 | for info in parse_require_file(target): 35 | yield info 36 | else: 37 | info = {"line": line} 38 | if line.startswith("-e "): 39 | info["package"] = line.split("#egg=")[1] 40 | elif "@git+" in line: 41 | info["package"] = line 42 | else: 43 | # Remove versioning from the package 44 | pat = "(" + "|".join([">=", "==", ">"]) + ")" 45 | parts = re.split(pat, line, maxsplit=1) 46 | parts = [p.strip() for p in parts] 47 | 48 | info["package"] = parts[0] 49 | if len(parts) > 1: 50 | op, rest = parts[1:] 51 | if ";" in rest: 52 | # Handle platform specific dependencies 53 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies 54 | version, platform_deps = map(str.strip, rest.split(";")) 55 | info["platform_deps"] = platform_deps 56 | else: 57 | version = rest # NOQA 58 | info["version"] = (op, version) 59 | yield info 60 | 61 | def parse_require_file(fpath): 62 | with open(fpath, "r") as f: 63 | for line in f.readlines(): 64 | line = line.strip() 65 | if line and not line.startswith("#"): 66 | for info in parse_line(line): 67 | yield info 68 | 69 | def gen_packages_items(): 70 | if exists(require_fpath): 71 | for info in parse_require_file(require_fpath): 72 | parts = [info["package"]] 73 | if with_version and "version" in info: 74 | parts.extend(info["version"]) 75 | if not sys.version.startswith("3.4"): 76 | # apparently package_deps are broken in 3.4 77 | platform_deps = info.get("platform_deps") 78 | if platform_deps is not None: 79 | parts.append(";" + platform_deps) 80 | item = "".join(parts) 81 | yield item 82 | 83 | packages = list(gen_packages_items()) 84 | return packages 85 | 86 | 87 | if __name__ == "__main__": 88 | setup( 89 | name="src", 90 | version="0.0.1", 91 | description="OpenMMLab Detection Toolbox and Benchmark", 92 | author="xxx", 93 | author_email="", 94 | # install_requires=parse_requirements('requirements.txt'), 95 | packages=find_packages(exclude=("configs", "tools", "demo")), 96 | ext_modules=[], 97 | zip_safe=False, 98 | ) 99 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | -------------------------------------------------------------------------------- /src/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import ( 2 | async_inference_detector, 3 | inference_detector, 4 | init_detector, 5 | show_result_pyplot, 6 | ) 7 | from .test import multi_gpu_test, single_gpu_test 8 | from .train import get_root_logger, set_random_seed, train_detector 9 | 10 | __all__ = [ 11 | "get_root_logger", 12 | "set_random_seed", 13 | "train_detector", 14 | "init_detector", 15 | "async_inference_detector", 16 | "inference_detector", 17 | "show_result_pyplot", 18 | "multi_gpu_test", 19 | "single_gpu_test", 20 | ] 21 | -------------------------------------------------------------------------------- /src/apis/inference.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import mmcv 4 | import numpy as np 5 | import torch 6 | from mmcv.ops import RoIPool 7 | from mmcv.parallel import collate, scatter 8 | from mmcv.runner import load_checkpoint 9 | 10 | from mmdet.core import get_classes 11 | from mmdet.datasets import replace_ImageToTensor 12 | from mmdet.datasets.pipelines import Compose 13 | from mmdet.models import build_detector 14 | 15 | 16 | def init_detector(config, checkpoint=None, device="cuda:0", cfg_options=None): 17 | """Initialize a detector from config file. 18 | 19 | Args: 20 | config (str or :obj:`mmcv.Config`): Config file path or the config 21 | object. 22 | checkpoint (str, optional): Checkpoint path. If left as None, the model 23 | will not load any weights. 24 | cfg_options (dict): Options to override some settings in the used 25 | config. 26 | 27 | Returns: 28 | nn.Module: The constructed detector. 29 | """ 30 | if isinstance(config, str): 31 | config = mmcv.Config.fromfile(config) 32 | elif not isinstance(config, mmcv.Config): 33 | raise TypeError( 34 | "config must be a filename or Config object, " f"but got {type(config)}" 35 | ) 36 | if cfg_options is not None: 37 | config.merge_from_dict(cfg_options) 38 | config.model.pretrained = None 39 | config.model.train_cfg = None 40 | model = build_detector(config.model, test_cfg=config.get("test_cfg")) 41 | if checkpoint is not None: 42 | map_loc = "cpu" if device == "cpu" else None 43 | checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) 44 | if "CLASSES" in checkpoint.get("meta", {}): 45 | model.CLASSES = checkpoint["meta"]["CLASSES"] 46 | else: 47 | warnings.simplefilter("once") 48 | warnings.warn( 49 | "Class names are not saved in the checkpoint's " 50 | "meta data, use COCO classes by default." 51 | ) 52 | model.CLASSES = get_classes("coco") 53 | model.cfg = config # save the config in the model for convenience 54 | model.to(device) 55 | model.eval() 56 | return model 57 | 58 | 59 | class LoadImage(object): 60 | """A simple pipeline to load image.""" 61 | 62 | def __call__(self, results): 63 | """Call function to load images into results. 64 | 65 | Args: 66 | results (dict): A result dict contains the file name 67 | of the image to be read. 68 | 69 | Returns: 70 | dict: ``results`` will be returned containing loaded image. 71 | """ 72 | if isinstance(results["img"], str): 73 | results["filename"] = results["img"] 74 | results["ori_filename"] = results["img"] 75 | else: 76 | results["filename"] = None 77 | results["ori_filename"] = None 78 | img = mmcv.imread(results["img"]) 79 | results["img"] = img 80 | results["img_fields"] = ["img"] 81 | results["img_shape"] = img.shape 82 | results["ori_shape"] = img.shape 83 | return results 84 | 85 | 86 | def inference_detector(model, imgs): 87 | """Inference image(s) with the detector. 88 | 89 | Args: 90 | model (nn.Module): The loaded detector. 91 | imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): 92 | Either image files or loaded images. 93 | 94 | Returns: 95 | If imgs is a list or tuple, the same length list type results 96 | will be returned, otherwise return the detection results directly. 97 | """ 98 | 99 | if isinstance(imgs, (list, tuple)): 100 | is_batch = True 101 | else: 102 | imgs = [imgs] 103 | is_batch = False 104 | 105 | cfg = model.cfg 106 | device = next(model.parameters()).device # model device 107 | 108 | if isinstance(imgs[0], np.ndarray): 109 | cfg = cfg.copy() 110 | # set loading pipeline type 111 | cfg.data.test.pipeline[0].type = "LoadImageFromWebcam" 112 | 113 | cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) 114 | test_pipeline = Compose(cfg.data.test.pipeline) 115 | 116 | datas = [] 117 | for img in imgs: 118 | # prepare data 119 | if isinstance(img, np.ndarray): 120 | # directly add img 121 | data = dict(img=img) 122 | else: 123 | # add information into dict 124 | data = dict(img_info=dict(filename=img), img_prefix=None) 125 | # build the data pipeline 126 | data = test_pipeline(data) 127 | datas.append(data) 128 | 129 | data = collate(datas, samples_per_gpu=len(imgs)) 130 | # just get the actual data from DataContainer 131 | data["img_metas"] = [img_metas.data[0] for img_metas in data["img_metas"]] 132 | data["img"] = [img.data[0] for img in data["img"]] 133 | if next(model.parameters()).is_cuda: 134 | # scatter to specified GPU 135 | data = scatter(data, [device])[0] 136 | else: 137 | for m in model.modules(): 138 | assert not isinstance( 139 | m, RoIPool 140 | ), "CPU inference with RoIPool is not supported currently." 141 | 142 | # forward the model 143 | with torch.no_grad(): 144 | results = model(return_loss=False, rescale=True, **data) 145 | 146 | if not is_batch: 147 | return results[0] 148 | else: 149 | return results 150 | 151 | 152 | async def async_inference_detector(model, img): 153 | """Async inference image(s) with the detector. 154 | 155 | Args: 156 | model (nn.Module): The loaded detector. 157 | img (str | ndarray): Either image files or loaded images. 158 | 159 | Returns: 160 | Awaitable detection results. 161 | """ 162 | cfg = model.cfg 163 | device = next(model.parameters()).device # model device 164 | # prepare data 165 | if isinstance(img, np.ndarray): 166 | # directly add img 167 | data = dict(img=img) 168 | cfg = cfg.copy() 169 | # set loading pipeline type 170 | cfg.data.test.pipeline[0].type = "LoadImageFromWebcam" 171 | else: 172 | # add information into dict 173 | data = dict(img_info=dict(filename=img), img_prefix=None) 174 | # build the data pipeline 175 | test_pipeline = Compose(cfg.data.test.pipeline) 176 | data = test_pipeline(data) 177 | data = scatter(collate([data], samples_per_gpu=1), [device])[0] 178 | 179 | # We don't restore `torch.is_grad_enabled()` value during concurrent 180 | # inference since execution can overlap 181 | torch.set_grad_enabled(False) 182 | result = await model.aforward_test(rescale=True, **data) 183 | return result 184 | 185 | 186 | def show_result_pyplot(model, img, result, score_thr=0.3, title="result", wait_time=0): 187 | """Visualize the detection results on the image. 188 | 189 | Args: 190 | model (nn.Module): The loaded detector. 191 | img (str or np.ndarray): Image filename or loaded image. 192 | result (tuple[list] or list): The detection result, can be either 193 | (bbox, segm) or just bbox. 194 | score_thr (float): The threshold to visualize the bboxes and masks. 195 | title (str): Title of the pyplot figure. 196 | wait_time (float): Value of waitKey param. 197 | Default: 0. 198 | """ 199 | if hasattr(model, "module"): 200 | model = model.module 201 | model.show_result( 202 | img, 203 | result, 204 | score_thr=score_thr, 205 | show=True, 206 | wait_time=wait_time, 207 | win_name=title, 208 | bbox_color=(72, 101, 241), 209 | text_color=(72, 101, 241), 210 | ) 211 | -------------------------------------------------------------------------------- /src/apis/test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import shutil 4 | import tempfile 5 | import time 6 | 7 | import mmcv 8 | import torch 9 | import torch.distributed as dist 10 | from mmcv.image import tensor2imgs 11 | from mmcv.runner import get_dist_info 12 | 13 | from mmdet.core import encode_mask_results 14 | 15 | 16 | def single_gpu_test(model, data_loader, show=False, out_dir=None, show_score_thr=0.3): 17 | model.eval() 18 | results = [] 19 | dataset = data_loader.dataset 20 | prog_bar = mmcv.ProgressBar(len(dataset)) 21 | for i, data in enumerate(data_loader): 22 | with torch.no_grad(): 23 | result = model(return_loss=False, rescale=True, **data) 24 | 25 | batch_size = len(result) 26 | if show or out_dir: 27 | if batch_size == 1 and isinstance(data["img"][0], torch.Tensor): 28 | img_tensor = data["img"][0] 29 | else: 30 | img_tensor = data["img"][0].data[0] 31 | img_metas = data["img_metas"][0].data[0] 32 | imgs = tensor2imgs(img_tensor, **img_metas[0]["img_norm_cfg"]) 33 | assert len(imgs) == len(img_metas) 34 | 35 | for i, (img, img_meta) in enumerate(zip(imgs, img_metas)): 36 | h, w, _ = img_meta["img_shape"] 37 | img_show = img[:h, :w, :] 38 | 39 | ori_h, ori_w = img_meta["ori_shape"][:-1] 40 | img_show = mmcv.imresize(img_show, (ori_w, ori_h)) 41 | 42 | if out_dir: 43 | out_file = osp.join(out_dir, img_meta["ori_filename"]) 44 | else: 45 | out_file = None 46 | 47 | model.module.show_result( 48 | img_show, 49 | result[i], 50 | show=show, 51 | out_file=out_file, 52 | score_thr=show_score_thr, 53 | ) 54 | 55 | # encode mask results 56 | if isinstance(result[0], tuple): 57 | result = [ 58 | (bbox_results, encode_mask_results(mask_results)) 59 | for bbox_results, mask_results in result 60 | ] 61 | results.extend(result) 62 | 63 | for _ in range(batch_size): 64 | prog_bar.update() 65 | return results 66 | 67 | 68 | def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): 69 | """Test model with multiple gpus. 70 | 71 | This method tests model with multiple gpus and collects the results 72 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' 73 | it encodes results to gpu tensors and use gpu communication for results 74 | collection. On cpu mode it saves the results on different gpus to 'tmpdir' 75 | and collects them by the rank 0 worker. 76 | 77 | Args: 78 | model (nn.Module): Model to be tested. 79 | data_loader (nn.Dataloader): Pytorch data loader. 80 | tmpdir (str): Path of directory to save the temporary results from 81 | different gpus under cpu mode. 82 | gpu_collect (bool): Option to use either gpu or cpu to collect results. 83 | 84 | Returns: 85 | list: The prediction results. 86 | """ 87 | model.eval() 88 | results = [] 89 | dataset = data_loader.dataset 90 | rank, world_size = get_dist_info() 91 | if rank == 0: 92 | prog_bar = mmcv.ProgressBar(len(dataset)) 93 | time.sleep(2) # This line can prevent deadlock problem in some cases. 94 | for i, data in enumerate(data_loader): 95 | with torch.no_grad(): 96 | result = model(return_loss=False, rescale=True, **data) 97 | # encode mask results 98 | if isinstance(result[0], tuple): 99 | result = [ 100 | (bbox_results, encode_mask_results(mask_results)) 101 | for bbox_results, mask_results in result 102 | ] 103 | results.extend(result) 104 | 105 | if rank == 0: 106 | batch_size = len(result) 107 | for _ in range(batch_size * world_size): 108 | prog_bar.update() 109 | 110 | # collect results from all ranks 111 | if gpu_collect: 112 | results = collect_results_gpu(results, len(dataset)) 113 | else: 114 | results = collect_results_cpu(results, len(dataset), tmpdir) 115 | return results 116 | 117 | 118 | def collect_results_cpu(result_part, size, tmpdir=None): 119 | rank, world_size = get_dist_info() 120 | # create a tmp dir if it is not specified 121 | if tmpdir is None: 122 | MAX_LEN = 512 123 | # 32 is whitespace 124 | dir_tensor = torch.full((MAX_LEN,), 32, dtype=torch.uint8, device="cuda") 125 | if rank == 0: 126 | mmcv.mkdir_or_exist(".dist_test") 127 | tmpdir = tempfile.mkdtemp(dir=".dist_test") 128 | tmpdir = torch.tensor( 129 | bytearray(tmpdir.encode()), dtype=torch.uint8, device="cuda" 130 | ) 131 | dir_tensor[: len(tmpdir)] = tmpdir 132 | dist.broadcast(dir_tensor, 0) 133 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() 134 | else: 135 | mmcv.mkdir_or_exist(tmpdir) 136 | # dump the part result to the dir 137 | mmcv.dump(result_part, osp.join(tmpdir, f"part_{rank}.pkl")) 138 | dist.barrier() 139 | # collect all parts 140 | if rank != 0: 141 | return None 142 | else: 143 | # load results of all parts from tmp dir 144 | part_list = [] 145 | for i in range(world_size): 146 | part_file = osp.join(tmpdir, f"part_{i}.pkl") 147 | part_list.append(mmcv.load(part_file)) 148 | # sort the results 149 | ordered_results = [] 150 | for res in zip(*part_list): 151 | ordered_results.extend(list(res)) 152 | # the dataloader may pad some samples 153 | ordered_results = ordered_results[:size] 154 | # remove tmp dir 155 | shutil.rmtree(tmpdir) 156 | return ordered_results 157 | 158 | 159 | def collect_results_gpu(result_part, size): 160 | rank, world_size = get_dist_info() 161 | # dump result part to tensor with pickle 162 | part_tensor = torch.tensor( 163 | bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device="cuda" 164 | ) 165 | # gather all result part tensor shape 166 | shape_tensor = torch.tensor(part_tensor.shape, device="cuda") 167 | shape_list = [shape_tensor.clone() for _ in range(world_size)] 168 | dist.all_gather(shape_list, shape_tensor) 169 | # padding result part tensor to max length 170 | shape_max = torch.tensor(shape_list).max() 171 | part_send = torch.zeros(shape_max, dtype=torch.uint8, device="cuda") 172 | part_send[: shape_tensor[0]] = part_tensor 173 | part_recv_list = [part_tensor.new_zeros(shape_max) for _ in range(world_size)] 174 | # gather all result part 175 | dist.all_gather(part_recv_list, part_send) 176 | 177 | if rank == 0: 178 | part_list = [] 179 | for recv, shape in zip(part_recv_list, shape_list): 180 | part_list.append(pickle.loads(recv[: shape[0]].cpu().numpy().tobytes())) 181 | # sort the results 182 | ordered_results = [] 183 | for res in zip(*part_list): 184 | ordered_results.extend(list(res)) 185 | # the dataloader may pad some samples 186 | ordered_results = ordered_results[:size] 187 | return ordered_results 188 | -------------------------------------------------------------------------------- /src/apis/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | import warnings 4 | 5 | import apex 6 | import numpy as np 7 | import torch 8 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 9 | from mmcv.runner import ( 10 | HOOKS, 11 | EpochBasedRunner, 12 | Fp16OptimizerHook, 13 | OptimizerHook, 14 | build_optimizer, 15 | build_runner, 16 | ) 17 | from mmcv.utils import build_from_cfg 18 | from mmdet.core import DistEvalHook, EvalHook 19 | from mmdet.datasets import build_dataset, replace_ImageToTensor 20 | from mmdet.utils import get_root_logger 21 | from src.datasets import build_dataloader 22 | from src.utils import ApexFP16OptimizerHook, DistDaliSamplerSeedHook, load_checkpoint 23 | 24 | 25 | def set_random_seed(seed, deterministic=False): 26 | """Set random seed. 27 | 28 | Args: 29 | seed (int): Seed to be used. 30 | deterministic (bool): Whether to set the deterministic option for 31 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 32 | to True and `torch.backends.cudnn.benchmark` to False. 33 | Default: False. 34 | """ 35 | random.seed(seed) 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | if deterministic: 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | 44 | def train_detector( 45 | model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None 46 | ): 47 | logger = get_root_logger(cfg.log_level) 48 | 49 | # prepare data loaders 50 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 51 | if "imgs_per_gpu" in cfg.data: 52 | logger.warning( 53 | '"imgs_per_gpu" is deprecated in MMDet V2.0. ' 54 | 'Please use "samples_per_gpu" instead' 55 | ) 56 | if "samples_per_gpu" in cfg.data: 57 | logger.warning( 58 | f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' 59 | f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' 60 | f"={cfg.data.imgs_per_gpu} is used in this experiments" 61 | ) 62 | else: 63 | logger.warning( 64 | 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' 65 | f"{cfg.data.imgs_per_gpu} in this experiments" 66 | ) 67 | cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu 68 | 69 | data_loaders = [ 70 | build_dataloader( 71 | ds, 72 | cfg.data.samples_per_gpu, 73 | cfg.data.workers_per_gpu, 74 | # cfg.gpus will be ignored if distributed 75 | len(cfg.gpu_ids), 76 | dist=distributed, 77 | seed=cfg.seed, 78 | loader_cfg=cfg.data.get("loader", {}).get("train", None), 79 | sampler_cfg=cfg.data.get("sampler", {}).get("train", None), 80 | ) 81 | for ds in dataset 82 | ] 83 | 84 | # build runner 85 | 86 | # fp16 setting 87 | fp16_cfg = cfg.get("fp16", None) 88 | if fp16_cfg is not None: 89 | optimizer_config = ApexFP16OptimizerHook( 90 | **cfg.optimizer_config, distributed=distributed 91 | ) 92 | optimizer = build_optimizer(model, cfg.optimizer) 93 | model, optimizer = apex.amp.initialize( 94 | model.cuda(), optimizer, opt_level="O1", **fp16_cfg, 95 | ) 96 | # put model on gpus 97 | if distributed: 98 | find_unused_parameters = cfg.get("find_unused_parameters", False) 99 | # Sets the `find_unused_parameters` parameter in 100 | # torch.nn.parallel.DistributedDataParallel 101 | model = MMDistributedDataParallel( 102 | model.cuda(), 103 | device_ids=[torch.cuda.current_device()], 104 | broadcast_buffers=False, 105 | find_unused_parameters=find_unused_parameters, 106 | ) 107 | else: 108 | model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 109 | else: 110 | # put model on gpus 111 | if distributed: 112 | find_unused_parameters = cfg.get("find_unused_parameters", False) 113 | # Sets the `find_unused_parameters` parameter in 114 | # torch.nn.parallel.DistributedDataParallel 115 | model = MMDistributedDataParallel( 116 | model.cuda(), 117 | device_ids=[torch.cuda.current_device()], 118 | broadcast_buffers=False, 119 | find_unused_parameters=find_unused_parameters, 120 | ) 121 | else: 122 | model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 123 | optimizer = build_optimizer(model, cfg.optimizer) 124 | if distributed and "type" not in cfg.optimizer_config: 125 | optimizer_config = OptimizerHook(**cfg.optimizer_config) 126 | else: 127 | optimizer_config = cfg.optimizer_config 128 | 129 | if "runner" not in cfg: 130 | cfg.runner = {"type": "EpochBasedRunner", "max_epochs": cfg.total_epochs} 131 | warnings.warn( 132 | "config is now expected to have a `runner` section, " 133 | "please set `runner` in your config.", 134 | UserWarning, 135 | ) 136 | else: 137 | if "total_epochs" in cfg: 138 | assert cfg.total_epochs == cfg.runner.max_epochs 139 | 140 | runner = build_runner( 141 | cfg.runner, 142 | default_args=dict( 143 | model=model, 144 | optimizer=optimizer, 145 | work_dir=cfg.work_dir, 146 | logger=logger, 147 | meta=meta, 148 | ), 149 | ) 150 | 151 | # an ugly workaround to make .log and .log.json filenames the same 152 | runner.timestamp = timestamp 153 | # register hooks 154 | runner.register_training_hooks( 155 | cfg.lr_config, 156 | optimizer_config, 157 | cfg.checkpoint_config, 158 | cfg.log_config, 159 | cfg.get("momentum_config", None), 160 | ) 161 | if distributed: 162 | if isinstance(runner, EpochBasedRunner): 163 | runner.register_hook(DistDaliSamplerSeedHook()) 164 | 165 | # register eval hooks 166 | if validate: 167 | # Support batch_size > 1 in validation 168 | val_samples_per_gpu = cfg.data.val.pop("samples_per_gpu", 1) 169 | if val_samples_per_gpu > 1: 170 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 171 | cfg.data.val.pipeline = replace_ImageToTensor(cfg.data.val.pipeline) 172 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 173 | val_dataloader = build_dataloader( 174 | val_dataset, 175 | samples_per_gpu=val_samples_per_gpu, 176 | workers_per_gpu=cfg.data.workers_per_gpu, 177 | dist=distributed, 178 | shuffle=False, 179 | ) 180 | eval_cfg = cfg.get("evaluation", {}) 181 | eval_cfg["by_epoch"] = cfg.runner["type"] != "IterBasedRunner" 182 | eval_hook = DistEvalHook if distributed else EvalHook 183 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 184 | 185 | # user-defined hooks 186 | if cfg.get("custom_hooks", None): 187 | custom_hooks = cfg.custom_hooks 188 | assert isinstance( 189 | custom_hooks, list 190 | ), f"custom_hooks expect list type, but got {type(custom_hooks)}" 191 | for hook_cfg in cfg.custom_hooks: 192 | assert isinstance(hook_cfg, dict), ( 193 | "Each item in custom_hooks expects dict type, but got " 194 | f"{type(hook_cfg)}" 195 | ) 196 | hook_cfg = hook_cfg.copy() 197 | priority = hook_cfg.pop("priority", "NORMAL") 198 | hook = build_from_cfg(hook_cfg, HOOKS) 199 | runner.register_hook(hook, priority=priority) 200 | 201 | if cfg.resume_from: 202 | runner.resume(cfg.resume_from) 203 | elif cfg.load_from: 204 | if isinstance(cfg.load_from, dict): 205 | if cfg.load_from.type == "partial_load": 206 | load_checkpoint( 207 | runner.model, cfg.load_from.path, prefix=cfg.load_from.prefix 208 | ) 209 | else: 210 | raise NotImplementedError() 211 | else: 212 | runner.load_checkpoint(cfg.load_from) 213 | runner.run(data_loaders, cfg.workflow) 214 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .geometric_transform import Transform2D, filter_invalid, recover_mask 2 | -------------------------------------------------------------------------------- /src/core/geometric_transform.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections.abc import Sequence 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | from mmdet.core.mask.structures import BitmapMasks 8 | from torch.nn import functional as F 9 | 10 | 11 | def bbox2points(box): 12 | min_x, min_y, max_x, max_y = torch.split(box[:, :4], [1, 1, 1, 1], dim=1) 13 | cx = min_x * 0.5 + max_x * 0.5 14 | cy = min_y * 0.5 + max_y * 0.5 15 | return torch.cat([cx, min_y, max_x, cy, cx, max_y, min_x, cy], dim=1).reshape( 16 | -1, 2 17 | ) # n*4,2 18 | 19 | 20 | def points2bbox(point, max_w, max_h): 21 | point = point.reshape(-1, 4, 2) 22 | if point.size()[0] > 0: 23 | min_xy = point.min(dim=1)[0] 24 | max_xy = point.max(dim=1)[0] 25 | xmin = min_xy[:, 0].clamp(min=0, max=max_w) 26 | ymin = min_xy[:, 1].clamp(min=0, max=max_h) 27 | xmax = max_xy[:, 0].clamp(min=0, max=max_w) 28 | ymax = max_xy[:, 1].clamp(min=0, max=max_h) 29 | min_xy = torch.stack([xmin, ymin], dim=1) 30 | max_xy = torch.stack([xmax, ymax], dim=1) 31 | return torch.cat([min_xy, max_xy], dim=1) # n,4 32 | else: 33 | return point.new_zeros(0, 4) 34 | 35 | 36 | def check_is_tensor(obj): 37 | """Checks whether the supplied object is a tensor.""" 38 | if not isinstance(obj, torch.Tensor): 39 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(obj))) 40 | 41 | 42 | def normal_transform_pixel( 43 | height: int, 44 | width: int, 45 | eps: float = 1e-14, 46 | device: Optional[torch.device] = None, 47 | dtype: Optional[torch.dtype] = None, 48 | ) -> torch.Tensor: 49 | tr_mat = torch.tensor( 50 | [[1.0, 0.0, -1.0], [0.0, 1.0, -1.0], [0.0, 0.0, 1.0]], 51 | device=device, 52 | dtype=dtype, 53 | ) # 3x3 54 | 55 | # prevent divide by zero bugs 56 | width_denom: float = eps if width == 1 else width - 1.0 57 | height_denom: float = eps if height == 1 else height - 1.0 58 | 59 | tr_mat[0, 0] = tr_mat[0, 0] * 2.0 / width_denom 60 | tr_mat[1, 1] = tr_mat[1, 1] * 2.0 / height_denom 61 | 62 | return tr_mat.unsqueeze(0) # 1x3x3 63 | 64 | 65 | def normalize_homography( 66 | dst_pix_trans_src_pix: torch.Tensor, 67 | dsize_src: Tuple[int, int], 68 | dsize_dst: Tuple[int, int], 69 | ) -> torch.Tensor: 70 | check_is_tensor(dst_pix_trans_src_pix) 71 | 72 | if not ( 73 | len(dst_pix_trans_src_pix.shape) == 3 74 | or dst_pix_trans_src_pix.shape[-2:] == (3, 3) 75 | ): 76 | raise ValueError( 77 | "Input dst_pix_trans_src_pix must be a Bx3x3 tensor. Got {}".format( 78 | dst_pix_trans_src_pix.shape 79 | ) 80 | ) 81 | 82 | # source and destination sizes 83 | src_h, src_w = dsize_src 84 | dst_h, dst_w = dsize_dst 85 | 86 | # compute the transformation pixel/norm for src/dst 87 | src_norm_trans_src_pix: torch.Tensor = normal_transform_pixel(src_h, src_w).to( 88 | dst_pix_trans_src_pix 89 | ) 90 | src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix.float()).to( 91 | src_norm_trans_src_pix.dtype 92 | ) 93 | dst_norm_trans_dst_pix: torch.Tensor = normal_transform_pixel(dst_h, dst_w).to( 94 | dst_pix_trans_src_pix 95 | ) 96 | 97 | # compute chain transformations 98 | dst_norm_trans_src_norm: torch.Tensor = dst_norm_trans_dst_pix @ ( 99 | dst_pix_trans_src_pix @ src_pix_trans_src_norm 100 | ) 101 | return dst_norm_trans_src_norm 102 | 103 | 104 | def warp_affine( 105 | src: torch.Tensor, 106 | M: torch.Tensor, 107 | dsize: Tuple[int, int], 108 | mode: str = "bilinear", 109 | padding_mode: str = "zeros", 110 | align_corners: Optional[bool] = None, 111 | ) -> torch.Tensor: 112 | if not isinstance(src, torch.Tensor): 113 | raise TypeError( 114 | "Input src type is not a torch.Tensor. Got {}".format(type(src)) 115 | ) 116 | 117 | if not isinstance(M, torch.Tensor): 118 | raise TypeError("Input M type is not a torch.Tensor. Got {}".format(type(M))) 119 | 120 | if not len(src.shape) == 4: 121 | raise ValueError("Input src must be a BxCxHxW tensor. Got {}".format(src.shape)) 122 | 123 | if not (len(M.shape) == 3 or M.shape[-2:] == (2, 3)): 124 | raise ValueError("Input M must be a Bx2x3 tensor. Got {}".format(M.shape)) 125 | 126 | # TODO: remove the statement below in kornia v0.6 127 | if align_corners is None: 128 | message: str = ( 129 | "The align_corners default value has been changed. By default now is set True " 130 | "in order to match cv2.warpAffine." 131 | ) 132 | warnings.warn(message) 133 | # set default value for align corners 134 | align_corners = True 135 | 136 | B, C, H, W = src.size() 137 | 138 | # we generate a 3x3 transformation matrix from 2x3 affine 139 | 140 | dst_norm_trans_src_norm: torch.Tensor = normalize_homography(M, (H, W), dsize) 141 | 142 | src_norm_trans_dst_norm = torch.inverse(dst_norm_trans_src_norm.float()).to( 143 | dst_norm_trans_src_norm.dtype 144 | ) 145 | 146 | grid = F.affine_grid( 147 | src_norm_trans_dst_norm[:, :2, :], 148 | [B, C, dsize[0], dsize[1]], 149 | align_corners=align_corners, 150 | ) 151 | 152 | return F.grid_sample( 153 | src, grid, align_corners=align_corners, mode=mode, padding_mode=padding_mode 154 | ) 155 | 156 | 157 | class Transform2D: 158 | @staticmethod 159 | def transform_bboxes(bbox, M, out_shape): 160 | if isinstance(bbox, Sequence): 161 | assert len(bbox) == len(M) 162 | return [ 163 | Transform2D.transform_bboxes(b, m, o) 164 | for b, m, o in zip(bbox, M, out_shape) 165 | ] 166 | else: 167 | if bbox.shape[0] == 0: 168 | return bbox 169 | score = None 170 | if bbox.shape[1] > 4: 171 | score = bbox[:, 4:] 172 | points = bbox2points(bbox[:, :4]) 173 | points = torch.cat( 174 | [points, points.new_ones(points.shape[0], 1)], dim=1 175 | ) # n,3 176 | points = torch.matmul(M, points.t()).t() 177 | points = points[:, :2] / points[:, 2:3] 178 | bbox = points2bbox(points, out_shape[1], out_shape[0]) 179 | if score is not None: 180 | return torch.cat([bbox, score], dim=1) 181 | return bbox 182 | 183 | @staticmethod 184 | def transform_masks( 185 | mask: Union[BitmapMasks, List[BitmapMasks]], 186 | M: Union[torch.Tensor, List[torch.Tensor]], 187 | out_shape: Union[list, List[list]], 188 | ): 189 | if isinstance(mask, Sequence): 190 | assert len(mask) == len(M) 191 | return [ 192 | Transform2D.transform_masks(b, m, o) 193 | for b, m, o in zip(mask, M, out_shape) 194 | ] 195 | else: 196 | if mask.masks.shape[0] == 0: 197 | return BitmapMasks(np.zeros((0, *out_shape)), *out_shape) 198 | mask_tensor = ( 199 | torch.from_numpy(mask.masks[:, None, ...]).to(M.device).to(M.dtype) 200 | ) 201 | return BitmapMasks( 202 | warp_affine( 203 | mask_tensor, 204 | M[None, ...].expand(mask.masks.shape[0], -1, -1), 205 | out_shape, 206 | ) 207 | .squeeze(1) 208 | .cpu() 209 | .numpy(), 210 | out_shape[0], 211 | out_shape[1], 212 | ) 213 | 214 | @staticmethod 215 | def transform_image(img, M, out_shape, mode="nearest"): 216 | if isinstance(img, Sequence): 217 | assert len(img) == len(M) 218 | return [ 219 | Transform2D.transform_image(b, m, shape, mode) 220 | for b, m, shape in zip(img, M, out_shape) 221 | ] 222 | else: 223 | if img.dim() == 2: 224 | img = img[None, None, ...] 225 | elif img.dim() == 3: 226 | img = img[None, ...] 227 | 228 | return warp_affine( 229 | img.float(), M[None, ...], out_shape, mode=mode 230 | ).squeeze() 231 | 232 | 233 | def filter_invalid(bbox, label, score=None, mask=None, thr=0.0, min_size=0): 234 | if (score is not None) and (thr > 0): 235 | valid = score > thr 236 | bbox = bbox[valid] 237 | label = label[valid] 238 | if mask is not None: 239 | mask = BitmapMasks(mask.masks[valid.cpu().numpy()], mask.height, mask.width) 240 | if min_size is not None: 241 | bw = bbox[:, 2] - bbox[:, 0] 242 | bh = bbox[:, 3] - bbox[:, 1] 243 | valid = (bw > min_size) & (bh > min_size) 244 | bbox = bbox[valid] 245 | label = label[valid] 246 | if mask is not None: 247 | mask = BitmapMasks(mask.masks[valid.cpu().numpy()], mask.height, mask.width) 248 | return bbox, label, mask 249 | 250 | 251 | def recover_mask(bitmask, meta, scale=1.0): 252 | pad_mask = F.interpolate(bitmask[None, None, ...].float(), scale_factor=scale) 253 | ih, iw = meta["img_shape"][:2] 254 | mask = pad_mask[:, :, :ih, :iw] 255 | return F.interpolate(mask, meta["ori_shape"][:2])[0, 0].long(), None 256 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .backends import * 2 | from .builder import build_dataloader, build_dataset 3 | from .dataset_wrappers import MultiSourceDataset 4 | from .pipeline import * 5 | from .samplers import * 6 | -------------------------------------------------------------------------------- /src/datasets/backends/__init__.py: -------------------------------------------------------------------------------- 1 | from .zip_backends import ZipBackend 2 | from .mem_backends import MemCachedV2Backend 3 | -------------------------------------------------------------------------------- /src/datasets/backends/_utils.py: -------------------------------------------------------------------------------- 1 | from mmcv.parallel import DataContainer 2 | from torch.utils.data._utils.pin_memory import pin_memory 3 | 4 | 5 | class TensorlikeDataContainer(DataContainer): 6 | def pin_memory(self): 7 | pin_memory(self._data) 8 | -------------------------------------------------------------------------------- /src/datasets/backends/mem_backends.py: -------------------------------------------------------------------------------- 1 | import os 2 | from mmcv import FileClient, BaseStorageBackend 3 | 4 | try: 5 | import memcache 6 | except: 7 | memcached = None 8 | 9 | 10 | class MemCachedV2Backend(BaseStorageBackend): 11 | """ 12 | Only single image directory is supported 13 | """ 14 | 15 | def __init__(self, server): 16 | self.client = memcache.Client([server], debug=True) 17 | 18 | def get(self, filepath): 19 | value_buf = self.client.get(filepath) 20 | if value_buf is None: 21 | raise ValueError(f"{filepath} does not exist in memory") 22 | return value_buf 23 | 24 | def get_text(self, filepath): 25 | raise NotImplementedError 26 | 27 | 28 | FileClient.register_backend("memcached_v2", MemCachedV2Backend) 29 | -------------------------------------------------------------------------------- /src/datasets/backends/zip_backends.py: -------------------------------------------------------------------------------- 1 | import os 2 | from mmcv import FileClient, BaseStorageBackend 3 | from zipfile import ZipFile 4 | 5 | 6 | class ZipBackend(BaseStorageBackend): 7 | """ 8 | Only single image directory is supported 9 | """ 10 | 11 | def __init__(self, zip_file_name=None): 12 | if zip_file_name is not None: 13 | self.zip_file = ZipFile(zip_file_name, mode="r") 14 | self.root_prefix = self.zip_file.namelist()[0] 15 | else: 16 | self.zip_file = None 17 | self.root_prefix = None 18 | print("Use Zip Backends") 19 | 20 | def get(self, filepath): 21 | file_name = None 22 | zip_name = None 23 | if ".zip" in filepath: 24 | zip_name, file_name = filepath.split(".zip/") 25 | zip_name = zip_name + ".zip" 26 | if self.zip_file is None: 27 | if zip_name is None: 28 | zip_name = os.path.dirname(filepath) + ".zip" 29 | if not os.path.exists(zip_name): 30 | raise FileNotFoundError(f"There is no zip file in {zip_name}") 31 | else: 32 | print(f"Load Zip File{zip_name}") 33 | self.zip_file = ZipFile(zip_name, mode="r") 34 | # print(self.zip_file.namelist()[:10]) 35 | self.root_prefix = self.zip_file.namelist()[0] 36 | else: 37 | assert isinstance(self.zip_file, ZipFile), "Error: no such zip file." 38 | if file_name is None: 39 | file_name = self.root_prefix + os.path.basename(filepath) 40 | value_buf = self.zip_file.read(file_name) 41 | return value_buf 42 | 43 | def get_text(self, filepath): 44 | raise NotImplementedError 45 | 46 | 47 | FileClient.register_backend("zip", ZipBackend) 48 | -------------------------------------------------------------------------------- /src/datasets/builder.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | from functools import partial 3 | 4 | import torch 5 | from mmcv.parallel import DataContainer 6 | from mmcv.runner import get_dist_info 7 | from mmcv.utils import Registry, build_from_cfg 8 | from mmdet.datasets import build_dataset as _base_dataset 9 | from mmdet.datasets.builder import worker_init_fn 10 | from torch.nn import functional as F 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | try: 15 | from .dali_loader import DaliDataLoader 16 | except: 17 | DaliDataLoader = None 18 | SAMPLERS = Registry("sampler") 19 | 20 | 21 | def build_sampler(cfg, dist=False, group=False, default_args=None): 22 | if cfg and ("type" in cfg): 23 | sampler_type = cfg.get("type") 24 | else: 25 | sampler_type = default_args.get("type") 26 | if group: 27 | sampler_type = "Group" + sampler_type 28 | if dist: 29 | sampler_type = "Distributed" + sampler_type 30 | 31 | if cfg: 32 | cfg.update(type=sampler_type) 33 | else: 34 | cfg = dict(type=sampler_type) 35 | 36 | return build_from_cfg(cfg, SAMPLERS, default_args) 37 | 38 | 39 | def build_dataset(cfg, default_args=None): 40 | extension = cfg.pop("dali_extension", {}) 41 | dataset = _base_dataset(cfg, default_args=default_args) 42 | for name, ext in extension.items(): 43 | setattr(dataset, name, ext) 44 | return dataset 45 | 46 | 47 | def build_dataloader( 48 | dataset, 49 | samples_per_gpu, 50 | workers_per_gpu, 51 | num_gpus=1, 52 | dist=True, 53 | shuffle=True, 54 | seed=None, 55 | loader_cfg=None, 56 | sampler_cfg=None, 57 | **kwargs, 58 | ): 59 | rank, world_size = get_dist_info() 60 | default_sampler_cfg = dict(type="Sampler", dataset=dataset,) 61 | if shuffle: 62 | default_sampler_cfg.update(samples_per_gpu=samples_per_gpu) 63 | else: 64 | default_sampler_cfg.update(shuffle=False) 65 | if dist: 66 | default_sampler_cfg.update(num_replicas=world_size, rank=rank) 67 | sampler = build_sampler(sampler_cfg, dist, shuffle, default_sampler_cfg) 68 | 69 | batch_size = samples_per_gpu 70 | num_workers = workers_per_gpu 71 | else: 72 | sampler = build_sampler(sampler_cfg, default_sampler_cfg) if shuffle else None 73 | batch_size = num_gpus * samples_per_gpu 74 | num_workers = num_gpus * workers_per_gpu 75 | 76 | if loader_cfg is None: 77 | loader_cfg = dict() 78 | loader_mode = loader_cfg.get("mode", "normal") 79 | if loader_mode == "normal": 80 | init_fn = ( 81 | partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) 82 | if seed is not None 83 | else None 84 | ) 85 | 86 | data_loader = DataLoader( 87 | dataset, 88 | batch_size=batch_size, 89 | sampler=sampler, 90 | num_workers=num_workers, 91 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu, flatten=True), 92 | pin_memory=False, 93 | worker_init_fn=init_fn, 94 | **kwargs, 95 | ) 96 | elif loader_mode == "dali": 97 | 98 | data_loader = DaliDataLoader( 99 | dataset, 100 | batch_size=batch_size, 101 | sampler=sampler, 102 | num_workers=num_workers, 103 | pipeline_cfg=dataset.pipeline_cfg, 104 | **kwargs, 105 | ) 106 | return data_loader 107 | 108 | 109 | def collate(batch, samples_per_gpu=1, flatten=False): 110 | """Puts each data field into a tensor/DataContainer with outer dimension 111 | batch size. 112 | 113 | Extend default_collate to add support for 114 | :type:`~mmcv.parallel.DataContainer`. There are 3 cases. 115 | 116 | 1. cpu_only = True, e.g., meta data 117 | 2. cpu_only = False, stack = True, e.g., images tensors 118 | 3. cpu_only = False, stack = False, e.g., gt bboxes 119 | """ 120 | if not isinstance(batch, Sequence): 121 | raise TypeError(f"{batch.dtype} is not supported.") 122 | 123 | if isinstance(batch[0], DataContainer): 124 | stacked = [] 125 | if batch[0].cpu_only: 126 | for i in range(0, len(batch), samples_per_gpu): 127 | stacked.append( 128 | [sample.data for sample in batch[i : i + samples_per_gpu]] 129 | ) 130 | return DataContainer( 131 | stacked, batch[0].stack, batch[0].padding_value, cpu_only=True 132 | ) 133 | elif batch[0].stack: 134 | for i in range(0, len(batch), samples_per_gpu): 135 | assert isinstance(batch[i].data, torch.Tensor) 136 | 137 | if batch[i].pad_dims is not None: 138 | ndim = batch[i].dim() 139 | assert ndim > batch[i].pad_dims 140 | max_shape = [0 for _ in range(batch[i].pad_dims)] 141 | for dim in range(1, batch[i].pad_dims + 1): 142 | max_shape[dim - 1] = batch[i].size(-dim) 143 | for sample in batch[i : i + samples_per_gpu]: 144 | for dim in range(0, ndim - batch[i].pad_dims): 145 | assert batch[i].size(dim) == sample.size(dim) 146 | for dim in range(1, batch[i].pad_dims + 1): 147 | max_shape[dim - 1] = max( 148 | max_shape[dim - 1], sample.size(-dim) 149 | ) 150 | padded_samples = [] 151 | for sample in batch[i : i + samples_per_gpu]: 152 | pad = [0 for _ in range(batch[i].pad_dims * 2)] 153 | for dim in range(1, batch[i].pad_dims + 1): 154 | pad[2 * dim - 1] = max_shape[dim - 1] - sample.size(-dim) 155 | padded_samples.append( 156 | F.pad(sample.data, pad, value=sample.padding_value) 157 | ) 158 | stacked.append(default_collate(padded_samples)) 159 | elif batch[i].pad_dims is None: 160 | stacked.append( 161 | default_collate( 162 | [sample.data for sample in batch[i : i + samples_per_gpu]] 163 | ) 164 | ) 165 | else: 166 | raise ValueError("pad_dims should be either None or integers (1-3)") 167 | 168 | else: 169 | for i in range(0, len(batch), samples_per_gpu): 170 | stacked.append( 171 | [sample.data for sample in batch[i : i + samples_per_gpu]] 172 | ) 173 | return DataContainer(stacked, batch[0].stack, batch[0].padding_value) 174 | elif any([isinstance(b, Sequence) for b in batch]): 175 | if flatten: 176 | flattened = [] 177 | for b in batch: 178 | if isinstance(b, Sequence): 179 | flattened.extend(b) 180 | else: 181 | flattened.extend([b]) 182 | return collate(flattened, len(flattened)) 183 | else: 184 | transposed = zip(*batch) 185 | return [collate(samples, samples_per_gpu) for samples in transposed] 186 | elif isinstance(batch[0], Mapping): 187 | return { 188 | key: collate([d[key] for d in batch], samples_per_gpu) for key in batch[0] 189 | } 190 | else: 191 | return default_collate(batch) 192 | -------------------------------------------------------------------------------- /src/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | from mmdet.datasets import DATASETS, build_dataset 2 | from mmdet.datasets.dataset_wrappers import ConcatDataset 3 | 4 | 5 | @DATASETS.register_module() 6 | class MultiSourceDataset(ConcatDataset): 7 | def __init__(self, datasets, sample_ratio): 8 | if not isinstance(datasets, list): 9 | datasets = [datasets] 10 | if isinstance(datasets[0], dict): 11 | datasets = [build_dataset(d) for d in datasets] 12 | super().__init__(datasets) 13 | self.sample_ratio = sample_ratio 14 | 15 | def expand_index(self, index): 16 | pass 17 | -------------------------------------------------------------------------------- /src/datasets/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .loading import LoadByteFromFile 2 | from .formating import PlainCollect 3 | from .transform import * 4 | from .immediate_transform import * 5 | -------------------------------------------------------------------------------- /src/datasets/pipeline/formating.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmdet.datasets import PIPELINES 3 | from mmdet.datasets.pipelines.formating import Collect 4 | import numpy as np 5 | 6 | 7 | @PIPELINES.register_module() 8 | class ExtraAttrs(object): 9 | def __init__(self, **attrs): 10 | self.attrs = attrs 11 | 12 | def __call__(self, results): 13 | for k, v in self.attrs.items(): 14 | assert k not in results 15 | results[k] = v 16 | return results 17 | 18 | 19 | @PIPELINES.register_module() 20 | class PlainCollect(object): 21 | """Collect data from the loader relevant to the specific task. 22 | 23 | This is usually the last stage of the data loader pipeline. Typically keys 24 | is set to some subset of "img", "proposals", "gt_bboxes", 25 | "gt_bboxes_ignore", "gt_labels", and/or "gt_masks". 26 | Args: 27 | keys (Sequence[str]): Keys of results to be collected in ``data``. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | keys=[ 33 | "img", 34 | "gt_bboxes", 35 | "gt_labels", 36 | "filename", 37 | "trans_matrix", 38 | "ori_filename", 39 | "flip", 40 | "flip_direction", 41 | "img_shape", 42 | "scale_factor", 43 | ], 44 | extra_keys=[], 45 | ): 46 | self.keys = keys + extra_keys 47 | 48 | def __call__(self, results): 49 | """Call function to collect keys in results. The keys in ``meta_keys`` 50 | will be converted to :obj:mmcv.DataContainer. 51 | 52 | Args: 53 | results (dict): Result dict contains the data to collect. 54 | 55 | Returns: 56 | dict: The result dict contains the following keys 57 | 58 | - keys in``self.keys`` 59 | - ``img_metas`` 60 | """ 61 | data = {} 62 | for key in self.keys: 63 | data[key] = results[key] 64 | return data 65 | 66 | def __repr__(self): 67 | return self.__class__.__name__ + f"(keys={self.keys}" 68 | 69 | 70 | @PIPELINES.register_module() 71 | class CollectV1(Collect): 72 | def __init__(self, *args, extra_meta_keys=[], **kwargs): 73 | super().__init__(*args, **kwargs) 74 | self.meta_keys = self.meta_keys + tuple(extra_meta_keys) 75 | 76 | 77 | @PIPELINES.register_module() 78 | class PseudoSamples(object): 79 | def __init__(self, with_bbox=False, with_mask=False, with_seg=False, override=True): 80 | self.with_bbox = with_bbox 81 | self.with_mask = with_mask 82 | self.with_seg = with_seg 83 | self.override = override 84 | 85 | def __call__(self, results): 86 | if self.with_bbox: 87 | if self.override and "gt_bboxes" in results: 88 | results.pop("gt_bboxes") 89 | if "gt_bboxes" not in results: 90 | results["gt_bboxes"] = np.zeros((0, 4)) 91 | results["gt_labels"] = np.zeros((0,)) 92 | if "bbox_fields" not in results: 93 | results["bbox_fields"] = [] 94 | if "gt_bboxes" not in results["bbox_fields"]: 95 | results["bbox_fields"].append("gt_bboxes") 96 | if self.with_mask: 97 | if self.override and "gt_masks" in results: 98 | results.pop("gt_masks") 99 | if "gt_masks" not in results: 100 | # TODO: keep consistent with original pipeline, use Bitmasks 101 | results["gt_masks"] = np.zeros((0, 1, 1)) 102 | 103 | if "mask_fields" not in results: 104 | results["mask_fields"] = [] 105 | if "gt_masks" not in results["mask_fields"]: 106 | results["mask_fields"].append("gt_masks") 107 | if self.with_seg: 108 | if self.override and "gt_semantic_seg" in results: 109 | results.pop("gt_semantic_seg") 110 | if "gt_semantic_seg" not in results: 111 | results["gt_semantic_seg"] = 255 * np.ones( 112 | results["img"].shape[:2], dtype=np.uint8 113 | ) 114 | if "seg_fields" not in results: 115 | results["seg_fields"] = [] 116 | if "gt_semantic_seg" not in results["seg_fields"]: 117 | results["seg_fields"].append("gt_semantic_seg") 118 | return results 119 | -------------------------------------------------------------------------------- /src/datasets/pipeline/loading.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import mmcv 3 | import numpy as np 4 | from mmdet.datasets import PIPELINES 5 | 6 | 7 | @PIPELINES.register_module() 8 | class LoadByteFromFile(object): 9 | """Load an image from file. 10 | 11 | Required keys are "img_prefix" and "img_info" (a dict that must contain the 12 | key "filename"). Added or updated keys are "filename", "img", "img_shape", 13 | "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), 14 | "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). 15 | 16 | Args: 17 | to_float32 (bool): Whether to convert the loaded image to a float32 18 | numpy array. If set to False, the loaded image is an uint8 array. 19 | Defaults to False. 20 | color_type (str): The flag argument for :func:`mmcv.imfrombytes`. 21 | Defaults to 'color'. 22 | file_client_args (dict): Arguments to instantiate a FileClient. 23 | See :class:`mmcv.fileio.FileClient` for details. 24 | Defaults to ``dict(backend='disk')``. 25 | """ 26 | 27 | def __init__(self, file_client_args=dict(backend="disk")): 28 | self.file_client_args = file_client_args.copy() 29 | self.file_client = None 30 | 31 | def __call__(self, results): 32 | """Call functions to load image and get image meta information. 33 | 34 | Args: 35 | results (dict): Result dict from :obj:`mmdet.CustomDataset`. 36 | 37 | Returns: 38 | dict: The dict contains loaded image and meta information. 39 | """ 40 | if self.file_client is None: 41 | self.file_client = mmcv.FileClient(**self.file_client_args) 42 | if results["img_prefix"] is not None: 43 | filename = osp.join(results["img_prefix"], results["img_info"]["filename"]) 44 | else: 45 | filename = results["img_info"]["filename"] 46 | 47 | img_bytes = self.file_client.get(filename) 48 | results["img"] = np.frombuffer(img_bytes, dtype=np.uint8) 49 | if "height" in results["img_info"] or "width" in results["img_info"]: 50 | results["ori_shape"] = np.array( 51 | [results["img_info"]["height"], results["img_info"]["width"]], 52 | dtype=np.int32, 53 | ) 54 | results["img_shape"] = results["ori_shape"] 55 | results["pad_shape"] = results["img_shape"] 56 | results["ori_filename"] = results["img_info"]["filename"] 57 | results["filename"] = filename 58 | return results 59 | 60 | def __repr__(self): 61 | repr_str = ( 62 | f"{self.__class__.__name__}(" 63 | f"to_float32={self.to_float32}, " 64 | f"color_type='{self.color_type}', " 65 | f"file_client_args={self.file_client_args})" 66 | ) 67 | return repr_str 68 | -------------------------------------------------------------------------------- /src/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed_sampler import DistributedSampler 2 | from .group_sampler import DistributedGroupSampler, GroupSampler 3 | from .balance_sampler import DistributedGroupSemiBalanceSampler 4 | 5 | __all__ = [ 6 | "DistributedSampler", 7 | "DistributedGroupSampler", 8 | "GroupSampler", 9 | "DistributedGroupSemiBalanceSampler", 10 | ] 11 | -------------------------------------------------------------------------------- /src/datasets/samplers/balance_sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import torch 4 | from mmcv.runner import get_dist_info 5 | from torch.utils.data import Sampler, WeightedRandomSampler 6 | from ..builder import SAMPLERS 7 | import pdb 8 | 9 | 10 | def repeat_choice(seq, size): 11 | repeat_factor = int(size // len(seq)) 12 | extra_num = size % len(seq) 13 | selected = [seq for _ in range(repeat_factor)] 14 | selected.append(seq[:extra_num]) 15 | selected = np.concatenate(selected) 16 | return selected 17 | 18 | 19 | @SAMPLERS.register_module() 20 | class DistributedGroupSemiBalanceSampler(Sampler): 21 | def __init__( 22 | self, 23 | dataset, 24 | by_prob=False, 25 | at_least_one=True, 26 | epoch_length=7330, 27 | samples_per_gpu=1, 28 | num_replicas=None, 29 | rank=None, 30 | sample_ratio=1, 31 | ): 32 | self.by_prob = by_prob 33 | self.at_least_one = at_least_one 34 | _rank, _num_replicas = get_dist_info() 35 | if num_replicas is None: 36 | num_replicas = _num_replicas 37 | if rank is None: 38 | rank = _rank 39 | self.dataset = dataset 40 | self.samples_per_gpu = samples_per_gpu 41 | self.num_replicas = num_replicas 42 | self.rank = rank 43 | self.epoch = 0 44 | assert hasattr(self.dataset, "flag") 45 | self.flag = self.dataset.flag 46 | self.group_sizes = np.bincount(self.flag) 47 | self.num_samples = 0 48 | self.cumulative_sizes = dataset.cumulative_sizes 49 | # decide the frequency to sample each kind of datasets 50 | if not isinstance(sample_ratio, list): 51 | sample_ratio = [sample_ratio] * len(self.cumulative_sizes) 52 | self.sample_ratio = sample_ratio 53 | self.sample_ratio = [ 54 | int(sr / min(self.sample_ratio)) for sr in self.sample_ratio 55 | ] 56 | self.size_of_dataset = [] 57 | cumulative_sizes = [0] + self.cumulative_sizes 58 | print(cumulative_sizes) 59 | for i, _ in enumerate(self.group_sizes): 60 | size_of_dataset = 0 61 | cur_group_inds = np.where(self.flag == i)[0] 62 | for j in range(len(self.cumulative_sizes)): 63 | cur_group_cur_dataset = np.where( 64 | np.logical_and( 65 | cur_group_inds > cumulative_sizes[j], 66 | cur_group_inds < cumulative_sizes[j + 1], 67 | ) 68 | )[0] 69 | size_per_dataset = len(cur_group_cur_dataset) 70 | size_of_dataset = max( 71 | size_of_dataset, np.ceil(size_per_dataset / self.sample_ratio[j]) 72 | ) 73 | 74 | self.size_of_dataset.append( 75 | int(np.ceil(size_of_dataset / self.samples_per_gpu / self.num_replicas)) 76 | * self.samples_per_gpu 77 | ) 78 | for j in range(len(self.cumulative_sizes)): 79 | self.num_samples += self.size_of_dataset[-1] * self.sample_ratio[j] 80 | 81 | self.total_size = self.num_samples * self.num_replicas 82 | group_factor = [g / sum(self.group_sizes) for g in self.group_sizes] 83 | self.epoch_length = [int(np.round(gf * epoch_length)) for gf in group_factor] 84 | self.epoch_length[-1] = epoch_length - sum(self.epoch_length[:-1]) 85 | print(self.group_sizes, self.epoch_length) 86 | 87 | def __iter__(self): 88 | # deterministically shuffle based on epoch 89 | g = torch.Generator() 90 | g.manual_seed(self.epoch) 91 | indices = [] 92 | cumulative_sizes = [0] + self.cumulative_sizes 93 | for i, size in enumerate(self.group_sizes): 94 | if size > 0: 95 | indice = np.where(self.flag == i)[0] 96 | assert len(indice) == size 97 | indice_per_dataset = [] 98 | 99 | for j in range(len(self.cumulative_sizes)): 100 | indice_per_dataset.append( 101 | indice[ 102 | np.where( 103 | np.logical_and( 104 | indice >= cumulative_sizes[j], 105 | indice < cumulative_sizes[j + 1], 106 | ) 107 | )[0] 108 | ] 109 | ) 110 | 111 | shuffled_indice_per_dataset = [ 112 | s[list(torch.randperm(int(s.shape[0]), generator=g).numpy())] 113 | for s in indice_per_dataset 114 | ] 115 | # split into 116 | total_indice = [] 117 | batch_idx = 0 118 | # pdb.set_trace() 119 | while batch_idx < self.epoch_length[i] * self.num_replicas: 120 | ratio = [x / sum(self.sample_ratio) for x in self.sample_ratio] 121 | if self.by_prob: 122 | indicator = list( 123 | WeightedRandomSampler( 124 | ratio, 125 | self.samples_per_gpu, 126 | replacement=True, 127 | generator=g, 128 | ) 129 | ) 130 | unique, counts = np.unique(indicator, return_counts=True) 131 | ratio = [0] * len(shuffled_indice_per_dataset) 132 | for u, c in zip(unique, counts): 133 | ratio[u] = c 134 | assert len(ratio) == 2, "Only two set is suppoted" 135 | if self.at_least_one: 136 | if ratio[0] == 0: 137 | ratio[0] = 1 138 | ratio[1] -= 1 139 | elif ratio[1] == 0: 140 | ratio[1] = 1 141 | ratio[0] -= 1 142 | 143 | ratio = [r / sum(ratio) for r in ratio] 144 | 145 | # num of each dataset 146 | ratio = [int(r * self.samples_per_gpu) for r in ratio] 147 | 148 | ratio[-1] = self.samples_per_gpu - sum(ratio[:-1]) 149 | selected = [] 150 | # print(ratio) 151 | for j in range(len(shuffled_indice_per_dataset)): 152 | if len(shuffled_indice_per_dataset[j]) < ratio[j]: 153 | shuffled_indice_per_dataset[j] = np.concatenate( 154 | ( 155 | shuffled_indice_per_dataset[j], 156 | indice_per_dataset[j][ 157 | list( 158 | torch.randperm( 159 | int(indice_per_dataset[j].shape[0]), 160 | generator=g, 161 | ).numpy() 162 | ) 163 | ], 164 | ) 165 | ) 166 | 167 | selected.append(shuffled_indice_per_dataset[j][: ratio[j]]) 168 | shuffled_indice_per_dataset[j] = shuffled_indice_per_dataset[j][ 169 | ratio[j] : 170 | ] 171 | selected = np.concatenate(selected) 172 | # real_names = [] 173 | # for m,r in enumerate(ratio): 174 | # real_names.extend([self.dataset.keys[m]]*r) 175 | 176 | # for n in range(len(selected)): 177 | # name = self.dataset.get_dataset_name(selected[n]) 178 | # assert real_names[n]==name,"{}:{}".format(selected,[self.dataset.get_dataset_name(s) for s in selected]) 179 | 180 | total_indice.append(selected) 181 | batch_idx += 1 182 | # print(self.size_of_dataset) 183 | indice = np.concatenate(total_indice) 184 | indices.append(indice) 185 | indices = np.concatenate(indices) # k 186 | indices = [ 187 | indices[j] 188 | for i in list( 189 | torch.randperm(len(indices) // self.samples_per_gpu, generator=g,) 190 | ) 191 | for j in range(i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu,) 192 | ] 193 | 194 | offset = len(self) * self.rank 195 | indices = indices[offset : offset + len(self)] 196 | assert len(indices) == len(self) 197 | return iter(indices) 198 | 199 | def __len__(self): 200 | return sum(self.epoch_length) * self.samples_per_gpu 201 | 202 | def update_sample_ratio(self, iteration): 203 | step = self.epoch * self.epoch_length + iteration 204 | if self.dynamic is not None: 205 | self.sample_ratio = [d(step) for d in self.dynamic] 206 | 207 | def set_epoch(self, epoch): 208 | self.epoch = epoch 209 | -------------------------------------------------------------------------------- /src/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.utils.data import DistributedSampler as _DistributedSampler 5 | from ..builder import SAMPLERS 6 | 7 | 8 | @SAMPLERS.register_module() 9 | class DistributedSampler(_DistributedSampler): 10 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 11 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 12 | self.shuffle = shuffle 13 | 14 | def __iter__(self): 15 | # deterministically shuffle based on epoch 16 | if self.shuffle: 17 | g = torch.Generator() 18 | g.manual_seed(self.epoch) 19 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 20 | else: 21 | indices = torch.arange(len(self.dataset)).tolist() 22 | 23 | # add extra samples to make it evenly divisible 24 | # in case that indices is shorter than half of total_size 25 | indices = (indices * math.ceil(self.total_size / len(indices)))[ 26 | : self.total_size 27 | ] 28 | assert len(indices) == self.total_size 29 | 30 | # subsample 31 | indices = indices[self.rank : self.total_size : self.num_replicas] 32 | assert len(indices) == self.num_samples 33 | 34 | return iter(indices) 35 | -------------------------------------------------------------------------------- /src/datasets/samplers/group_sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.runner import get_dist_info 7 | from torch.utils.data import Sampler 8 | from ..builder import SAMPLERS 9 | 10 | 11 | @SAMPLERS.register_module() 12 | class GroupSampler(Sampler): 13 | def __init__(self, dataset, samples_per_gpu=1): 14 | assert hasattr(dataset, "flag") 15 | self.dataset = dataset 16 | self.samples_per_gpu = samples_per_gpu 17 | self.flag = dataset.flag.astype(np.int64) 18 | self.group_sizes = np.bincount(self.flag) 19 | self.num_samples = 0 20 | for i, size in enumerate(self.group_sizes): 21 | self.num_samples += ( 22 | int(np.ceil(size / self.samples_per_gpu)) * self.samples_per_gpu 23 | ) 24 | 25 | def __iter__(self): 26 | indices = [] 27 | for i, size in enumerate(self.group_sizes): 28 | if size == 0: 29 | continue 30 | indice = np.where(self.flag == i)[0] 31 | assert len(indice) == size 32 | np.random.shuffle(indice) 33 | num_extra = int( 34 | np.ceil(size / self.samples_per_gpu) 35 | ) * self.samples_per_gpu - len(indice) 36 | indice = np.concatenate([indice, np.random.choice(indice, num_extra)]) 37 | indices.append(indice) 38 | indices = np.concatenate(indices) 39 | indices = [ 40 | indices[i * self.samples_per_gpu : (i + 1) * self.samples_per_gpu] 41 | for i in np.random.permutation(range(len(indices) // self.samples_per_gpu)) 42 | ] 43 | indices = np.concatenate(indices) 44 | indices = indices.astype(np.int64).tolist() 45 | assert len(indices) == self.num_samples 46 | return iter(indices) 47 | 48 | def __len__(self): 49 | return self.num_samples 50 | 51 | 52 | @SAMPLERS.register_module() 53 | class DistributedGroupSampler(Sampler): 54 | """Sampler that restricts data loading to a subset of the dataset. 55 | 56 | It is especially useful in conjunction with 57 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 58 | process can pass a DistributedSampler instance as a DataLoader sampler, 59 | and load a subset of the original dataset that is exclusive to it. 60 | 61 | .. note:: 62 | Dataset is assumed to be of constant size. 63 | 64 | Arguments: 65 | dataset: Dataset used for sampling. 66 | num_replicas (optional): Number of processes participating in 67 | distributed training. 68 | rank (optional): Rank of the current process within num_replicas. 69 | """ 70 | 71 | def __init__(self, dataset, samples_per_gpu=1, num_replicas=None, rank=None): 72 | _rank, _num_replicas = get_dist_info() 73 | if num_replicas is None: 74 | num_replicas = _num_replicas 75 | if rank is None: 76 | rank = _rank 77 | self.dataset = dataset 78 | self.samples_per_gpu = samples_per_gpu 79 | self.num_replicas = num_replicas 80 | self.rank = rank 81 | self.epoch = 0 82 | 83 | assert hasattr(self.dataset, "flag") 84 | self.flag = self.dataset.flag 85 | self.group_sizes = np.bincount(self.flag) 86 | 87 | self.num_samples = 0 88 | for i, j in enumerate(self.group_sizes): 89 | self.num_samples += ( 90 | int( 91 | math.ceil( 92 | self.group_sizes[i] 93 | * 1.0 94 | / self.samples_per_gpu 95 | / self.num_replicas 96 | ) 97 | ) 98 | * self.samples_per_gpu 99 | ) 100 | self.total_size = self.num_samples * self.num_replicas 101 | 102 | def __iter__(self): 103 | # deterministically shuffle based on epoch 104 | g = torch.Generator() 105 | g.manual_seed(self.epoch) 106 | 107 | indices = [] 108 | for i, size in enumerate(self.group_sizes): 109 | if size > 0: 110 | indice = np.where(self.flag == i)[0] 111 | assert len(indice) == size 112 | # add .numpy() to avoid bug when selecting indice in parrots. 113 | # TODO: check whether torch.randperm() can be replaced by 114 | # numpy.random.permutation(). 115 | indice = indice[ 116 | list(torch.randperm(int(size), generator=g).numpy()) 117 | ].tolist() 118 | extra = int( 119 | math.ceil(size * 1.0 / self.samples_per_gpu / self.num_replicas) 120 | ) * self.samples_per_gpu * self.num_replicas - len(indice) 121 | # pad indice 122 | tmp = indice.copy() 123 | for _ in range(extra // size): 124 | indice.extend(tmp) 125 | indice.extend(tmp[: extra % size]) 126 | indices.extend(indice) 127 | 128 | assert len(indices) == self.total_size 129 | 130 | indices = [ 131 | indices[j] 132 | for i in list( 133 | torch.randperm(len(indices) // self.samples_per_gpu, generator=g) 134 | ) 135 | for j in range(i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu) 136 | ] 137 | 138 | # subsample 139 | offset = self.num_samples * self.rank 140 | indices = indices[offset : offset + self.num_samples] 141 | assert len(indices) == self.num_samples 142 | 143 | return iter(indices) 144 | 145 | def __len__(self): 146 | return self.num_samples 147 | 148 | def set_epoch(self, epoch): 149 | self.epoch = epoch 150 | 151 | @property 152 | def sample_size(self): 153 | return self.dataset.sample_size 154 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * 2 | from .detectors import * 3 | -------------------------------------------------------------------------------- /src/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .swing_transformer import SwinTransformer 2 | 3 | __all__ = ["SwinTransformer"] 4 | -------------------------------------------------------------------------------- /src/models/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .semi_two_stage import SemiTwoStageDetector 2 | -------------------------------------------------------------------------------- /src/models/detectors/semi_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from mmdet.core import multi_apply 5 | from src.core import Transform2D, filter_invalid 6 | from src.utils.structure_utils import pad_stack 7 | 8 | 9 | class Teacher(nn.Module): 10 | def __init__(self, detector, requires_grad=False, eval=True): 11 | super().__init__() 12 | self._fields = [] 13 | self.detector = detector 14 | # freeze detector 15 | if not requires_grad: 16 | for param in self.detector.parameters(): 17 | param.requires_grad = False 18 | # deal with detector with dropout and bn 19 | if eval: 20 | self.detector.eval() 21 | 22 | @torch.no_grad() 23 | def read(self, img, **kwargs): 24 | self.img_metas = kwargs.get("img_metas") 25 | feat = self.detector.extract_feats(img) 26 | # store data in own 27 | self.backbone_feat = feat 28 | 29 | @torch.no_grad() 30 | def learn(self, student, momentum=None): 31 | if momentum is not None: 32 | for src_parm, tgt_parm in zip( 33 | student.detector.parameters(), self.detector.parameters() 34 | ): 35 | ori_type = tgt_parm.data.dtype 36 | tgt_parm.data = ( 37 | tgt_parm.data.float() * momentum 38 | + src_parm.data.float() * (1.0 - momentum) 39 | ).to(ori_type) 40 | 41 | @torch.no_grad() 42 | def deliver(self, student_input, student_infos): 43 | # connect student to each teacher 44 | pseudo_gt = {k: [] for k in self._fields} 45 | for sinfo in student_infos: 46 | student_id = sinfo["ori_filename"] 47 | teacher_idx = self.query_teacher_by_id(student_id) 48 | trans_matrix = [ 49 | torch.matmul( 50 | torch.from_numpy(sinfo["trans_matrix"]).to( 51 | self.backbone_feat[0].device 52 | ), 53 | torch.from_numpy(self.img_metas[idx]["trans_matrix"]) 54 | .to(self.backbone_feat[0].device) 55 | .inverse(), 56 | ) 57 | for idx in teacher_idx 58 | ] 59 | output_shape = [sinfo["img_shape"][:2] for _ in teacher_idx] 60 | for name in self._fields: 61 | res = [getattr(self, name)[idx] for idx in teacher_idx] 62 | if any([r is None for r in res]): 63 | continue 64 | if name == "gt_bboxes": 65 | res = Transform2D.transform_bboxes( 66 | [r.detach() for r in res], trans_matrix, output_shape 67 | ) 68 | elif name == "gt_masks": 69 | res = Transform2D.transform_masks(res, trans_matrix, output_shape) 70 | elif name == "gt_semantic_seg": 71 | res = Transform2D.transform_image( 72 | [r.detach() + 1 for r in res], trans_matrix, output_shape 73 | ) 74 | res = [r - 1 for r in res] 75 | 76 | else: 77 | res = [r.detach() for r in res] 78 | 79 | if len(res) > 1: 80 | raise NotImplementedError() 81 | # res = fuse(res) 82 | else: 83 | res = res[0] 84 | if name == "gt_semantic_seg": 85 | res[res == -1] = 255 86 | pseudo_gt[name].append(res) 87 | 88 | ( 89 | pseudo_gt["gt_bboxes"], 90 | pseudo_gt["gt_labels"], 91 | pseudo_gt["gt_masks"], 92 | ) = multi_apply( 93 | filter_invalid, 94 | pseudo_gt["gt_bboxes"], 95 | pseudo_gt["gt_labels"], 96 | [None for _ in student_infos], 97 | pseudo_gt.get("gt_masks", [None for _ in student_infos]), 98 | ) 99 | for key in list(pseudo_gt.keys()): 100 | if key not in self._fields: 101 | pseudo_gt.pop(key) 102 | if "gt_semantic_seg" in pseudo_gt: 103 | pseudo_gt["gt_semantic_seg"] = pad_stack( 104 | pseudo_gt["gt_semantic_seg"], student_input.shape[-2:] 105 | ) 106 | pseudo_gt["gt_semantic_seg"] = F.interpolate( 107 | pseudo_gt["gt_semantic_seg"].unsqueeze(1).float(), scale_factor=0.125 108 | ).long() 109 | if "gt_masks" in pseudo_gt: 110 | pseudo_gt["gt_masks"] = [ 111 | mask.pad(student_info["pad_shape"][:2]) 112 | for mask, student_info in zip(pseudo_gt["gt_masks"], student_infos) 113 | ] 114 | return pseudo_gt 115 | 116 | def rate(self): 117 | pass 118 | 119 | def query_teacher_by_id(self, file_id): 120 | teacher_ids = [meta["ori_filename"] for meta in self.img_metas] 121 | return [teacher_ids.index(file_id)] 122 | 123 | 124 | class Student(nn.Module): 125 | def __init__(self, detector, train_cfg=None): 126 | super().__init__() 127 | self.detector = detector 128 | self.train_cfg = train_cfg 129 | 130 | def learn(self, problem_data, teacher): 131 | pass 132 | 133 | def self_learn(self, reference_data): 134 | pass 135 | 136 | def parallel_learn(self, labeled, unlabeled, teacher): 137 | pass 138 | -------------------------------------------------------------------------------- /src/models/detectors/semi_two_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmdet.models import DETECTORS, build_detector 3 | from mmdet.models.detectors import BaseDetector 4 | from src.utils import GlobalWandbLoggerHook 5 | from src.utils.debug_utils import Timer 6 | from src.utils.structure_utils import check_equal, dict_concat, dict_split, zero_like 7 | 8 | from .student_wrapper import TwoStageStudent 9 | from .teacher_wrapper import TwoStageTeacher 10 | 11 | 12 | @DETECTORS.register_module() 13 | class SemiTwoStageDetector(BaseDetector): 14 | def __init__( 15 | self, 16 | student_cfg, 17 | teacher_cfg=None, 18 | train_cfg=None, 19 | test_cfg=None, 20 | base_momentum=0.999, 21 | ): 22 | super().__init__() 23 | if teacher_cfg is None: 24 | teacher_cfg = student_cfg 25 | teacher_detector = build_detector(teacher_cfg) 26 | self.teacher = TwoStageTeacher(teacher_detector, train_cfg=train_cfg) 27 | student_detector = build_detector(student_cfg) 28 | self.student = TwoStageStudent(student_detector, train_cfg=train_cfg) 29 | # self.student.register_teacher_supervision(self.teacher) 30 | 31 | self.train_cfg = train_cfg 32 | self.test_cfg = test_cfg 33 | if self.train_cfg is None: 34 | self.train_cfg = {} 35 | if self.test_cfg is None: 36 | self.test_cfg = {} 37 | 38 | self.base_momentum = base_momentum 39 | self.momentum = self.base_momentum 40 | if self.base_momentum < 1: 41 | check_equal(self.teacher.detector, self.student.detector) 42 | self._momentum_update(0.0) 43 | 44 | @torch.no_grad() 45 | def _momentum_update(self, momentum): 46 | """Momentum update of the target network.""" 47 | self.teacher.learn(self.student, momentum) 48 | 49 | @torch.no_grad() 50 | def momentum_update(self): 51 | self._momentum_update(self.momentum) 52 | 53 | def forward_dummy(self, img): 54 | """Used for computing network flops. 55 | 56 | See `mmdetection/tools/get_flops.py` 57 | """ 58 | return self.student.detector.forward_dummy(img) 59 | 60 | def forward_train(self, img, img_metas, **kwargs): 61 | 62 | with Timer("split data"): 63 | if not hasattr(self.teacher, "CLASSES"): 64 | self.teacher.CLASSES = self.CLASSES 65 | if not hasattr(self.student, "CLASSES"): 66 | self.student.CLASSES = self.CLASSES 67 | unsup_tag = self.train_cfg.get( 68 | "unsup_tag", ["unsup_teacher", "unsup_student"] 69 | ) 70 | sup_tag = self.train_cfg.get("sup_tag", ["sup"]) 71 | kwargs.update({"img": img}) 72 | kwargs.update({"img_metas": img_metas}) 73 | kwargs.update({"tag": [meta["tag"] for meta in img_metas]}) 74 | data_groups = dict_split(kwargs, "tag") 75 | sample_num = {} 76 | for tag in unsup_tag: 77 | sample_num[tag] = 0 78 | for tag in sup_tag: 79 | sample_num[tag] = 0 80 | 81 | for k, v in data_groups.items(): 82 | sample_num[k] = len(v["img"]) 83 | GlobalWandbLoggerHook.add_scalars(sample_num) 84 | 85 | if any([s in data_groups for s in sup_tag]): 86 | # compute supervised loss 87 | labeled_data_group = dict_concat( 88 | [data_groups[s] for s in sup_tag if s in data_groups] 89 | ) 90 | labeled_data_group.pop("tag") 91 | 92 | else: 93 | labeled_data_group = None 94 | 95 | if unsup_tag[0] in data_groups: 96 | teacher_tag = unsup_tag[0] 97 | student_tag = unsup_tag[1] 98 | data_groups[teacher_tag].pop("tag") 99 | data_groups[student_tag].pop("tag") 100 | 101 | unlabeled_data = data_groups[student_tag] 102 | else: 103 | unlabeled_data = None 104 | with Timer("techer prediction"): 105 | if unlabeled_data is not None: 106 | self.teacher.read(data_groups[teacher_tag]) 107 | loss = self.student.parallel_learn( 108 | labeled_data_group, unlabeled_data, self.teacher 109 | ) 110 | return loss 111 | 112 | async def async_simple_test(self, img, img_metas, **kwargs): 113 | return self.inference_detector.async_simple_test(img, img_metas, **kwargs) 114 | 115 | def simple_test(self, img, img_metas, **kwargs): 116 | return self.inference_detector.simple_test(img, img_metas, **kwargs) 117 | 118 | def aug_test(self, imgs, img_metas, **kwargs): 119 | """Test function with test time augmentation.""" 120 | return self.inference_detector.aug_test(imgs, img_metas, **kwargs) 121 | 122 | @property 123 | def inference_detector(self): 124 | if self.test_cfg.get("inference_on", "student") == "teacher": 125 | detector = self.teacher.detector 126 | else: 127 | detector = self.student.detector 128 | return detector 129 | 130 | def extract_feat(self, x): 131 | return self.student.detector.extract_feat(x) 132 | 133 | def extract_feats(self, imgs): 134 | return self.student.detector.extract_feats(imgs) 135 | 136 | def _load_from_state_dict( 137 | self, 138 | state_dict, 139 | prefix, 140 | local_metadata, 141 | strict, 142 | missing_keys, 143 | unexpected_keys, 144 | error_msgs, 145 | ): 146 | if not any(["student" in key or "teacher" in key for key in state_dict.keys()]): 147 | keys = list(state_dict.keys()) 148 | state_dict.update({"teacher.detector." + k: state_dict[k] for k in keys}) 149 | state_dict.update({"student.detector." + k: state_dict[k] for k in keys}) 150 | for k in keys: 151 | state_dict.pop(k) 152 | 153 | return super()._load_from_state_dict( 154 | state_dict, 155 | prefix, 156 | local_metadata, 157 | strict, 158 | missing_keys, 159 | unexpected_keys, 160 | error_msgs, 161 | ) 162 | -------------------------------------------------------------------------------- /src/models/detectors/student_wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .two_stage_student import TwoStageStudent 2 | 3 | __all__ = ["TwoStageStudent"] 4 | -------------------------------------------------------------------------------- /src/models/detectors/student_wrapper/two_stage_student.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmdet.core import multi_apply 3 | from mmdet.core import bbox_overlaps 4 | from src.utils.debug_utils import Timer 5 | from src.utils.structure_utils import dict_concat, dict_sum, weighted_loss 6 | 7 | 8 | from src.models.utils.semi_wrapper import ( 9 | get_roi_loss, 10 | get_roi_prediction, 11 | split_roi_prediction, 12 | split_rpn_output, 13 | ) 14 | from ..semi_base import Student 15 | 16 | 17 | class TwoStageStudent(Student): 18 | def __init__(self, detector, train_cfg=None): 19 | super().__init__(detector, train_cfg) 20 | 21 | def get_prediction( 22 | self, 23 | img, 24 | img_metas, 25 | gt_bboxes, 26 | gt_labels, 27 | proposals=None, 28 | gt_masks=None, 29 | gt_bboxes_ignore=None, 30 | **kwargs 31 | ): 32 | x = self.detector.extract_feat(img) 33 | if self.detector.with_rpn: 34 | proposal_cfg = self.detector.train_cfg.get( 35 | "rpn_proposal", self.detector.test_cfg.rpn 36 | ) 37 | rpn_output = self.detector.rpn_head(x) 38 | proposals = self.detector.rpn_head.get_bboxes( 39 | *rpn_output, img_metas, proposal_cfg 40 | ) 41 | else: 42 | rpn_output = None 43 | 44 | roi_output = get_roi_prediction( 45 | self.detector.roi_head, 46 | x, 47 | proposals, 48 | img_metas, 49 | gt_bboxes, 50 | gt_labels, 51 | gt_masks, 52 | gt_bboxes_ignore, 53 | **kwargs 54 | ) 55 | 56 | return rpn_output, roi_output 57 | 58 | def get_rpn_loss(self, rpn_output, gt_bboxes, img_metas): 59 | if self.detector.with_rpn: 60 | loss_inputs = tuple(rpn_output) + (gt_bboxes, img_metas) 61 | return self.detector.rpn_head.loss(*loss_inputs) 62 | return {} 63 | 64 | def parallel_learn(self, labeled, unlabeled, teacher): 65 | if unlabeled is not None: 66 | with Timer("techer prediction transform"): 67 | pseudo_gt = teacher.deliver(unlabeled["img"], unlabeled["img_metas"]) 68 | gt_proposals = { 69 | "bboxes": unlabeled["gt_bboxes"], 70 | "img_metas": unlabeled["img_metas"], 71 | } 72 | gt_scores = teacher.rate(**gt_proposals) 73 | bboxes, labels = multi_apply( 74 | self.combine, 75 | unlabeled["gt_bboxes"], 76 | unlabeled["gt_labels"], 77 | gt_scores, 78 | [bbox[:, :4] for bbox in pseudo_gt["gt_bboxes"]], 79 | pseudo_gt["gt_labels"], 80 | ) 81 | unlabeled.update(gt_bboxes=bboxes, gt_labels=labels) 82 | loss = {} 83 | if labeled is None: 84 | with Timer("Get student prediction"): 85 | rpn_pred, roi_pred = self.get_prediction(**unlabeled) 86 | unlabeled_rpn_loss = self.get_rpn_loss( 87 | rpn_pred, unlabeled["gt_bboxes"], unlabeled["img_metas"] 88 | ) 89 | unlabeled_roi_loss = get_roi_loss( 90 | self.detector.roi_head, 91 | roi_pred, 92 | unlabeled["img_metas"], 93 | teacher=teacher, 94 | ) 95 | loss.update( 96 | weighted_loss( 97 | unlabeled_rpn_loss, 98 | self.train_cfg.get("unsup_weight", 2.0), 99 | teacher.ignore_branch, 100 | ) 101 | ) 102 | loss.update( 103 | weighted_loss( 104 | unlabeled_roi_loss, 105 | self.train_cfg.get("unsup_weight", 2.0), 106 | teacher.ignore_branch, 107 | ) 108 | ) 109 | elif unlabeled is None: 110 | with Timer("Get student prediction"): 111 | rpn_pred, roi_pred = self.get_prediction(**labeled) 112 | labeled_rpn_loss = self.get_rpn_loss( 113 | rpn_pred, labeled["gt_bboxes"], labeled["img_metas"] 114 | ) 115 | labeled_roi_loss = get_roi_loss( 116 | self.detector.roi_head, roi_pred, labeled["img_metas"], 117 | ) 118 | loss.update(labeled_rpn_loss) 119 | loss.update(labeled_roi_loss) 120 | else: 121 | labeled_sample_num, unlabeled_sample_num = ( 122 | len(labeled["img"]), 123 | len(unlabeled["img"]), 124 | ) 125 | with Timer("Get student prediction"): 126 | rpn_pred, roi_pred = self.get_prediction( 127 | **dict_concat([labeled, unlabeled]) 128 | ) 129 | 130 | with Timer("Get student rpn loss"): 131 | labeled_rpn_pred, unlabeled_rpn_pred = split_rpn_output( 132 | rpn_pred, [labeled_sample_num, unlabeled_sample_num] 133 | ) 134 | 135 | labeled_rpn_loss = self.get_rpn_loss( 136 | labeled_rpn_pred, labeled["gt_bboxes"], labeled["img_metas"] 137 | ) 138 | unlabeled_rpn_loss = self.get_rpn_loss( 139 | unlabeled_rpn_pred, unlabeled["gt_bboxes"], unlabeled["img_metas"] 140 | ) 141 | with Timer("Get student rcnn loss"): 142 | labeled_roi_pred, unlabeled_roi_pred = split_roi_prediction( 143 | self.detector.roi_head, 144 | roi_pred, 145 | [labeled_sample_num, unlabeled_sample_num], 146 | ) 147 | 148 | labeled_roi_loss = get_roi_loss( 149 | self.detector.roi_head, labeled_roi_pred, labeled["img_metas"] 150 | ) 151 | unlabeled_roi_loss = get_roi_loss( 152 | self.detector.roi_head, 153 | unlabeled_roi_pred, 154 | unlabeled["img_metas"], 155 | teacher=teacher, 156 | ) 157 | 158 | loss.update(labeled_rpn_loss) 159 | loss.update(labeled_roi_loss) 160 | 161 | unlabeled_loss = {} 162 | unlabeled_loss.update( 163 | weighted_loss( 164 | unlabeled_rpn_loss, 165 | self.train_cfg.get("unsup_weight", 2.0), 166 | teacher.ignore_branch, 167 | ) 168 | ) 169 | unlabeled_loss.update( 170 | weighted_loss( 171 | unlabeled_roi_loss, 172 | self.train_cfg.get("unsup_weight", 2.0), 173 | teacher.ignore_branch, 174 | ) 175 | ) 176 | loss = dict_sum(loss, unlabeled_loss) 177 | return loss 178 | 179 | def combine(self, bbox_gt, label_gt, score_gt, bbox_noise, label_noise): 180 | score_thr = self.train_cfg.get("gt_score_thr", 0.9) 181 | flags = score_gt[torch.arange(len(score_gt)), label_gt] > score_thr 182 | bbox_gt = bbox_gt[flags] 183 | label_gt = label_gt[flags] 184 | 185 | iou = bbox_overlaps(bbox_gt, bbox_noise) 186 | if iou.numel() > 0: 187 | matched = iou.float().max(dim=1)[0] > self.train_cfg.get("iou_thr", 0.5) 188 | else: 189 | matched = torch.zeros_like(label_gt, dtype=torch.bool) 190 | 191 | bbox_select = torch.cat([bbox_noise, bbox_gt[~matched]]) 192 | label_select = torch.cat([label_noise, label_gt[~matched]]) 193 | 194 | return bbox_select, label_select 195 | -------------------------------------------------------------------------------- /src/models/detectors/teacher_wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .two_stage_teacher import TwoStageTeacher 2 | 3 | __all__ = ["TwoStageTeacher"] 4 | -------------------------------------------------------------------------------- /src/models/detectors/teacher_wrapper/two_stage_teacher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmdet.core import multi_apply 3 | from mmdet.models.roi_heads import HybridTaskCascadeRoIHead 4 | from src.core import Transform2D, filter_invalid, recover_mask 5 | from src.models.utils.semi_wrapper import simple_test_bboxes 6 | from src.utils.debug_utils import Timer 7 | from src.utils.structure_utils import result2bbox, result2mask 8 | 9 | from ..semi_base import Teacher 10 | 11 | 12 | class TwoStageTeacher(Teacher): 13 | def __init__(self, detector, train_cfg): 14 | super().__init__(detector) 15 | self.train_cfg = train_cfg 16 | if self.train_cfg is not None: 17 | self._fields = self.train_cfg.get( 18 | "supervised_fields", 19 | ["gt_bboxes", "gt_labels", "gt_masks", "gt_semantic_seg"], 20 | ) 21 | self.ignore_branch = [] 22 | if "gt_bboxes" not in self._fields: 23 | self.ignore_branch.append("bbox") 24 | if "gt_labels" not in self._fields: 25 | self.ignore_branch.append("cls") 26 | self.ignore_branch.append("mask") 27 | self.ignore_branch.append("seg") 28 | else: 29 | self._fields = None 30 | 31 | @torch.no_grad() 32 | def read_feat(self, data): 33 | feat = self.detector.extract_feat(data["img"]) 34 | self.backbone_feat = feat 35 | self.img_metas = data["img_metas"] 36 | 37 | @torch.no_grad() 38 | def read(self, data): 39 | with Timer("teacher rpn"): 40 | feat = self.detector.extract_feat(data["img"]) 41 | # store data in own 42 | self.backbone_feat = feat 43 | if self.detector.with_rpn: 44 | proposal_list = self.detector.rpn_head.simple_test_rpn( 45 | feat, data["img_metas"] 46 | ) 47 | else: 48 | proposal_list = data["proposals"] 49 | 50 | self.img_metas = data["img_metas"] 51 | self.proposal_list = proposal_list 52 | # roi prediction 53 | # Note: There is a bug in the original cascade mask rcnn simple test function, 54 | # we should not flip the mask in the head. So here we have to create a fake meta infomation. 55 | # https://github.com/open-mmlab/mmdetection/issues/1466 56 | with Timer("teacher rcnn"): 57 | fake_meta = [ 58 | { 59 | "img_shape": meta["img_shape"], 60 | "ori_shape": meta["ori_shape"], 61 | "scale_factor": meta["scale_factor"], 62 | "flip": False, 63 | "flip_direction": "Bug", 64 | } 65 | for meta in data["img_metas"] 66 | ] 67 | results = self.detector.roi_head.simple_test( 68 | feat, proposal_list, fake_meta, rescale=False 69 | ) 70 | with Timer("teacher filter bbox"): 71 | self._prepare_instance_label(results, data["img"].device) 72 | with Timer("teacher semantic"): 73 | if ( 74 | "gt_semantic_seg" in self._fields 75 | and hasattr(self.detector.roi_head, "with_semantic") 76 | and self.detector.roi_head.with_semantic 77 | ): 78 | semantic_pred, semantic_feat = self.detector.roi_head.semantic_head( 79 | self.backbone_feat 80 | ) 81 | self._prepare_semantic_label( 82 | semantic_pred, 83 | data["img_metas"], 84 | scale=data["img"].shape[-1] / semantic_pred.shape[-1], 85 | ) 86 | self.semantic_feat = semantic_feat 87 | 88 | def _prepare_semantic_label(self, semantic_pred, img_metas, scale=1.0): 89 | score, label = semantic_pred.softmax(dim=1).max(dim=1) 90 | label[score < self.train_cfg.get("semantic_seg_thr", 0.9)] = 255 91 | self.gt_semantic_seg, _ = multi_apply( 92 | recover_mask, label, img_metas, scale=scale 93 | ) 94 | 95 | def _prepare_instance_label(self, instance_pred, device): 96 | with Timer("result2instance"): 97 | if len(instance_pred[0]) == 2: 98 | bbox_result = [instance_pred[i][0] for i in range(len(instance_pred))] 99 | mask_result = [instance_pred[i][1] for i in range(len(instance_pred))] 100 | else: 101 | bbox_result = instance_pred 102 | mask_result = None 103 | bboxes, labels = multi_apply(result2bbox, bbox_result) 104 | if "gt_masks" in self._fields and (mask_result is not None): 105 | masks, _ = multi_apply(result2mask, mask_result) 106 | else: 107 | masks = [None for _ in range(len(bboxes))] 108 | with Timer("filter"): 109 | bboxes = [torch.from_numpy(bbox).to(device) for bbox in bboxes] 110 | labels = [torch.from_numpy(label).to(device) for label in labels] 111 | bboxes, labels, masks = multi_apply( 112 | filter_invalid, 113 | bboxes, 114 | labels, 115 | [bbox[:, 4] for bbox in bboxes], 116 | masks, 117 | thr=self.train_cfg.get("score_thr", 0.5), 118 | ) 119 | 120 | self.gt_bboxes = bboxes 121 | self.gt_labels = labels 122 | self.gt_masks = masks 123 | 124 | @torch.no_grad() 125 | def rate(self, bboxes, img_metas=None, **kwargs): 126 | # mapping bboxes to teacher image space 127 | if img_metas is not None: 128 | results = [] 129 | for i, sinfo in enumerate(img_metas): 130 | student_id = sinfo["ori_filename"] 131 | teacher_idx = self.query_teacher_by_id(student_id)[0] 132 | source2teacher = torch.from_numpy( 133 | self.img_metas[teacher_idx]["trans_matrix"] 134 | ).to(self.backbone_feat[0].device) 135 | student2source = ( 136 | torch.from_numpy(sinfo["trans_matrix"]) 137 | .to(self.backbone_feat[0].device) 138 | .inverse() 139 | ) 140 | trans_matrix = torch.matmul(source2teacher, student2source) 141 | output_shape = sinfo["img_shape"][:2] 142 | res = Transform2D.transform_bboxes( 143 | bboxes[i], trans_matrix, output_shape 144 | ).float() 145 | results.append(torch.cat([res, res.new_ones(res.shape[0], 1)], dim=1)) 146 | bboxes = results 147 | if isinstance(self.detector.roi_head, HybridTaskCascadeRoIHead): 148 | kwargs.update({"semantic_feat": getattr(self, "semantic_feat", None)}) 149 | score_pred = simple_test_bboxes( 150 | self.detector.roi_head, self.backbone_feat, self.img_metas, bboxes, **kwargs 151 | ) 152 | 153 | return score_pred 154 | -------------------------------------------------------------------------------- /src/models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MendelXu/MixTraining/ca97b3888a660265508b4aa8ba2b84c19f298f44/src/models/utils/__init__.py -------------------------------------------------------------------------------- /src/models/utils/semi_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmdet.core import multi_apply 3 | from mmdet.models.roi_heads import CascadeRoIHead, StandardRoIHead 4 | from . import standard_wrappers, cascade_wrappers 5 | 6 | 7 | def simple_test_bboxes(self, x, img_metas, proposal_list, **kwargs): 8 | if isinstance(self, CascadeRoIHead): 9 | return cascade_wrappers.get_bbox_confidence( 10 | self, x, img_metas, proposal_list, **kwargs 11 | ) 12 | elif isinstance(self, StandardRoIHead): 13 | return standard_wrappers.get_bbox_confidence( 14 | self, x, img_metas, proposal_list, **kwargs 15 | ) 16 | else: 17 | raise NotImplementedError( 18 | f"confidence estimation method for {type(self)} is not implemented yet." 19 | ) 20 | 21 | 22 | def get_roi_prediction( 23 | self, 24 | x, 25 | proposal_list, 26 | img_metas, 27 | gt_bboxes, 28 | gt_labels, 29 | gt_masks, 30 | gt_bboxes_ignore=None, 31 | **kwargs, 32 | ): 33 | if isinstance(self, CascadeRoIHead): 34 | return cascade_wrappers.get_roi_prediction( 35 | self, 36 | x, 37 | proposal_list, 38 | img_metas, 39 | gt_bboxes, 40 | gt_labels, 41 | gt_masks, 42 | gt_bboxes_ignore, 43 | **kwargs, 44 | ) 45 | elif isinstance(self, StandardRoIHead): 46 | return standard_wrappers.get_roi_prediction( 47 | self, 48 | x, 49 | proposal_list, 50 | img_metas, 51 | gt_bboxes, 52 | gt_labels, 53 | gt_masks, 54 | gt_bboxes_ignore, 55 | **kwargs, 56 | ) 57 | else: 58 | raise NotImplementedError( 59 | f"get_prediction method for {type(self)} is not implemented yet." 60 | ) 61 | 62 | 63 | def get_roi_loss(self, pred, img_metas, **kwargs): 64 | if isinstance(self, CascadeRoIHead): 65 | return cascade_wrappers.get_roi_loss(self, pred, img_metas, **kwargs) 66 | elif isinstance(self, StandardRoIHead): 67 | return standard_wrappers.get_roi_loss(self, pred, img_metas, **kwargs) 68 | else: 69 | raise NotImplementedError() 70 | 71 | 72 | def _split(*input, splits=None): 73 | if splits is None: 74 | return input 75 | else: 76 | return tuple([torch.split(i, splits) for i in input]) 77 | 78 | 79 | def split_rpn_output(rpn_output, splits): 80 | if rpn_output is None: 81 | return [None for _ in splits] 82 | else: 83 | tmp = multi_apply(_split, *rpn_output, splits=splits) 84 | return [[[tt[j] for tt in t] for t in tmp] for j in range(len(splits))] 85 | 86 | 87 | def split_roi_prediction(self, roi_output, splits): 88 | if isinstance(self, CascadeRoIHead): 89 | return cascade_wrappers.split_roi_prediction(self, roi_output, splits) 90 | elif isinstance(self, StandardRoIHead): 91 | return standard_wrappers.split_roi_prediction(self, roi_output, splits) 92 | else: 93 | raise NotImplementedError( 94 | f"confidence estimation method for {type(self)} is not implemented yet." 95 | ) 96 | -------------------------------------------------------------------------------- /src/models/utils/standard_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | from mmdet.core import bbox2roi, bbox_overlaps, roi2bbox 5 | from src.utils.structure_utils import dict_concat 6 | 7 | 8 | def get_bbox_confidence(self, x, img_metas, proposal_list, **kwargs): 9 | assert self.with_bbox, "Bbox head must be implemented." 10 | num_proposals_per_img = tuple(len(proposals) for proposals in proposal_list) 11 | rois = bbox2roi(proposal_list) 12 | bbox_results = self._bbox_forward(x, rois) 13 | cls_score = bbox_results["cls_score"].softmax(dim=-1) 14 | cls_score = cls_score.split(num_proposals_per_img, 0) 15 | return cls_score 16 | 17 | 18 | def get_roi_prediction( 19 | self, 20 | x, 21 | proposal_list, 22 | img_metas, 23 | gt_bboxes, 24 | gt_labels, 25 | gt_masks, 26 | gt_bboxes_ignore=None, 27 | **kwargs, 28 | ): 29 | if self.with_bbox or self.with_mask: 30 | num_imgs = len(img_metas) 31 | if gt_bboxes_ignore is None: 32 | gt_bboxes_ignore = [None for _ in range(num_imgs)] 33 | sampling_results = [] 34 | for i in range(num_imgs): 35 | assign_result = self.bbox_assigner.assign( 36 | proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], gt_labels[i] 37 | ) 38 | sampling_result = self.bbox_sampler.sample( 39 | assign_result, 40 | proposal_list[i], 41 | gt_bboxes[i], 42 | gt_labels[i], 43 | feats=[lvl_feat[i][None] for lvl_feat in x], 44 | ) 45 | sampling_results.append(sampling_result) 46 | 47 | prediction = {} 48 | 49 | if self.with_bbox: 50 | rois = bbox2roi([res.bboxes for res in sampling_results]) 51 | if kwargs.get("save_assign_gt_inds", False): 52 | pos_inds = [res.pos_inds for res in sampling_results] 53 | prediction["pos_inds"] = pos_inds 54 | pos_assigned_gt_inds = [ 55 | res.pos_assigned_gt_inds for res in sampling_results 56 | ] 57 | prediction["pos_assigned_gt_inds"] = pos_assigned_gt_inds 58 | bbox_results = self._bbox_forward(x, rois) 59 | bbox_targets = self.bbox_head.get_targets( 60 | sampling_results, gt_bboxes, gt_labels, self.train_cfg 61 | ) 62 | 63 | prediction["rois"] = rois 64 | prediction["bbox_results"] = bbox_results 65 | prediction["bbox_targets"] = bbox_targets 66 | if self.with_mask: 67 | if not self.share_roi_extractor: 68 | pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) 69 | mask_results = self._mask_forward(x, pos_rois) 70 | else: 71 | pos_inds = [] 72 | device = bbox_results["bbox_feats"].device 73 | for res in sampling_results: 74 | pos_inds.append( 75 | torch.ones( 76 | res.pos_bboxes.shape[0], device=device, dtype=torch.uint8, 77 | ) 78 | ) 79 | pos_inds.append( 80 | torch.zeros( 81 | res.neg_bboxes.shape[0], device=device, dtype=torch.uint8, 82 | ) 83 | ) 84 | pos_inds = torch.cat(pos_inds) 85 | 86 | mask_results = self._mask_forward( 87 | x, pos_inds=pos_inds, bbox_feats=bbox_results["bbox_feats"] 88 | ) 89 | mask_targets = self.mask_head.get_targets( 90 | sampling_results, gt_masks, self.train_cfg 91 | ) 92 | pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) 93 | prediction["mask_num_per_img"] = [ 94 | len(res.pos_bboxes) for res in sampling_results 95 | ] 96 | prediction["mask_targets"] = mask_targets 97 | prediction["pos_labels"] = pos_labels 98 | prediction["mask_results"] = mask_results 99 | return prediction 100 | else: 101 | return None 102 | 103 | 104 | def split_roi_prediction(self, prediction, splits): 105 | 106 | interval = [0] + np.cumsum(splits).tolist() 107 | if self.with_bbox: 108 | proposal_list = roi2bbox(prediction["rois"]) 109 | # split rois 110 | proposal_lists = [ 111 | proposal_list[interval[i] : interval[i + 1]] for i in range(len(splits)) 112 | ] 113 | rois_list = [bbox2roi(p) for p in proposal_lists] 114 | chunk_sizes = [len(rois) for rois in rois_list] 115 | 116 | bbox_results = { 117 | k: torch.split(v, chunk_sizes) 118 | for k, v in prediction["bbox_results"].items() 119 | } 120 | bbox_result_list = [ 121 | {k: v[i] for k, v in bbox_results.items()} for i in range(len(splits)) 122 | ] 123 | 124 | bbox_targets = [torch.split(v, chunk_sizes) for v in prediction["bbox_targets"]] 125 | bbox_target_list = [[bt[i] for bt in bbox_targets] for i in range(len(splits))] 126 | prediction_chunks = [ 127 | dict(rois=r, bbox_results=br, bbox_targets=bt) 128 | for r, br, bt in zip(rois_list, bbox_result_list, bbox_target_list) 129 | ] 130 | if self.with_mask: 131 | mask_num_per_img = prediction["mask_num_per_img"] 132 | mask_nums = [ 133 | mask_num_per_img[interval[i] : interval[i + 1]] for i in range(len(splits)) 134 | ] 135 | chunk_sizes = [sum(m) for m in mask_nums] 136 | mask_targets = torch.split(prediction["mask_targets"], chunk_sizes) 137 | pos_labels = torch.split(prediction["pos_labels"], chunk_sizes) 138 | mask_results = { 139 | k: torch.split(v, chunk_sizes) 140 | for k, v in prediction["mask_results"].items() 141 | } 142 | mask_results = [ 143 | {k: v[i] for k, v in mask_results.items()} for i in range(len(splits)) 144 | ] 145 | 146 | for i, p in enumerate(prediction_chunks): 147 | p.update( 148 | dict( 149 | mask_results=mask_results[i], 150 | mask_targets=mask_targets[i], 151 | pos_labels=pos_labels[i], 152 | ) 153 | ) 154 | 155 | return prediction_chunks 156 | 157 | 158 | def get_roi_loss(self, pred, img_metas, teacher=None, student=None, **kwargs): 159 | loss = dict() 160 | if self.with_bbox: 161 | bbox_results = pred["bbox_results"] 162 | bbox_targets = list(pred["bbox_targets"]) 163 | rois = pred["rois"] 164 | with torch.no_grad(): 165 | if (teacher is not None) and teacher.train_cfg.get( 166 | "with_soft_teacher", True 167 | ): 168 | unsup_flag = ["unsup" in meta["tag"] for meta in img_metas] 169 | if sum(unsup_flag) > 0: 170 | label_weights = bbox_targets[1].detach().clone() 171 | proposal_list = roi2bbox(rois) 172 | label_weights = list( 173 | torch.split( 174 | label_weights, [len(p) for p in proposal_list], dim=0 175 | ) 176 | ) 177 | label_list = torch.split( 178 | bbox_targets[0], [len(p) for p in proposal_list], dim=0 179 | ) 180 | unsup_proposals = [ 181 | {"bboxes": [proposal], "img_metas": [meta]} 182 | for proposal, flag, meta in zip( 183 | proposal_list, unsup_flag, img_metas, 184 | ) 185 | if flag 186 | ] 187 | unsup_proposals = dict_concat(unsup_proposals) 188 | rated_weights = teacher.rate(**unsup_proposals) 189 | unsup_inds = 0 190 | for i, flag in enumerate(unsup_flag): 191 | if flag: 192 | if ( 193 | teacher.train_cfg.get("rate_method", "background") 194 | == "background" 195 | ): 196 | neg_inds = label_list[i] == self.bbox_head.num_classes 197 | label_weights[i][neg_inds] = ( 198 | label_weights[i][neg_inds] 199 | * rated_weights[unsup_inds][:, -1][neg_inds] 200 | ) 201 | elif ( 202 | teacher.train_cfg.get("rate_method", "background") 203 | == "per_class" 204 | ): 205 | label_weights[i] = ( 206 | label_weights[i] 207 | * rated_weights[unsup_inds][ 208 | torch.arange(len(rated_weights[unsup_inds])), 209 | label_list[i], 210 | ] 211 | ) 212 | else: 213 | raise NotImplementedError() 214 | unsup_inds += 1 215 | label_weights = torch.cat(label_weights) 216 | bbox_targets[1] = ( 217 | label_weights.shape[0] 218 | * label_weights 219 | / max(label_weights.sum(), 1) 220 | ) 221 | loss.update( 222 | self.bbox_head.loss( 223 | bbox_results["cls_score"], 224 | bbox_results["bbox_pred"], 225 | rois, 226 | *bbox_targets, 227 | ) 228 | ) 229 | return loss 230 | 231 | 232 | def compute_kl_loss(logits, target_prob, bg_weight=1.0, fg_weight=1.0, weight=1.0): 233 | C = logits.shape[-1] 234 | class_weight = logits.new_ones(1, C) 235 | class_weight[:, -1] = class_weight[:, -1] * bg_weight 236 | class_weight[:, :-1] = class_weight[:, :-1] * fg_weight 237 | 238 | loss = -1 * class_weight * target_prob * F.log_softmax(logits, dim=-1) 239 | loss = loss.sum(dim=-1).mean() 240 | return weight * loss 241 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config 2 | from .hooks import ( 3 | DistDaliSamplerSeedHook, 4 | ApexFP16OptimizerHook, 5 | MeanTeacherHook, 6 | GlobalWandbLoggerHook, 7 | ) 8 | from .log_utils import collect_model_info 9 | from .file_utils import load_checkpoint, find_latest_checkpoint 10 | 11 | __all__ = [ 12 | "Config", 13 | "DistDaliSamplerSeedHook", 14 | "ApexFP16OptimizerHook", 15 | "MeanTeacherHook", 16 | "GlobalWandbLoggerHook", 17 | "collect_model_info", 18 | "load_checkpoint", 19 | "find_latest_checkpoint", 20 | ] 21 | -------------------------------------------------------------------------------- /src/utils/debug_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import os 4 | 5 | DEBUG = False 6 | if "DEBUG" in os.environ: 7 | if os.environ["DEBUG"] == "1": 8 | DEBUG = True 9 | 10 | 11 | class Timer: 12 | def __init__(self, name="script"): 13 | self.name = name 14 | 15 | def __enter__(self): 16 | if DEBUG: 17 | torch.cuda.synchronize() 18 | self.start = time.time() 19 | 20 | def __exit__(self, *args, **kwargs): 21 | if DEBUG: 22 | torch.cuda.synchronize() 23 | print(self.name, time.time() - self.start) 24 | -------------------------------------------------------------------------------- /src/utils/log_utils.py: -------------------------------------------------------------------------------- 1 | from prettytable import PrettyTable 2 | 3 | 4 | def collect_model_info(model, rich_text=False): 5 | def bool2str(input): 6 | if input: 7 | return "Y" 8 | else: 9 | return "N" 10 | 11 | def shape_str(size): 12 | size = [str(s) for s in size] 13 | return "X".join(size) 14 | 15 | def min_max_str(input): 16 | return "Min:{:.3f} Max:{:.3f}".format(input.min(), input.max()) 17 | 18 | def param_size(size, dtype="float32"): 19 | if dtype == "float32": 20 | size = size * 4 21 | elif dtype == "float16": 22 | size = size * 2 23 | else: 24 | raise NotImplementedError() 25 | return size / 1024.0 / 1024.0 26 | 27 | if not rich_text: 28 | table = PrettyTable(["Parameter Name", "Requires Grad", "Shape", "Value Scale"]) 29 | total_size = 0 30 | for name, param in model.named_parameters(): 31 | total_size += param.numel() 32 | table.add_row( 33 | [ 34 | name, 35 | bool2str(param.requires_grad), 36 | shape_str(param.size()), 37 | min_max_str(param), 38 | ] 39 | ) 40 | table.add_row( 41 | ["Total Number Of Params", "fp32", "%.4f" % param_size(total_size), "MB"] 42 | ) 43 | return "\n" + table.get_string(title="Model Information") 44 | else: 45 | pass 46 | -------------------------------------------------------------------------------- /src/utils/structure_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import Mapping, Sequence 3 | from numbers import Number 4 | from typing import Dict, List 5 | 6 | import numpy as np 7 | import torch 8 | from mmdet.core.mask.structures import BitmapMasks 9 | from torch.nn import functional as F 10 | 11 | 12 | def list_concat(data_list: List[list]): 13 | if isinstance(data_list[0], torch.Tensor): 14 | return torch.cat(data_list) 15 | else: 16 | endpoint = [d for d in data_list[0]] 17 | 18 | for i in range(1, len(data_list)): 19 | endpoint.extend(data_list[i]) 20 | return endpoint 21 | 22 | 23 | def sequence_concat(a, b): 24 | if isinstance(a, Sequence) and isinstance(b, Sequence): 25 | return a + b 26 | else: 27 | return None 28 | 29 | 30 | def dict_concat(dicts: List[Dict[str, list]]): 31 | return {k: list_concat([d[k] for d in dicts]) for k in dicts[0].keys()} 32 | 33 | 34 | def dict_fuse(obj_list, reference_obj): 35 | if isinstance(reference_obj, torch.Tensor): 36 | return torch.stack(obj_list) 37 | return obj_list 38 | 39 | 40 | def dict_select(dict1: Dict[str, list], key: str, value: str): 41 | flag = [v == value for v in dict1[key]] 42 | return { 43 | k: dict_fuse([vv for vv, ff in zip(v, flag) if ff], v) for k, v in dict1.items() 44 | } 45 | 46 | 47 | def dict_split(dict1, key): 48 | group_names = list(set(dict1[key])) 49 | dict_groups = {k: dict_select(dict1, key, k) for k in group_names} 50 | 51 | return dict_groups 52 | 53 | 54 | def dict_sum(a, b): 55 | if isinstance(a, dict): 56 | assert isinstance(b, dict) 57 | return {k: dict_sum(v, b[k]) for k, v in a.items()} 58 | elif isinstance(a, list): 59 | assert len(a) == len(b) 60 | return [dict_sum(aa, bb) for aa, bb in zip(a, b)] 61 | else: 62 | return a + b 63 | 64 | 65 | def zero_like(tensor_pack, prefix=""): 66 | if isinstance(tensor_pack, Sequence): 67 | return [zero_like(t) for t in tensor_pack] 68 | elif isinstance(tensor_pack, Mapping): 69 | return {prefix + k: zero_like(v) for k, v in tensor_pack.items()} 70 | elif isinstance(tensor_pack, torch.Tensor): 71 | return tensor_pack.new_zeros(tensor_pack.shape) 72 | elif isinstance(tensor_pack, np.ndarray): 73 | return np.zeros_like(tensor_pack) 74 | else: 75 | warnings.warn("Unexpected data type {}".format(type(tensor_pack))) 76 | return 0 77 | 78 | 79 | def pad_stack(tensors, shape, pad_value=255): 80 | tensors = torch.stack( 81 | [ 82 | F.pad( 83 | tensor, 84 | pad=[0, shape[1] - tensor.shape[1], 0, shape[0] - tensor.shape[0]], 85 | value=pad_value, 86 | ) 87 | for tensor in tensors 88 | ] 89 | ) 90 | return tensors 91 | 92 | 93 | def result2bbox(result): 94 | num_class = len(result) 95 | 96 | bbox = np.concatenate(result) 97 | if bbox.shape[0] == 0: 98 | label = np.zeros(0, dtype=np.uint8) 99 | else: 100 | label = np.concatenate( 101 | [[i] * len(result[i]) for i in range(num_class) if len(result[i]) > 0] 102 | ).reshape((-1,)) 103 | return bbox, label 104 | 105 | 106 | def result2mask(result): 107 | num_class = len(result) 108 | mask = [np.stack(result[i]) for i in range(num_class) if len(result[i]) > 0] 109 | if len(mask) > 0: 110 | mask = np.concatenate(mask) 111 | else: 112 | mask = np.zeros((0, 1, 1)) 113 | return BitmapMasks(mask, mask.shape[1], mask.shape[2]), None 114 | 115 | 116 | def sequence_mul(obj, multiplier): 117 | if isinstance(obj, Sequence): 118 | return [o * multiplier for o in obj] 119 | else: 120 | return obj * multiplier 121 | 122 | 123 | def is_match(word, word_list): 124 | for keyword in word_list: 125 | if keyword in word: 126 | return True 127 | return False 128 | 129 | 130 | def weighted_loss(loss: dict, weight, ignore_keys=[]): 131 | if isinstance(weight, Mapping): 132 | for k, v in weight.items(): 133 | for name, loss_item in loss.items(): 134 | if (k in name) and ("loss" in name): 135 | loss[name] = sequence_mul(loss[name], v) 136 | elif isinstance(weight, Number): 137 | for name, loss_item in loss.items(): 138 | if ("loss" in name) and (not is_match(name, ignore_keys)): 139 | loss[name] = sequence_mul(loss[name], weight) 140 | else: 141 | return loss 142 | return loss 143 | 144 | 145 | def check_equal(a, b): 146 | a_sdict = a.state_dict() 147 | b_sdict = b.state_dict() 148 | for k, v in a_sdict.items(): 149 | if k not in b_sdict: 150 | return False 151 | if v.numel() != b_sdict[k].numel(): 152 | return False 153 | return True 154 | 155 | 156 | def check_nan(obj): 157 | if isinstance(obj, Mapping): 158 | return {k: check_nan(v) for k, v in obj.items()} 159 | elif isinstance(obj, Sequence): 160 | return [check_nan(o) for o in obj] 161 | else: 162 | return torch.any(torch.isnan(obj)).item() 163 | -------------------------------------------------------------------------------- /src/utils/web_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | 4 | shortname = {"mmdet": "https://raw.githubusercontent.com/open-mmlab/mmdetection/master"} 5 | 6 | 7 | def load_text_from_web(url): 8 | header, path = url.split(":") 9 | if header in shortname: 10 | header = shortname[header] 11 | url = os.path.join(header, path) 12 | r = requests.get(url, stream=True) 13 | return r.content.decode(r.encoding) 14 | 15 | 16 | def check_url_exist(url): 17 | header, path = url.split(":") 18 | if header in shortname: 19 | header = shortname[header] 20 | url = os.path.join(header, path) 21 | 22 | r = requests.get(url, stream=True) 23 | if r.status_code != 200: 24 | raise FileNotFoundError(f"{url} does not exist") 25 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29500} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | 5 | import mmcv 6 | import torch 7 | from mmcv import DictAction 8 | from mmcv.cnn import fuse_conv_bn 9 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 10 | from mmcv.runner import get_dist_info, init_dist, load_checkpoint, wrap_fp16_model 11 | from mmdet.apis import multi_gpu_test, single_gpu_test 12 | from mmdet.datasets import replace_ImageToTensor 13 | from mmdet.models import TwoStageDetector, SingleStageDetector, build_detector 14 | from src.datasets import build_dataloader, build_dataset 15 | from src.models import AsymTwoStageDetector, AsymSingleStageDetector 16 | from src.utils import Config 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description="MMDet test (and eval) a model") 21 | parser.add_argument("config", help="test config file path") 22 | parser.add_argument("checkpoint", help="checkpoint file") 23 | parser.add_argument("--out", help="output result file in pickle format") 24 | parser.add_argument( 25 | "--fuse-conv-bn", 26 | action="store_true", 27 | help="Whether to fuse conv and bn, this will slightly increase" 28 | "the inference speed", 29 | ) 30 | parser.add_argument( 31 | "--format-only", 32 | action="store_true", 33 | help="Format the output results without perform evaluation. It is" 34 | "useful when you want to format the result to a specific format and " 35 | "submit it to the test server", 36 | ) 37 | parser.add_argument( 38 | "--eval", 39 | type=str, 40 | nargs="+", 41 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",' 42 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC', 43 | ) 44 | parser.add_argument("--show", action="store_true", help="show results") 45 | parser.add_argument( 46 | "--show-dir", help="directory where painted images will be saved" 47 | ) 48 | parser.add_argument( 49 | "--show-score-thr", 50 | type=float, 51 | default=0.3, 52 | help="score threshold (default: 0.3)", 53 | ) 54 | parser.add_argument( 55 | "--gpu-collect", 56 | action="store_true", 57 | help="whether to use gpu to collect results.", 58 | ) 59 | parser.add_argument( 60 | "--tmpdir", 61 | help="tmp directory used for collecting results from multiple " 62 | "workers, available when gpu-collect is not specified", 63 | ) 64 | parser.add_argument( 65 | "--cfg-options", 66 | nargs="+", 67 | action=DictAction, 68 | help="override some settings in the used config, the key-value pair " 69 | "in xxx=yyy format will be merged into config file. If the value to " 70 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 71 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 72 | "Note that the quotation marks are necessary and that no white space " 73 | "is allowed.", 74 | ) 75 | parser.add_argument( 76 | "--options", 77 | nargs="+", 78 | action=DictAction, 79 | help="custom options for evaluation, the key-value pair in xxx=yyy " 80 | "format will be kwargs for dataset.evaluate() function (deprecate), " 81 | "change to --eval-options instead.", 82 | ) 83 | parser.add_argument( 84 | "--eval-options", 85 | nargs="+", 86 | action=DictAction, 87 | help="custom options for evaluation, the key-value pair in xxx=yyy " 88 | "format will be kwargs for dataset.evaluate() function", 89 | ) 90 | parser.add_argument( 91 | "--launcher", 92 | choices=["none", "pytorch", "slurm", "mpi"], 93 | default="none", 94 | help="job launcher", 95 | ) 96 | parser.add_argument("--local_rank", type=int, default=0) 97 | args = parser.parse_args() 98 | if "LOCAL_RANK" not in os.environ: 99 | os.environ["LOCAL_RANK"] = str(args.local_rank) 100 | 101 | if args.options and args.eval_options: 102 | raise ValueError( 103 | "--options and --eval-options cannot be both " 104 | "specified, --options is deprecated in favor of --eval-options" 105 | ) 106 | if args.options: 107 | warnings.warn("--options is deprecated in favor of --eval-options") 108 | args.eval_options = args.options 109 | return args 110 | 111 | 112 | def main(): 113 | args = parse_args() 114 | 115 | assert args.out or args.eval or args.format_only or args.show or args.show_dir, ( 116 | "Please specify at least one operation (save/eval/format/show the " 117 | 'results / save the results) with the argument "--out", "--eval"' 118 | ', "--format-only", "--show" or "--show-dir"' 119 | ) 120 | 121 | if args.eval and args.format_only: 122 | raise ValueError("--eval and --format_only cannot be both specified") 123 | 124 | if args.out is not None and not args.out.endswith((".pkl", ".pickle")): 125 | raise ValueError("The output file must be a pkl file.") 126 | 127 | cfg = Config.fromfile(args.config) 128 | if args.cfg_options is not None: 129 | cfg.merge_from_dict(args.cfg_options) 130 | cfg.build() 131 | # import modules from string list. 132 | if cfg.get("custom_imports", None): 133 | from mmcv.utils import import_modules_from_strings 134 | 135 | import_modules_from_strings(**cfg["custom_imports"]) 136 | # set cudnn_benchmark 137 | if cfg.get("cudnn_benchmark", False): 138 | torch.backends.cudnn.benchmark = True 139 | if hasattr(cfg.model, "pretrained"): 140 | cfg.model.pretrained = None 141 | if cfg.model.get("neck"): 142 | if isinstance(cfg.model.neck, list): 143 | for neck_cfg in cfg.model.neck: 144 | if neck_cfg.get("rfp_backbone"): 145 | if neck_cfg.rfp_backbone.get("pretrained"): 146 | neck_cfg.rfp_backbone.pretrained = None 147 | elif cfg.model.neck.get("rfp_backbone"): 148 | if cfg.model.neck.rfp_backbone.get("pretrained"): 149 | cfg.model.neck.rfp_backbone.pretrained = None 150 | 151 | # in case the test dataset is concatenated 152 | samples_per_gpu = 1 153 | if isinstance(cfg.data.test, dict): 154 | cfg.data.test.test_mode = True 155 | samples_per_gpu = cfg.data.test.pop("samples_per_gpu", 1) 156 | if samples_per_gpu > 1: 157 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 158 | cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) 159 | elif isinstance(cfg.data.test, list): 160 | for ds_cfg in cfg.data.test: 161 | ds_cfg.test_mode = True 162 | samples_per_gpu = max( 163 | [ds_cfg.pop("samples_per_gpu", 1) for ds_cfg in cfg.data.test] 164 | ) 165 | if samples_per_gpu > 1: 166 | for ds_cfg in cfg.data.test: 167 | ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) 168 | 169 | # init distributed env first, since logger depends on the dist info. 170 | if args.launcher == "none": 171 | distributed = False 172 | else: 173 | distributed = True 174 | init_dist(args.launcher, **cfg.dist_params) 175 | 176 | # build the dataloader 177 | dataset = build_dataset(cfg.data.test) 178 | data_loader = build_dataloader( 179 | dataset, 180 | samples_per_gpu=samples_per_gpu, 181 | workers_per_gpu=cfg.data.workers_per_gpu, 182 | dist=distributed, 183 | shuffle=False, 184 | ) 185 | 186 | # build the model and load checkpoint 187 | cfg.model.train_cfg = None 188 | rank, _ = get_dist_info() 189 | if rank == 0: 190 | print(cfg.pretty_text) 191 | model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg")) 192 | if cfg.get("asym_wrap", False): 193 | if isinstance(model, TwoStageDetector): 194 | model = AsymTwoStageDetector(model) 195 | elif isinstance(model, SingleStageDetector): 196 | model = AsymSingleStageDetector(model) 197 | else: 198 | raise NotImplementedError() 199 | fp16_cfg = cfg.get("fp16", None) 200 | if fp16_cfg is not None: 201 | wrap_fp16_model(model) 202 | checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu") 203 | if args.fuse_conv_bn: 204 | model = fuse_conv_bn(model) 205 | # old versions did not save class info in checkpoints, this walkaround is 206 | # for backward compatibility 207 | if "CLASSES" in checkpoint.get("meta", {}): 208 | model.CLASSES = checkpoint["meta"]["CLASSES"] 209 | else: 210 | model.CLASSES = dataset.CLASSES 211 | 212 | if not distributed: 213 | model = MMDataParallel(model, device_ids=[0]) 214 | outputs = single_gpu_test( 215 | model, data_loader, args.show, args.show_dir, args.show_score_thr 216 | ) 217 | else: 218 | model = MMDistributedDataParallel( 219 | model.cuda(), 220 | device_ids=[torch.cuda.current_device()], 221 | broadcast_buffers=False, 222 | ) 223 | outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect) 224 | 225 | rank, _ = get_dist_info() 226 | if rank == 0: 227 | if args.out: 228 | print(f"\nwriting results to {args.out}") 229 | mmcv.dump(outputs, args.out) 230 | kwargs = {} if args.eval_options is None else args.eval_options 231 | if args.format_only: 232 | dataset.format_results(outputs, **kwargs) 233 | if args.eval: 234 | eval_kwargs = cfg.get("evaluation", {}).copy() 235 | # hard-code way to remove EvalHook args 236 | for key in [ 237 | "interval", 238 | "tmpdir", 239 | "start", 240 | "gpu_collect", 241 | "save_best", 242 | "rule", 243 | ]: 244 | eval_kwargs.pop(key, None) 245 | eval_kwargs.update(dict(metric=args.eval, **kwargs)) 246 | print(dataset.evaluate(outputs, **eval_kwargs)) 247 | 248 | 249 | if __name__ == "__main__": 250 | main() 251 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import os.path as osp 5 | import time 6 | import warnings 7 | 8 | import mmcv 9 | import torch 10 | from mmcv import DictAction 11 | from mmcv.runner import get_dist_info, init_dist 12 | from mmcv.utils import get_git_hash 13 | 14 | from mmdet import __version__ 15 | from mmdet.models import build_detector 16 | from mmdet.utils import collect_env, get_root_logger 17 | from src.apis import set_random_seed, train_detector 18 | from src.datasets import build_dataset 19 | from src.utils import Config, collect_model_info, find_latest_checkpoint 20 | from mmdet.models import TwoStageDetector, SingleStageDetector 21 | 22 | old_repr = torch.Tensor.__repr__ 23 | 24 | 25 | # def tensor_info(tensor): 26 | # return ( 27 | # repr(tensor.shape)[6:] 28 | # + " " 29 | # + repr(tensor.dtype)[6:] 30 | # + "@" 31 | # + str(tensor.device) 32 | # + "\n" 33 | # ) 34 | 35 | 36 | # torch.Tensor.__repr__ = tensor_info 37 | 38 | 39 | def parse_args(): 40 | parser = argparse.ArgumentParser(description="Train a detector") 41 | parser.add_argument("config", help="train config file path") 42 | parser.add_argument("--work-dir", help="the dir to save logs and models") 43 | parser.add_argument("--resume-from", help="the checkpoint file to resume from") 44 | parser.add_argument( 45 | "--no-validate", 46 | action="store_true", 47 | help="whether not to evaluate the checkpoint during training", 48 | ) 49 | group_gpus = parser.add_mutually_exclusive_group() 50 | group_gpus.add_argument( 51 | "--gpus", 52 | type=int, 53 | help="number of gpus to use " "(only applicable to non-distributed training)", 54 | ) 55 | group_gpus.add_argument( 56 | "--gpu-ids", 57 | type=int, 58 | nargs="+", 59 | help="ids of gpus to use " "(only applicable to non-distributed training)", 60 | ) 61 | parser.add_argument("--seed", type=int, default=None, help="random seed") 62 | parser.add_argument( 63 | "--deterministic", 64 | action="store_true", 65 | help="whether to set deterministic options for CUDNN backend.", 66 | ) 67 | parser.add_argument( 68 | "--options", 69 | nargs="+", 70 | action=DictAction, 71 | help="override some settings in the used config, the key-value pair " 72 | "in xxx=yyy format will be merged into config file (deprecate), " 73 | "change to --cfg-options instead.", 74 | ) 75 | parser.add_argument( 76 | "--cfg-options", 77 | nargs="+", 78 | action=DictAction, 79 | help="override some settings in the used config, the key-value pair " 80 | "in xxx=yyy format will be merged into config file. If the value to " 81 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 82 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 83 | "Note that the quotation marks are necessary and that no white space " 84 | "is allowed.", 85 | ) 86 | parser.add_argument( 87 | "--launcher", 88 | choices=["none", "pytorch", "slurm", "mpi"], 89 | default="none", 90 | help="job launcher", 91 | ) 92 | parser.add_argument("--local_rank", type=int, default=0) 93 | args = parser.parse_args() 94 | if "LOCAL_RANK" not in os.environ: 95 | os.environ["LOCAL_RANK"] = str(args.local_rank) 96 | 97 | if args.options and args.cfg_options: 98 | raise ValueError( 99 | "--options and --cfg-options cannot be both " 100 | "specified, --options is deprecated in favor of --cfg-options" 101 | ) 102 | if args.options: 103 | warnings.warn("--options is deprecated in favor of --cfg-options") 104 | args.cfg_options = args.options 105 | 106 | return args 107 | 108 | 109 | def main(): 110 | args = parse_args() 111 | 112 | cfg = Config.fromfile(args.config) 113 | extra_param = None 114 | if args.cfg_options is not None: 115 | cfg.merge_from_dict(args.cfg_options) 116 | extra_param = ("_").join( 117 | ["{}_{}".format(k, v) for k, v in args.cfg_options.items()] 118 | ) 119 | cfg.exp_name = cfg.exp_name + extra_param 120 | # import modules from string list. 121 | if cfg.get("custom_imports", None): 122 | from mmcv.utils import import_modules_from_strings 123 | 124 | import_modules_from_strings(**cfg["custom_imports"]) 125 | # set cudnn_benchmark 126 | if cfg.get("cudnn_benchmark", False): 127 | torch.backends.cudnn.benchmark = True 128 | 129 | # work_dir is determined in this priority: CLI > segment in file > filename 130 | if args.work_dir is not None: 131 | # update configs according to CLI args if args.work_dir is not None 132 | cfg.work_dir = args.work_dir 133 | elif cfg.get("work_dir", None) is None: 134 | # use config filename as default work_dir if cfg.work_dir is None 135 | cfg.work_dir = osp.join( 136 | "./work_dirs", 137 | ".".join(os.sep.join((args.config.split(os.sep))[1:]).split(".")[:-1]), 138 | ) 139 | if extra_param is not None: 140 | cfg.work_dir = osp.join(cfg.work_dir, extra_param) 141 | if args.resume_from is not None: 142 | cfg.resume_from = args.resume_from 143 | elif (cfg.resume_from is None) and cfg.get("auto_resume", True): 144 | cfg.resume_from = find_latest_checkpoint(cfg.work_dir) 145 | if args.gpu_ids is not None: 146 | cfg.gpu_ids = args.gpu_ids 147 | else: 148 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 149 | cfg.build() 150 | 151 | # init distributed env first, since logger depends on the dist info. 152 | if args.launcher == "none": 153 | distributed = False 154 | else: 155 | distributed = True 156 | init_dist(args.launcher, **cfg.dist_params) 157 | # re-set gpu_ids with distributed training mode 158 | _, world_size = get_dist_info() 159 | cfg.gpu_ids = range(world_size) 160 | 161 | # create work_dir 162 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 163 | # dump config 164 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 165 | # init the logger before other steps 166 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 167 | log_file = osp.join(cfg.work_dir, f"{timestamp}.log") 168 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 169 | 170 | # init the meta dict to record some important information such as 171 | # environment info and seed, which will be logged 172 | meta = dict() 173 | # log env info 174 | env_info_dict = collect_env() 175 | env_info = "\n".join([(f"{k}: {v}") for k, v in env_info_dict.items()]) 176 | dash_line = "-" * 60 + "\n" 177 | logger.info("Environment info:\n" + dash_line + env_info + "\n" + dash_line) 178 | meta["env_info"] = env_info 179 | meta["config"] = cfg.pretty_text 180 | # log some basic info 181 | logger.info(f"Distributed training: {distributed}") 182 | logger.info(f"Config:\n{cfg.pretty_text}") 183 | 184 | # set random seeds 185 | if args.seed is not None: 186 | logger.info( 187 | f"Set random seed to {args.seed}, " f"deterministic: {args.deterministic}" 188 | ) 189 | set_random_seed(args.seed, deterministic=args.deterministic) 190 | cfg.seed = args.seed 191 | meta["seed"] = args.seed 192 | meta["exp_name"] = osp.basename(args.config) 193 | 194 | model = build_detector( 195 | cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg") 196 | ) 197 | logger.info(collect_model_info(model)) 198 | datasets = [build_dataset(cfg.data.train)] 199 | if len(cfg.workflow) == 2: 200 | val_dataset = copy.deepcopy(cfg.data.val) 201 | val_dataset.pipeline = cfg.data.train.pipeline 202 | datasets.append(build_dataset(val_dataset)) 203 | if cfg.checkpoint_config is not None: 204 | # save mmdet version, config file content and class names in 205 | # checkpoints as meta data 206 | cfg.checkpoint_config.meta = dict( 207 | mmdet_version=__version__ + get_git_hash()[:7], CLASSES=datasets[0].CLASSES 208 | ) 209 | # add an attribute for visualization convenience 210 | model.CLASSES = datasets[0].CLASSES 211 | train_detector( 212 | model, 213 | datasets, 214 | cfg, 215 | distributed=distributed, 216 | validate=(not args.no_validate), 217 | timestamp=timestamp, 218 | meta=meta, 219 | ) 220 | # at the end of training, we may want to upload some file to wandb 221 | 222 | 223 | if __name__ == "__main__": 224 | main() 225 | --------------------------------------------------------------------------------