├── .gitignore
├── LICENSE
├── README.md
├── assets
└── teaser.png
├── configs
├── BaseRetina.yaml
├── coco
│ ├── querydet_test.yaml
│ ├── querydet_train.yaml
│ ├── retinanet_test.yaml
│ └── retinanet_train.yaml
├── custom_config.py
└── visdrone
│ ├── querydet_test.yaml
│ ├── querydet_train.yaml
│ ├── retinanet_test.yaml
│ └── retinanet_train.yaml
├── eval_visdrone.sh
├── infer_coco.py
├── infer_visdrone.py
├── models
├── querydet
│ ├── __pycache__
│ │ ├── det_head.cpython-36.pyc
│ │ ├── det_head.cpython-37.pyc
│ │ ├── detector.cpython-36.pyc
│ │ ├── detector.cpython-37.pyc
│ │ ├── qinfer.cpython-36.pyc
│ │ └── qinfer.cpython-37.pyc
│ ├── det_head.py
│ ├── detector.py
│ └── qinfer.py
└── retinanet
│ ├── __pycache__
│ ├── retinanet.cpython-36.pyc
│ └── retinanet.cpython-37.pyc
│ └── retinanet.py
├── train_coco.py
├── train_tools
├── coco_infer.py
├── coco_train.py
├── visdrone_infer.py
└── visdrone_train.py
├── train_visdrone.py
├── utils
├── anchor_gen.py
├── coco_eval_fpn.py
├── gradient_checkpoint.py
├── json_evaluator.py
├── loop_matcher.py
├── merged_sync_bn.py
├── soft_nms.py
├── time_evaluator.py
├── utils.py
└── val_mapper_with_ann.py
├── visdrone
├── data_prepare.py
├── dataloader.py
├── json_to_txt.py
├── mapper.py
└── utils.py
└── visdrone_eval
├── LICENSE
├── README.md
├── evaluate.py
├── requirements.txt
├── setup.py
└── viseval
├── __init__.py
├── __pycache__
├── __init__.cpython-37.pyc
├── bbox_overlaps.cpython-37.pyc
├── calc_accuracy.cpython-37.pyc
├── drop_objects_in_igr.cpython-37.pyc
└── eval_det.cpython-37.pyc
├── bbox_overlaps.py
├── calc_accuracy.py
├── drop_objects_in_igr.py
└── eval_det.py
/.gitignore:
--------------------------------------------------------------------------------
1 | work_dirs
2 | work_dirs/*
3 | data
4 | data/*
5 | */__pycache__
6 | */__pycache__/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Chenhongyi Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # QueryDet-PyTorch
2 |
3 |
4 |
5 |
6 |
7 | This repository is the official implementation of our paper: [QueryDet: Cascaded Sparse Query for Accelerating High-Resolution Small Object Detection, *Chenhongyi Yang*, *Zehao Huang*, *Naiyan Wang*. CVPR 2022](https://arxiv.org/abs/2103.09136)
8 |
9 |
10 |
11 | ## IMPORTANT UPDATE !!!
12 |
13 | We have updated the QueryDet repository to make it easier to use. Specifically:
14 |
15 | - QueryDet now supports newer versions of PyTorch and Detectron2.
16 | - You do not need APEX anymore. FP16 training is currently achieved through PyTorch AMP.
17 | - QueryDet now supports Spconv 2.1, which can be directly installed using pip.
18 | - We have improved the support for the VisDrone dataset.
19 | - We have re-orgnized the model configs to make them easier to use.
20 |
21 |
22 |
23 | ## Setting:
24 |
25 | ### Environment setting:
26 |
27 | We tested the new QueryDet with CUDA 10.2 using NVIDIA 2080Ti GPUs. We provide a sample setting-up script as follows:
28 |
29 | ```shell
30 | conda create -n querydet python=3.7 -y
31 | source activate querydet
32 | pip install torch==1.8.1+cu101 torchvision==0.9.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
33 | python -m pip install detectron2==0.4 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.7/index.html
34 | pip install spconv-cu102==2.1.25
35 |
36 | # Clone our repository and have fun with it!
37 | git clone https://github.com/ChenhongyiYang/QueryDet-PyTorch.git
38 |
39 | # OPTIONAL: Install the python evaluation tool for VisDrone
40 | # Reference: https://github.com/tjiiv-cprg/visdrone-det-toolkit-python
41 | cd visdrone_eval
42 | pip install -e .
43 |
44 | # OPTIONAL: Install detectron2_backbone if you want to use backbone networks like MobileNet
45 | # Reference: https://github.com/sxhxliang/detectron2_backbone
46 | git clone https://github.com/sxhxliang/detectron2_backbone.git
47 | cd detectron2_backbone
48 | python setup.py build develop
49 | ```
50 |
51 | ### COCO setting:
52 |
53 | You need to set up COCO following the [official tutorial](https://detectron2.readthedocs.io/en/latest/tutorials/builtin_datasets.html) of Detectron2.
54 |
55 | ### VisDrone setting:
56 |
57 | We provide full support for the VisDrone dataset.
58 |
59 | - You need to download the VisDrone dataset from its [official website](http://aiskyeye.com/).
60 | - Unzip and place the downloaded dataset as follows:
61 |
62 | ```
63 | QueryDet-PyTorch
64 | |-- data
65 | |-- visdrone
66 | |-- VisDrone2019-DET-train
67 | | |-- images
68 | | | |-- ...jpg # 6471 .jpg files
69 | | |-- annotations
70 | | |-- ...txt # 6471 .txt files
71 | |-- VisDrone2019-DET-val
72 | |-- images
73 | | |-- ...jpg # 548 .jpg files
74 | |-- annotations
75 | |-- ...txt # 548 .txt files
76 | ```
77 |
78 | - Pre-process the dataset by running: `python visdrone/data_prepare.py --visdrone-root data/visdrone`.
79 | - The resulting file structure will be as follows:
80 |
81 | ```
82 | QueryDet-PyTorch
83 | |-- data
84 | |-- visdrone
85 | |-- VisDrone2019-DET-train
86 | | |-- images
87 | | | |-- ...jpg # 6471 .jpg files
88 | | |-- annotations
89 | | |-- ...txt # 6471 .txt files
90 | |-- VisDrone2019-DET-val
91 | | |-- images
92 | | | |-- ...jpg # 548 .jpg files
93 | | |-- annotations
94 | | |-- ...txt # 548 .txt files
95 | |-- coco_format
96 | |-- train_images
97 | | |-- ...jpg # 25884 .jpg files
98 | |-- val_images
99 | | |-- ...jpg # 548 .jpg files
100 | |-- annotations
101 | |-- train_label.json
102 | |-- val_label.json
103 | ```
104 |
105 | - After model training, you can evaluate your model by running `bash eval_visdrone.sh /path/to/visdrone_infer.json`.
106 |
107 |
108 |
109 | ## Usage
110 |
111 | Before training, we recommend you to create a `work_dirs` directory to store all training results under `QueryDet-PyTorch` as follows:
112 |
113 | ```
114 | QueryDet-PyTorch
115 | |-- work_dirs
116 | |-- ... # other stuffs
117 | ```
118 |
119 | If you do not want to store your training results in another place, you can run `ln -s /path/to/your/storage work_dirs` to create a symbolic link.
120 |
121 | In the following, we will assume you have created such a directory and introduce the training, testing, and evaluating commands.
122 |
123 | ### Training
124 |
125 | ```shell
126 | % train coco RetinaNet baseline
127 | python train_coco.py --config-file configs/coco/retinanet_train.yaml --num-gpu 8 OUTPUT_DIR work_dirs/coco_retinanet
128 |
129 | % train coco QueryDet
130 | python train_coco.py --config-file configs/coco/querydet_train.yaml --num-gpu 8 OUTPUT_DIR work_dirs/coco_querydet
131 |
132 | % train VisDrone RetinaNet baseline
133 | python train_visdrone.py --config-file configs/visdrone/retinanet_train.yaml --num-gpu 8 OUTPUT_DIR work_dirs/visdrone_retinanet
134 |
135 | % train VisDrone QueryDet
136 | python train_visdrone.py --config-file configs/visdrone/querydet_train.pyaml --num-gpu 8 OUTPUT_DIR work_dirs/visdrone_querydet
137 | ```
138 |
139 | ### Testing
140 |
141 | ```shell
142 | % test coco RetinaNet baseline
143 | python infer_coco.py --config-file configs/coco/retinanet_test.yaml --num-gpu 8 --eval-only MODEL.WEIGHTS work_dirs/coco_retinanet/model_final.pth OUTPUT_DIR work_dirs/model_test
144 |
145 | % test coco QueryDet with Dense Inference
146 | python infer_coco.py --config-file configs/coco/querydet_test.yaml --num-gpu 8 --eval-only MODEL.WEIGHTS work_dirs/coco_querydet/model_final.pth OUTPUT_DIR work_dirs/model_test
147 |
148 | % test coco QueryDet with CSQ
149 | export SPCONV_FILTER_HWIO="1"; python infer_coco.py --config-file configs/coco/querydet_test.yaml --num-gpu 8 --eval-only MODEL.WEIGHTS work_dirs/coco_querydet/model_final.pth OUTPUT_DIR work_dirs/model_test MODEL.QUERY.QUERY_INFER True
150 |
151 | % test VisDrone RetinaNet baseline
152 | python infer_coco.py --config-file configs/visdrone/retinanet_test.yaml --num-gpu 8 --eval-only MODEL.WEIGHTS work_dirs/visdrone_retinanet/model_final.pth OUTPUT_DIR work_dirs/model_test
153 |
154 | % test VisDrone QueryDet with Dense Inference
155 | python infer_coco.py --config-file configs/visdrone/querydet_test.yaml --num-gpu 8 --eval-only MODEL.WEIGHTS work_dirs/visdrone_querydet/model_final.pth OUTPUT_DIR work_dirs/model_test
156 |
157 | % test VisDrone QueryDet with CSQ
158 | export SPCONV_FILTER_HWIO="1"; python infer_coco.py --config-file configs/visdrone/querydet_test.yaml --num-gpu 8 --eval-only MODEL.WEIGHTS work_dirs/visdrone_querydet/model_final.pth OUTPUT_DIR work_dirs/model_test MODEL.QUERY.QUERY_INFER True
159 | ```
160 |
161 | ### Evaluation
162 |
163 | - For COCO, Detectron2 will automatically evaluate the result when you run the inference command so you do not need to run any extra command.
164 | - For VisDrone, after running an inference command, you will get a result file named `visdrone_infer.json` in your resulting directory (e.g., `work_dirs/model_test` in the above commands). Then you have two options to evaluate the result:
165 | - If you have installed the Python evaluation tool, then you can evaluate your result by running `bash eval_visdrone.sh work_dirs/model_test/visdrone_infer.json`
166 | - If you want to use the official Matlab evaluation tool, you can run `python visdrone/json_to_txt.py --out /path/to/result --gt-json data/visdrone/coco_format/annotations/val_label.json --det-json work_dirs/model_test/visdrone_infer.json` to convert the result to .txt files for Matlab evaluation.
167 |
168 |
169 |
170 | ## Citation
171 | ```
172 | @InProceedings{Yang_2022_CVPR_QueryDet,
173 | author = {{Yang, Chenhongyi and Huang, Zehao and Wang, Naiyan}},
174 | title = {{QueryDet: Cascaded Sparse Query for Accelerating High-Resolution Small Object Detection},
175 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
176 | year = {2022}
177 | }
178 | ```
179 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/assets/teaser.png
--------------------------------------------------------------------------------
/configs/BaseRetina.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | BACKBONE:
3 | NAME: "build_retinanet_resnet_fpn_backbone"
4 | RESNETS:
5 | OUT_FEATURES: ["res3", "res4", "res5"]
6 | ANCHOR_GENERATOR:
7 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"]
8 | FPN:
9 | IN_FEATURES: ["res3", "res4", "res5"]
10 | RETINANET:
11 | IOU_THRESHOLDS: [0.4, 0.5]
12 | IOU_LABELS: [0, -1, 1]
13 | DATASETS:
14 | TRAIN: ("coco_2017_train",)
15 | TEST: ("coco_2017_val",)
16 | SOLVER:
17 | IMS_PER_BATCH: 16
18 | BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate
19 | STEPS: (60000, 80000)
20 | MAX_ITER: 90000
21 | INPUT:
22 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
23 | VERSION: 2
24 |
--------------------------------------------------------------------------------
/configs/coco/querydet_test.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../BaseRetina.yaml"
2 | OUTPUT_DIR: "work_dirs/model_test"
3 |
4 | MODEL:
5 | META_ARCHITECTURE: "RetinaNetQueryDet"
6 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
7 |
8 | RESNETS:
9 | DEPTH: 50
10 |
11 | ANCHOR_GENERATOR:
12 | NAME: "AnchorGeneratorWithCenter"
13 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3)] for x in [16, 32, 64, 128, 256, 512]]"]
14 |
15 | RETINANET:
16 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6", "p7"]
17 | SCORE_THRESH_TEST: 0.0001
18 |
19 | RESNETS:
20 | OUT_FEATURES: ["res2", "res3", "res4", "res5"]
21 |
22 | FPN:
23 | IN_FEATURES: ["res2", "res3", "res4", "res5"]
24 |
25 | QUERY:
26 | FEATURES_WHOLE_TEST: [2, 3, 4, 5]
27 | FEATURES_VALUE_TEST: [0, 1]
28 | Q_FEATURE_TRAIN: [1, 2]
29 | Q_FEATURE_TEST: [1, 2]
30 | THRESHOLD: 0.12
31 | QUERY_INFER: False
32 |
33 | ENCODE_CENTER_DIS_COEFF: [1., 1.]
34 | ENCODE_SMALL_OBJ_SCALE: [[0, 32], [0, 64]]
35 |
36 | CUSTOM:
37 | USE_SOFT_NMS: False
38 | SOFT_NMS_METHOD: 'gaussian'
39 | SOFT_NMS_SIGMA: 0.7
40 | SOFT_NMS_THRESHOLD: 0.4
41 | SOFT_NMS_PRUND: 0.0001
42 |
43 | TEST:
44 | DETECTIONS_PER_IMAGE: 200
45 |
46 | META_INFO:
47 | EVAL_SMALL_CLS: False
48 | EVAL_GPU_TIME: True
49 |
50 | # DATASETS:
51 | # TEST: ("coco_2017_test-dev",)
52 |
--------------------------------------------------------------------------------
/configs/coco/querydet_train.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../BaseRetina.yaml"
2 | OUTPUT_DIR: "work_dirs/coco_querydet"
3 |
4 | MODEL:
5 | META_ARCHITECTURE: "RetinaNetQueryDet"
6 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
7 |
8 | ANCHOR_GENERATOR:
9 | NAME: "AnchorGeneratorWithCenter"
10 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3)] for x in [16, 32, 64, 128, 256, 512]]"]
11 |
12 | RETINANET:
13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6", "p7"]
14 |
15 | RESNETS:
16 | DEPTH: 50
17 | OUT_FEATURES: ["res2", "res3", "res4", "res5"]
18 |
19 | FPN:
20 | IN_FEATURES: ["res2", "res3", "res4", "res5"]
21 |
22 | QUERY:
23 | Q_FEATURE_TRAIN: [1, 2]
24 | FEATURES_WHOLE_TEST: [2, 3, 4, 5]
25 | FEATURES_VALUE_TEST: [0, 1]
26 | Q_FEATURE_TEST: [1, 2]
27 |
28 | QUERY_LOSS_WEIGHT: [10., 10.]
29 | QUERY_LOSS_GAMMA: [1.2, 1.2]
30 |
31 | ENCODE_CENTER_DIS_COEFF: [1., 1.]
32 | ENCODE_SMALL_OBJ_SCALE: [[0, 32], [0, 64]]
33 |
34 | QUERY_INFER: False
35 |
36 | CUSTOM:
37 | CLEAR_CUDA_CACHE: True
38 | USE_LOOP_MATCHER: True
39 | FOCAL_LOSS_ALPHAS: [0.25, 0.25, 0.25, 0.25, 0.25, 0.25]
40 | FOCAL_LOSS_GAMMAS: [2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
41 | CLS_WEIGHTS: [1.0, 1.4, 2.1, 2.5, 2.9, 3.2]
42 | REG_WEIGHTS: [1.0, 1.4, 2.1, 2.5, 2.9, 3.2]
43 |
44 | SOLVER:
45 | # 3x
46 | # STEPS: (210000, 250000)
47 | # MAX_ITER: 270000
48 |
49 | # 1x
50 | BASE_LR: 0.01
51 | STEPS: (60000, 80000)
52 | MAX_ITER: 90000
53 | IMS_PER_BATCH: 16
54 | AMP:
55 | ENABLED: True
56 |
57 |
58 | TEST:
59 | EVAL_PERIOD: 0
60 | DETECTIONS_PER_IMAGE: 200
61 |
62 | META_INFO:
63 | EVAL_GPU_TIME: False
64 | EVAL_AP: True
65 |
66 | VIS_PERIOD: 0
--------------------------------------------------------------------------------
/configs/coco/retinanet_test.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../BaseRetina.yaml"
2 | OUTPUT_DIR: "work_dirs/model_test"
3 | MODEL:
4 | META_ARCHITECTURE: "RetinaNet_D2"
5 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
6 | RESNETS:
7 | DEPTH: 50
8 |
9 | ANCHOR_GENERATOR:
10 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3)] for x in [32, 64, 128, 256, 512]]"]
11 |
12 | RETINANET:
13 | IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"]
14 |
15 | SOLVER:
16 | # 3x
17 | # STEPS: (210000, 250000)
18 | # MAX_ITER: 270000
19 |
20 | # 1x
21 | STEPS: (60000, 80000)
22 | MAX_ITER: 90000
23 | CLIP_GRADIENTS:
24 | ENABLED: False
25 |
26 | META_INFO:
27 | EVAL_GPU_TIME: True
28 |
29 |
30 | TEST:
31 | EVAL_PERIOD: 5000
32 | DETECTIONS_PER_IMAGE: 200
33 |
--------------------------------------------------------------------------------
/configs/coco/retinanet_train.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../BaseRetina.yaml"
2 | OUTPUT_DIR: ""work_dirs/coco_retinanet"
3 | MODEL:
4 | META_ARCHITECTURE: "RetinaNet_D2"
5 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
6 | RESNETS:
7 | DEPTH: 50
8 |
9 | ANCHOR_GENERATOR:
10 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3)] for x in [32, 64, 128, 256, 512]]"]
11 |
12 | RETINANET:
13 | IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"]
14 |
15 | CUSTOM:
16 | CLS_WEIGHTS: [1., 1., 1., 1., 1.]
17 | REG_WEIGHTS: [1., 1., 1., 1., 1.]
18 | FOCAL_LOSS_ALPHAS: [0.25, 0.25, 0.25, 0.25, 0.25, 0.25]
19 | FOCAL_LOSS_GAMMAS: [2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
20 |
21 | SOLVER:
22 | # 3x
23 | # STEPS: (210000, 250000)
24 | # MAX_ITER: 270000
25 |
26 | # 1x
27 | STEPS: (60000, 80000)
28 | MAX_ITER: 90000
29 | IMS_PER_BATCH: 16
30 | BASE_LR: 0.02
31 | AMP:
32 | ENABLED: True
33 |
34 | TEST:
35 | EVAL_PERIOD: 0
36 | DETECTIONS_PER_IMAGE: 200
37 |
38 | VIS_PERIOD: 0
--------------------------------------------------------------------------------
/configs/custom_config.py:
--------------------------------------------------------------------------------
1 | from detectron2.config import CfgNode as CN
2 |
3 | INF = 1e8
4 |
5 | def add_custom_config(cfg):
6 | cfg.MODEL.FPN.TOP_LEVELS = 2
7 |
8 | #----------------------------------------------------------------------------------------------
9 | # CUSTOM
10 | #----------------------------------------------------------------------------------------------
11 | cfg.MODEL.CUSTOM = CN()
12 |
13 | cfg.MODEL.CUSTOM.FOCAL_LOSS_GAMMAS = []
14 | cfg.MODEL.CUSTOM.FOCAL_LOSS_ALPHAS = []
15 |
16 | cfg.MODEL.CUSTOM.CLS_WEIGHTS = []
17 | cfg.MODEL.CUSTOM.REG_WEIGHTS = []
18 |
19 | cfg.MODEL.CUSTOM.USE_LOOP_MATCHER = False
20 | cfg.MODEL.CUSTOM.GRADIENT_CHECKPOINT = False
21 | cfg.MODEL.CUSTOM.CLEAR_CUDA_CACHE = False
22 |
23 | # soft nms
24 | cfg.MODEL.CUSTOM.USE_SOFT_NMS = False
25 | cfg.MODEL.CUSTOM.GIOU_LOSS = False
26 | cfg.MODEL.CUSTOM.SOFT_NMS_METHOD = 'linear' # gaussian
27 | cfg.MODEL.CUSTOM.SOFT_NMS_SIGMA = 0.5
28 | cfg.MODEL.CUSTOM.SOFT_NMS_THRESHOLD = 0.5
29 | cfg.MODEL.CUSTOM.SOFT_NMS_PRUND = 0.001
30 |
31 | cfg.MODEL.CUSTOM.HEAD_BN = False
32 |
33 | #----------------------------------------------------------------------------------------------
34 | # QUERY
35 | #----------------------------------------------------------------------------------------------
36 | cfg.MODEL.QUERY = CN()
37 |
38 | cfg.MODEL.QUERY.FEATURES_WHOLE_TRAIN = [2, 3, 4, 5]
39 | cfg.MODEL.QUERY.FEATURES_VALUE_TRAIN = [0, 1]
40 | cfg.MODEL.QUERY.Q_FEATURE_TRAIN = [2]
41 |
42 | cfg.MODEL.QUERY.FEATURES_WHOLE_TEST = [2, 3, 4, 5]
43 | cfg.MODEL.QUERY.FEATURES_VALUE_TEST = [0, 1]
44 | cfg.MODEL.QUERY.Q_FEATURE_TEST = [2]
45 |
46 | cfg.MODEL.QUERY.QUERY_LOSS_WEIGHT = []
47 | cfg.MODEL.QUERY.QUERY_LOSS_GAMMA = []
48 |
49 | cfg.MODEL.QUERY.ENCODE_CENTER_DIS_COEFF = [1.]
50 | cfg.MODEL.QUERY.ENCODE_SMALL_OBJ_SCALE = []
51 |
52 | cfg.MODEL.QUERY.THRESHOLD = 0.12
53 | cfg.MODEL.QUERY.CONTEXT = 2
54 |
55 | cfg.MODEL.QUERY.QUERY_INFER = False
56 |
57 |
58 | #----------------------------------------------------------------------------------------------
59 | # Meta Info
60 | #----------------------------------------------------------------------------------------------
61 | cfg.META_INFO = CN()
62 |
63 | cfg.META_INFO.VIS_ROOT = ''
64 | cfg.META_INFO.EVAL_GPU_TIME = False
65 | cfg.META_INFO.EVAL_AP = True
66 |
67 | #----------------------------------------------------------------------------------------------
68 | # VisDrone2018
69 | #----------------------------------------------------------------------------------------------
70 | cfg.VISDRONE = CN()
71 |
72 | cfg.VISDRONE.TRAIN_JSON = 'data/visdrone/coco_format/annotations/train_label.json'
73 | cfg.VISDRONE.TRING_IMG_ROOT = 'data//visdrone/coco_format/train_images'
74 |
75 | cfg.VISDRONE.TEST_JSON = 'data/visdrone/coco_format/annotations/val_label.json'
76 | cfg.VISDRONE.TEST_IMG_ROOT = 'data/visdrone/coco_format/val_images'
77 |
78 | cfg.VISDRONE.SHORT_LENGTH = [1200]
79 | cfg.VISDRONE.MAX_LENGTH = 1999
80 |
81 | cfg.VISDRONE.TEST_LENGTH = 3999
82 |
83 |
--------------------------------------------------------------------------------
/configs/visdrone/querydet_test.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../BaseRetina.yaml"
2 | OUTPUT_DIR: "work_dirs/model_test"
3 |
4 | MODEL:
5 | META_ARCHITECTURE: "RetinaNetQueryDet"
6 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
7 |
8 | RESNETS:
9 | DEPTH: 50
10 |
11 | ANCHOR_GENERATOR:
12 | NAME: "AnchorGeneratorWithCenter"
13 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3)] for x in [16, 32, 64, 128, 256, 512]]"]
14 |
15 | RETINANET:
16 | IOU_THRESHOLDS: [0.4, 0.5]
17 | IOU_LABELS: [0, -1, 1]
18 | NUM_CLASSES: 10
19 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6", "p7"]
20 | SCORE_THRESH_TEST: 0.0001
21 |
22 | RESNETS:
23 | OUT_FEATURES: ["res2", "res3", "res4", "res5"]
24 |
25 | FPN:
26 | IN_FEATURES: ["res2", "res3", "res4", "res5"]
27 |
28 | QUERY:
29 | FEATURES_WHOLE_TEST: [2, 3, 4, 5]
30 | FEATURES_VALUE_TEST: [0, 1]
31 | Q_FEATURE_TRAIN: [1, 2]
32 | Q_FEATURE_TEST: [1, 2]
33 |
34 | ENCODE_CENTER_DIS_COEFF: [1., 1.]
35 | ENCODE_SMALL_OBJ_SCALE: [[0, 32], [0, 64]]
36 |
37 | THRESHOLD: 0.12
38 | QUERY_INFER: False
39 |
40 | CUSTOM:
41 | USE_SOFT_NMS: False
42 | SOFT_NMS_METHOD: 'gaussian'
43 | SOFT_NMS_SIGMA: 0.6
44 | SOFT_NMS_THRESHOLD: 0.4
45 | SOFT_NMS_PRUND: 0.0001
46 |
47 | VISDRONE:
48 | TEST_LENGTH: 3999
49 |
50 | TEST:
51 | DETECTIONS_PER_IMAGE: 500
52 |
53 | META_INFO:
54 | EVAL_GPU_TIME: True
55 |
--------------------------------------------------------------------------------
/configs/visdrone/querydet_train.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../BaseRetina.yaml"
2 | OUTPUT_DIR: "work_dirs/visdrone_querydet"
3 |
4 | MODEL:
5 | META_ARCHITECTURE: "RetinaNetQueryDet"
6 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
7 |
8 | RESNETS:
9 | DEPTH: 50
10 |
11 | ANCHOR_GENERATOR:
12 | NAME: "AnchorGeneratorWithCenter"
13 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3)] for x in [16, 32, 64, 128, 256, 512]]"]
14 |
15 | RETINANET:
16 | IOU_THRESHOLDS: [0.4, 0.5]
17 | IOU_LABELS: [0, -1, 1]
18 | NUM_CLASSES: 10
19 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6", "p7"]
20 |
21 | RESNETS:
22 | OUT_FEATURES: ["res2", "res3", "res4", "res5"]
23 |
24 | FPN:
25 | IN_FEATURES: ["res2", "res3", "res4", "res5"]
26 |
27 | QUERY:
28 | Q_FEATURE_TRAIN: [1, 2]
29 | FEATURES_WHOLE_TEST: [2, 3, 4, 5]
30 | FEATURES_VALUE_TEST: [0, 1]
31 | Q_FEATURE_TEST: [1, 2]
32 |
33 | QUERY_LOSS_WEIGHT: [10., 10.]
34 | QUERY_LOSS_GAMMA: [1.3, 1.3]
35 |
36 | ENCODE_CENTER_DIS_COEFF: [1., 1.]
37 | ENCODE_SMALL_OBJ_SCALE: [[0, 32], [0, 64]]
38 |
39 | QUERY_INFER: False
40 |
41 | CUSTOM:
42 | GRADIENT_CHECKPOINT: False
43 | USE_LOOP_MATCHER: True
44 | FOCAL_LOSS_ALPHAS: [0.25, 0.25, 0.25, 0.25, 0.25, 0.25]
45 | FOCAL_LOSS_GAMMAS: [2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
46 | CLS_WEIGHTS: [1.0, 1.4, 1.8, 2.2, 2.6, 2.6]
47 | REG_WEIGHTS: [1.0, 1.4, 1.8, 2.2, 2.6, 2.6]
48 |
49 |
50 | SOLVER:
51 | BASE_LR: 0.01
52 | STEPS: (30000, 40000)
53 | MAX_ITER: 50000
54 | IMS_PER_BATCH: 8
55 | AMP:
56 | ENABLED: True
57 | CLIP_GRADIENTS:
58 | ENABLED: True
59 | CLIP_TYPE: value
60 | CLIP_VALUE: 35.0
61 | NORM_TYPE: 2.0
62 |
63 | VISDRONE:
64 | SHORT_LENGTH: [1200]
65 | MAX_LENGTH: 1999
66 |
67 | TEST:
68 | EVAL_PERIOD: 0
69 | DETECTIONS_PER_IMAGE: 500
70 |
71 | META_INFO:
72 | EVAL_GPU_TIME: True
73 |
74 | VIS_PERIOD: 0
--------------------------------------------------------------------------------
/configs/visdrone/retinanet_test.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../BaseRetina.yaml"
2 | OUTPUT_DIR: "work_dirs/model_test"
3 |
4 | MODEL:
5 | META_ARCHITECTURE: "RetinaNet_D2"
6 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
7 | RESNETS:
8 | DEPTH: 50
9 |
10 | ANCHOR_GENERATOR:
11 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3)] for x in [32, 64, 128, 256, 512]]"]
12 |
13 | RESNETS:
14 | OUT_FEATURES: ["res3", "res4", "res5"]
15 |
16 | FPN:
17 | IN_FEATURES: ["res3", "res4", "res5"]
18 |
19 | RETINANET:
20 | IOU_THRESHOLDS: [0.4, 0.5]
21 | IOU_LABELS: [0, -1, 1]
22 | NUM_CLASSES: 10
23 | IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"]
24 | SCORE_THRESH_TEST: 0.005
25 |
26 | META_INFO:
27 | EVAL_GPU_TIME: True
28 |
29 |
30 | TEST:
31 | DETECTIONS_PER_IMAGE: 500
32 |
--------------------------------------------------------------------------------
/configs/visdrone/retinanet_train.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../BaseRetina.yaml"
2 | OUTPUT_DIR: "work_dirs/visdrone_retinanet"
3 |
4 | MODEL:
5 | META_ARCHITECTURE: "RetinaNet_D2"
6 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
7 | RESNETS:
8 | DEPTH: 50
9 |
10 | ANCHOR_GENERATOR:
11 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3)] for x in [32, 64, 128, 256, 512]]"]
12 |
13 | RETINANET:
14 | IOU_THRESHOLDS: [0.4, 0.5]
15 | IOU_LABELS: [0, -1, 1]
16 | NUM_CLASSES: 10
17 | IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"]
18 |
19 | CUSTOM:
20 | FOCAL_LOSS_ALPHAS: [0.25, 0.25, 0.25, 0.25, 0.25]
21 | FOCAL_LOSS_GAMMAS: [2.0, 2.0, 2.0, 2.0, 2.0]
22 | CLS_WEIGHTS: [1., 1., 1., 1., 1.]
23 | REG_WEIGHTS: [1., 1., 1., 1., 1.]
24 |
25 | SOLVER:
26 | BASE_LR: 0.01
27 | STEPS: (30000, 40000)
28 | MAX_ITER: 50000
29 | IMS_PER_BATCH: 8
30 | AMP:
31 | ENABLED: True
32 |
33 | TEST:
34 | EVAL_PERIOD: 0
35 | DETECTIONS_PER_IMAGE: 500
36 |
37 | VIS_PERIOD: 0
--------------------------------------------------------------------------------
/eval_visdrone.sh:
--------------------------------------------------------------------------------
1 | DetJSON=$1
2 |
3 | python visdrone/json_to_txt.py --out .visdrone_det_txt --gt-json data/visdrone/coco_format/annotations/val_label.json --det-json $DetJSON
4 | python visdrone_eval/evaluate.py --dataset-dir data/visdrone/VisDrone2019-DET-val --res-dir .visdrone_det_txt
5 | rm -rf .visdrone_det_txt
--------------------------------------------------------------------------------
/infer_coco.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from detectron2.engine import launch
4 | from train_tools.coco_infer import default_argument_parser, start_train
5 |
6 | from models.retinanet.retinanet import RetinaNet_D2
7 | from models.querydet.detector import RetinaNetQueryDet
8 |
9 |
10 |
11 | if __name__ == '__main__':
12 | args = default_argument_parser().parse_args()
13 | print("Command Line Args:", args)
14 | launch(
15 | start_train,
16 | args.num_gpus,
17 | num_machines=args.num_machines,
18 | machine_rank=args.machine_rank,
19 | dist_url=args.dist_url,
20 | args=(args,),
21 | )
22 |
--------------------------------------------------------------------------------
/infer_visdrone.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from detectron2.engine import launch
4 | from train_tools.visdrone_infer import default_argument_parser, start_train
5 |
6 | import logging
7 |
8 | from models.retinanet.retinanet import RetinaNet_D2
9 | from models.querydet.detector import RetinaNetQueryDet
10 |
11 |
12 |
13 | if __name__ == '__main__':
14 | args = default_argument_parser().parse_args()
15 | print("Command Line Args:", args)
16 | launch(
17 | start_train,
18 | args.num_gpus,
19 | num_machines=args.num_machines,
20 | machine_rank=args.machine_rank,
21 | dist_url=args.dist_url,
22 | args=(args,),
23 | )
--------------------------------------------------------------------------------
/models/querydet/__pycache__/det_head.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/models/querydet/__pycache__/det_head.cpython-36.pyc
--------------------------------------------------------------------------------
/models/querydet/__pycache__/det_head.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/models/querydet/__pycache__/det_head.cpython-37.pyc
--------------------------------------------------------------------------------
/models/querydet/__pycache__/detector.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/models/querydet/__pycache__/detector.cpython-36.pyc
--------------------------------------------------------------------------------
/models/querydet/__pycache__/detector.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/models/querydet/__pycache__/detector.cpython-37.pyc
--------------------------------------------------------------------------------
/models/querydet/__pycache__/qinfer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/models/querydet/__pycache__/qinfer.cpython-36.pyc
--------------------------------------------------------------------------------
/models/querydet/__pycache__/qinfer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/models/querydet/__pycache__/qinfer.cpython-37.pyc
--------------------------------------------------------------------------------
/models/querydet/det_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import logging
3 | import math
4 | import numpy as np
5 | from typing import List
6 | import torch
7 | from fvcore.nn import sigmoid_focal_loss_jit, smooth_l1_loss
8 | from torch import nn
9 | import torch.nn.functional as F
10 |
11 | from detectron2.layers import ShapeSpec, batched_nms, cat, Conv2d, get_norm
12 | from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
13 | from detectron2.utils.events import get_event_storage
14 | from detectron2.utils.logger import log_first_n
15 |
16 | from detectron2.modeling.anchor_generator import build_anchor_generator
17 | from detectron2.modeling.backbone import build_backbone
18 | from detectron2.modeling.box_regression import Box2BoxTransform
19 | from detectron2.modeling.matcher import Matcher
20 | from detectron2.modeling.postprocessing import detector_postprocess
21 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
22 |
23 | from detectron2.modeling.roi_heads.roi_heads import ROIHeads
24 | from detectron2.modeling.poolers import ROIPooler
25 |
26 |
27 | class RetinaNetHead_3x3(nn.Module):
28 | def __init__(self, cfg, in_channels, conv_channels, num_convs, num_anchors):
29 | super().__init__()
30 | # fmt: off
31 | num_classes = cfg.MODEL.RETINANET.NUM_CLASSES
32 | prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB
33 | self.num_convs = num_convs
34 | # fmt: on
35 |
36 | self.cls_subnet = []
37 | self.bbox_subnet = []
38 | channels = in_channels
39 | for i in range(self.num_convs):
40 | cls_layer = nn.Conv2d(channels, conv_channels, kernel_size=3, stride=1, padding=1)
41 | bbox_layer = nn.Conv2d(channels, conv_channels, kernel_size=3, stride=1, padding=1)
42 |
43 | torch.nn.init.normal_(cls_layer.weight, mean=0, std=0.01)
44 | torch.nn.init.normal_(bbox_layer.weight, mean=0, std=0.01)
45 |
46 | torch.nn.init.constant_(cls_layer.bias, 0)
47 | torch.nn.init.constant_(bbox_layer.bias, 0)
48 |
49 | self.add_module('cls_layer_{}'.format(i), cls_layer)
50 | self.add_module('bbox_layer_{}'.format(i), bbox_layer)
51 |
52 | self.cls_subnet.append(cls_layer)
53 | self.bbox_subnet.append(bbox_layer)
54 |
55 | channels = conv_channels
56 |
57 | self.cls_score = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
58 | self.bbox_pred = nn.Conv2d(channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
59 |
60 | torch.nn.init.normal_(self.cls_score.weight, mean=0, std=0.01)
61 | torch.nn.init.normal_(self.bbox_pred.weight, mean=0, std=0.01)
62 |
63 | bias_value = -(math.log((1 - prior_prob) / prior_prob))
64 | torch.nn.init.constant_(self.cls_score.bias, bias_value)
65 |
66 | def forward(self, features):
67 | logits = []
68 | bbox_reg = []
69 |
70 | for feature in features:
71 | cls_f = feature
72 | bbox_f = feature
73 | for i in range(self.num_convs):
74 | cls_f = F.relu(self.cls_subnet[i](cls_f))
75 | bbox_f = F.relu(self.bbox_subnet[i](bbox_f))
76 |
77 | logits.append(self.cls_score(cls_f))
78 | bbox_reg.append(self.bbox_pred(bbox_f))
79 |
80 | return logits, bbox_reg
81 |
82 | def get_params(self):
83 | cls_weights = [x.weight for x in self.cls_subnet] + [self.cls_score.weight.data]
84 | cls_biases = [x.bias for x in self.cls_subnet] + [self.cls_score.bias.data]
85 |
86 | bbox_weights = [x.weight for x in self.bbox_subnet] + [self.bbox_pred.weight.data]
87 | bbox_biases = [x.bias for x in self.bbox_subnet] + [self.bbox_pred.bias.data]
88 | return cls_weights, cls_biases, bbox_weights, bbox_biases
89 |
90 |
91 | class Head_3x3(nn.Module):
92 | def __init__(self, in_channels, conv_channels, num_convs, pred_channels, pred_prior=None):
93 | super().__init__()
94 | self.num_convs = num_convs
95 |
96 | self.subnet = []
97 | channels = in_channels
98 | for i in range(self.num_convs):
99 | layer = nn.Conv2d(channels, conv_channels, kernel_size=3, stride=1, padding=1)
100 | torch.nn.init.xavier_normal_(layer.weight)
101 | torch.nn.init.constant_(layer.bias, 0)
102 | self.add_module('layer_{}'.format(i), layer)
103 | self.subnet.append(layer)
104 | channels = conv_channels
105 |
106 | self.pred_net = nn.Conv2d(channels, pred_channels, kernel_size=3, stride=1, padding=1)
107 |
108 | torch.nn.init.xavier_normal_(self.pred_net.weight)
109 | if pred_prior is not None:
110 | bias_value = -(math.log((1 - prior_prob) / prior_prob))
111 | torch.nn.init.constant_(self.pred_net.bias, bias_value)
112 | else:
113 | torch.nn.init.constant_(self.pred_net.bias, 0)
114 |
115 | def forward(self, features):
116 | preds = []
117 | for feature in features:
118 | x = feature
119 | for i in range(self.num_convs):
120 | x = F.relu(self.subnet[i](x))
121 | preds.append(self.pred_net(x))
122 | return preds
123 |
124 | def get_params(self):
125 | weights = [x.weight for x in self.subnet] + [self.pred_net.weight]
126 | biases = [x.bias for x in self.subnet] + [self.pred_net.bias]
127 | return weights, biases
128 |
129 |
130 | from utils.merged_sync_bn import MergedSyncBatchNorm
131 |
132 | class RetinaNetHead_3x3_MergeBN(nn.Module):
133 | def __init__(self, cfg, in_channels, conv_channels, num_convs, num_anchors):
134 | super().__init__()
135 | # fmt: off
136 | num_classes = cfg.MODEL.RETINANET.NUM_CLASSES
137 | prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB
138 | num_anchors = 1
139 | self.num_convs = num_convs
140 | self.bn_converted = False
141 | # fmt: on
142 |
143 | self.cls_subnet = []
144 | self.bbox_subnet = []
145 | self.cls_bns = []
146 | self.bbox_bns = []
147 |
148 | channels = in_channels
149 | for i in range(self.num_convs):
150 | cls_layer = Conv2d(channels, conv_channels, kernel_size=3, stride=1, padding=1, bias=False, activation=None, norm=None)
151 | bbox_layer = Conv2d(channels, conv_channels, kernel_size=3, stride=1, padding=1, bias=False, activation=None, norm=None)
152 | torch.nn.init.normal_(cls_layer.weight, mean=0, std=0.01)
153 | torch.nn.init.normal_(bbox_layer.weight, mean=0, std=0.01)
154 |
155 | cls_bn = MergedSyncBatchNorm(conv_channels)
156 | bbox_bn = MergedSyncBatchNorm(conv_channels)
157 |
158 | self.add_module('cls_layer_{}'.format(i), cls_layer)
159 | self.add_module('bbox_layer_{}'.format(i), bbox_layer)
160 | self.add_module('cls_bn_{}'.format(i), cls_bn)
161 | self.add_module('bbox_bn_{}'.format(i), bbox_bn)
162 |
163 | self.cls_subnet.append(cls_layer)
164 | self.bbox_subnet.append(bbox_layer)
165 | self.cls_bns.append(cls_bn)
166 | self.bbox_bns.append(bbox_bn)
167 |
168 | channels = conv_channels
169 |
170 | self.cls_score = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
171 | self.bbox_pred = nn.Conv2d(channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
172 |
173 | torch.nn.init.normal_(self.cls_score.weight, mean=0, std=0.01)
174 | torch.nn.init.normal_(self.bbox_pred.weight, mean=0, std=0.01)
175 |
176 | bias_value = -(math.log((1 - prior_prob) / prior_prob))
177 | torch.nn.init.constant_(self.cls_score.bias, bias_value)
178 |
179 |
180 | def forward(self, features, lvl_start):
181 | if self.training:
182 | return self._forward_train(features, lvl_start)
183 | else:
184 | return self._forward_eval(features, lvl_start)
185 |
186 | def _forward_train(self, features, lvl_start):
187 | cls_features = features
188 | bbox_features = features
189 | len_feats = len(features)
190 |
191 | for i in range(self.num_convs):
192 | cls_features = [self.cls_subnet[i](x) for x in cls_features]
193 | bbox_features = [self.bbox_subnet[i](x) for x in bbox_features]
194 |
195 | cls_features = self.cls_bns[i](cls_features)
196 | bbox_features = self.bbox_bns[i](bbox_features)
197 |
198 | cls_features = [F.relu(x) for x in cls_features]
199 | bbox_features = [F.relu(x) for x in bbox_features]
200 |
201 | logits = [self.cls_score(x) for x in cls_features]
202 | bbox_pred = [self.bbox_pred(x) for x in bbox_features]
203 | return logits, bbox_pred
204 |
205 |
206 | def _forward_eval(self, features, lvl_start):
207 | if not self.bn_converted:
208 | self._bn_convert()
209 |
210 | cls_features = features
211 | bbox_features = features
212 | len_feats = len(features)
213 |
214 | for i in range(self.num_convs):
215 | cls_features = [F.relu(self.cls_subnet[i](x)) for x in cls_features]
216 | bbox_features = [F.relu(self.bbox_subnet[i](x)) for x in bbox_features]
217 |
218 | logits = [self.cls_score(x) for x in cls_features]
219 | bbox_pred = [self.bbox_pred(x) for x in bbox_features]
220 |
221 | return logits, bbox_pred, centerness
222 |
223 | def _bn_convert(self):
224 | # merge BN into head weights
225 | assert not self.training
226 | if self.bn_converted:
227 | return
228 |
229 | for i in range(self.num_convs):
230 | cls_running_mean = self.cls_bns[i].running_mean.data
231 | cls_running_var = self.cls_bns[i].running_var.data
232 | cls_gamma = self.cls_bns[i].weight.data
233 | cls_beta = self.cls_bns[i].bias.data
234 |
235 | bbox_running_mean = self.bbox_bns[i].running_mean.data
236 | bbox_running_var = self.bbox_bns[i].running_var.data
237 | bbox_gamma = self.bbox_bns[i].weight.data
238 | bbox_beta = self.bbox_bns[i].bias.data
239 |
240 | cls_bn_scale = cls_gamma * torch.rsqrt(cls_running_var + 1e-10)
241 | cls_bn_bias = cls_beta - cls_bn_scale * cls_running_mean
242 |
243 | bbox_bn_scale = bbox_gamma * torch.rsqrt(bbox_running_var + 1e-10)
244 | bbox_bn_bias = bbox_beta - bbox_bn_scale * bbox_running_mean
245 |
246 | self.cls_subnet[i].weight.data = self.cls_subnet[i].weight.data * cls_bn_scale.view(-1, 1, 1, 1)
247 | self.cls_subnet[i].bias = torch.nn.Parameter(cls_bn_bias)
248 | self.bbox_subnet[i].weight.data = self.bbox_subnet[i].weight.data * bbox_bn_scale.view(-1, 1, 1, 1)
249 | self.bbox_subnet[i].bias = torch.nn.Parameter(bbox_bn_bias)
250 |
251 | self.bn_converted = True
252 |
253 | def get_params(self):
254 | if not self.bn_converted:
255 | self._bn_convert()
256 |
257 | cls_ws = [x.weight.data for x in self.cls_subnet] + [self.cls_score.weight.data]
258 | bbox_ws = [x.weight.data for x in self.bbox_subnet] + [self.bbox_pred.weight.data]
259 |
260 | cls_bs = [x.bias.data for x in self.cls_subnet] + [self.bbox_pred.weight.data]
261 | bbox_bs = [x.bias.data for x in self.bbox_subnet] + [self.bbox_pred.bias.data]
262 |
263 | return cls_ws, cls_bs, bbox_ws, bbox_bs
264 |
265 |
266 | class Head_3x3_MergeBN(nn.Module):
267 | def __init__(self, in_channels, conv_channels, num_convs, pred_channels, pred_prior=None):
268 | super().__init__()
269 | self.num_convs = num_convs
270 | self.bn_converted = False
271 |
272 | self.subnet = []
273 | self.bns = []
274 |
275 | channels = in_channels
276 | for i in range(self.num_convs):
277 | layer = Conv2d(channels, conv_channels, kernel_size=3, stride=1, padding=1, bias=False, activation=None, norm=None)
278 | torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
279 | bn = MergedSyncBatchNorm(conv_channels)
280 |
281 | self.add_module('layer_{}'.format(i), layer)
282 | self.add_module('bn_{}'.format(i), bn)
283 |
284 | self.subnet.append(layer)
285 | self.bns.append(bn)
286 |
287 | channels = conv_channels
288 |
289 | self.pred_net = nn.Conv2d(channels, pred_channels, kernel_size=3, stride=1, padding=1)
290 |
291 | torch.nn.init.normal_(self.pred_net.weight, mean=0, std=0.01)
292 | if pred_prior is not None:
293 | bias_value = -(math.log((1 - prior_prob) / prior_prob))
294 | torch.nn.init.constant_(self.pred_net.bias, bias_value)
295 | else:
296 | torch.nn.init.constant_(self.pred_net.bias, 0)
297 |
298 | def forward(self, features):
299 | if self.training:
300 | return self._forward_train(features)
301 | else:
302 | return self._forward_eval(features)
303 |
304 | def _forward_train(self, features):
305 | for i in range(self.num_convs):
306 | features = [self.subnet[i](x) for x in features]
307 | features = self.bns[i](features)
308 | features = [F.relu(x) for x in features]
309 | preds = [self.pred_net(x) for x in features]
310 | return preds
311 |
312 | def _forward_eval(self, features):
313 | if not self.bn_converted:
314 | self._bn_convert()
315 |
316 | for i in range(self.num_convs):
317 | features = [F.relu(self.subnet[i](x)) for x in features]
318 |
319 | preds = [self.pred_net(x) for x in features]
320 | return preds
321 |
322 | def _bn_convert(self):
323 | # merge BN into head weights
324 | assert not self.training
325 | if self.bn_converted:
326 | return
327 | for i in range(self.num_convs):
328 | running_mean = self.bns[i].running_mean.data
329 | running_var = self.bns[i].running_var.data
330 | gamma = self.bns[i].weight.data
331 | beta = self.bns[i].bias.data
332 | bn_scale = gamma * torch.rsqrt(running_var + 1e-10)
333 | bn_bias = beta - bn_scale * running_mean
334 | self.subnet[i].weight.data = self.subnet[i].weight.data * bn_scale.view(-1, 1, 1, 1)
335 | self.subnet[i].bias = torch.nn.Parameter(bn_bias)
336 | self.bn_converted = True
337 |
338 | def get_params(self):
339 | if not self.bn_converted:
340 | self._bn_convert()
341 | weights = [x.weight.data for x in self.subnet] + [self.pred_net.weight.data]
342 | biases = [x.bias.data for x in self.subnet] + [self.pred_net.bias.data]
343 | return weights, biases
344 |
345 |
--------------------------------------------------------------------------------
/models/querydet/detector.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import os
3 | import sys
4 | import time
5 | from pathlib import Path
6 | sys.path.append(os.path.abspath(Path(__file__).parent.parent))
7 |
8 | import logging
9 | import math
10 | import numpy as np
11 | from typing import List
12 | import torch
13 | import torch.nn.functional as F
14 | from fvcore.nn import sigmoid_focal_loss_jit, smooth_l1_loss, sigmoid_focal_loss, giou_loss
15 | from torch import nn
16 |
17 | from detectron2.layers import ShapeSpec, batched_nms, cat
18 | from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
19 | from detectron2.utils.events import get_event_storage
20 | from detectron2.utils.logger import log_first_n
21 |
22 | from detectron2.modeling.anchor_generator import build_anchor_generator
23 | from detectron2.modeling.backbone import build_backbone
24 | from detectron2.modeling.box_regression import Box2BoxTransform
25 | from detectron2.modeling.matcher import Matcher
26 | from detectron2.modeling.postprocessing import detector_postprocess
27 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
28 |
29 |
30 | from torch.cuda import Event
31 | ###########################################################################################
32 | from utils.utils import *
33 | from utils.loop_matcher import LoopMatcher
34 | from utils.soft_nms import SoftNMSer
35 | from utils.anchor_gen import AnchorGeneratorWithCenter
36 | from utils.gradient_checkpoint import checkpoint
37 | import models.querydet.det_head as dh
38 | import models.querydet.qinfer as qf
39 |
40 | from torch.cuda.amp import autocast
41 |
42 | __all__ = ["RetinaNetQueryDet"]
43 |
44 |
45 | def permute_to_N_HWA_K(tensor, K):
46 | assert tensor.dim() == 4, tensor.shape
47 | N, _, H, W = tensor.shape
48 | tensor = tensor.view(N, -1, K, H, W)
49 | tensor = tensor.permute(0, 3, 4, 1, 2)
50 | tensor = tensor.reshape(N, -1, K) # Size=(N,HWA,K)
51 | return tensor
52 |
53 |
54 | def permute_all_cls_and_box_to_N_HWA_K_and_concat(box_cls, box_delta, num_classes=80):
55 | box_cls_flattened = [permute_to_N_HWA_K(x, num_classes) for x in box_cls]
56 | box_delta_flattened = [permute_to_N_HWA_K(x, 4) for x in box_delta]
57 | box_cls = cat(box_cls_flattened, dim=1).view(-1, num_classes)
58 | box_delta = cat(box_delta_flattened, dim=1).view(-1, 4)
59 | return box_cls, box_delta
60 |
61 |
62 | def permute_all_to_NHWA_K_not_concat(box_cls, box_delta, num_classes=80):
63 | box_cls_flattened = [permute_to_N_HWA_K(x, num_classes).reshape(-1, num_classes) for x in box_cls]
64 | box_delta_flattened = [permute_to_N_HWA_K(x, 4).reshape(-1, 4) for x in box_delta]
65 | return box_cls_flattened, box_delta_flattened
66 |
67 |
68 | @META_ARCH_REGISTRY.register()
69 | class RetinaNetQueryDet(nn.Module):
70 | """
71 | Implement Our QueryDet
72 | """
73 | def __init__(self, cfg):
74 | super().__init__()
75 |
76 | # fmt: off
77 | self.num_classes = cfg.MODEL.RETINANET.NUM_CLASSES
78 | self.in_features = cfg.MODEL.RETINANET.IN_FEATURES
79 | self.query_layer_train = cfg.MODEL.QUERY.Q_FEATURE_TRAIN
80 | self.layers_whole_test = cfg.MODEL.QUERY.FEATURES_WHOLE_TEST
81 | self.layers_value_test = cfg.MODEL.QUERY.FEATURES_VALUE_TEST
82 | self.query_layer_test = cfg.MODEL.QUERY.Q_FEATURE_TEST
83 | # Loss parameters:
84 | self.focal_loss_alpha = cfg.MODEL.CUSTOM.FOCAL_LOSS_ALPHAS
85 | self.focal_loss_gamma = cfg.MODEL.CUSTOM.FOCAL_LOSS_GAMMAS
86 | self.smooth_l1_loss_beta = cfg.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA
87 | self.use_giou_loss = cfg.MODEL.CUSTOM.GIOU_LOSS
88 | self.cls_weights = cfg.MODEL.CUSTOM.CLS_WEIGHTS
89 | self.reg_weights = cfg.MODEL.CUSTOM.REG_WEIGHTS
90 | # training query head
91 | self.small_obj_scale = cfg.MODEL.QUERY.ENCODE_SMALL_OBJ_SCALE
92 | self.query_loss_weights = cfg.MODEL.QUERY.QUERY_LOSS_WEIGHT
93 | self.query_loss_gammas = cfg.MODEL.QUERY.QUERY_LOSS_GAMMA
94 | self.small_center_dis_coeff = cfg.MODEL.QUERY.ENCODE_CENTER_DIS_COEFF
95 | # Inference parameters:
96 | self.score_threshold = cfg.MODEL.RETINANET.SCORE_THRESH_TEST
97 | self.topk_candidates = cfg.MODEL.RETINANET.TOPK_CANDIDATES_TEST
98 | self.use_soft_nms = cfg.MODEL.CUSTOM.USE_SOFT_NMS
99 | self.nms_threshold = cfg.MODEL.RETINANET.NMS_THRESH_TEST
100 | self.max_detections_per_image = cfg.TEST.DETECTIONS_PER_IMAGE
101 | # query inference
102 | self.query_infer = cfg.MODEL.QUERY.QUERY_INFER
103 | self.query_threshold = cfg.MODEL.QUERY.THRESHOLD
104 | self.query_context = cfg.MODEL.QUERY.CONTEXT
105 | # other settings
106 | self.clear_cuda_cache = cfg.MODEL.CUSTOM.CLEAR_CUDA_CACHE
107 | self.anchor_num = len(cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS[0]) * \
108 | len(cfg.MODEL.ANCHOR_GENERATOR.SIZES[0])
109 | self.with_cp = cfg.MODEL.CUSTOM.GRADIENT_CHECKPOINT
110 | # fmt: on
111 | assert 'p2' in self.in_features
112 |
113 | self.backbone = build_backbone(cfg)
114 | if cfg.MODEL.CUSTOM.HEAD_BN:
115 | self.det_head = dh.RetinaNetHead_3x3_MergeBN(cfg, 256, 256, 4, self.anchor_num)
116 | self.query_head = dh.Head_3x3_MergeBN(256, 256, 4, 1)
117 | else:
118 | self.det_head = dh.RetinaNetHead_3x3(cfg, 256, 256, 4, self.anchor_num)
119 | self.query_head = dh.Head_3x3(256, 256, 4, 1)
120 |
121 | self.qInfer = qf.QueryInfer(9, self.num_classes, self.query_threshold, self.query_context)
122 |
123 | backbone_shape = self.backbone.output_shape()
124 | all_det_feature_shapes = [backbone_shape[f] for f in self.in_features]
125 |
126 | self.anchor_generator = build_anchor_generator(cfg, all_det_feature_shapes)
127 | self.query_anchor_generator = AnchorGeneratorWithCenter(sizes=[128], aspect_ratios=[1.0],
128 | strides=[2**(x+2) for x in self.query_layer_train], offset=0.5)
129 | # Matching and loss
130 | self.box2box_transform = Box2BoxTransform(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS)
131 |
132 | self.soft_nmser = SoftNMSer(
133 | cfg.MODEL.CUSTOM.SOFT_NMS_METHOD,
134 | cfg.MODEL.CUSTOM.SOFT_NMS_SIGMA,
135 | cfg.MODEL.CUSTOM.SOFT_NMS_THRESHOLD,
136 | cfg.MODEL.CUSTOM.SOFT_NMS_PRUND
137 | )
138 |
139 | if cfg.MODEL.CUSTOM.USE_LOOP_MATCHER:
140 | self.matcher = LoopMatcher(
141 | cfg.MODEL.RETINANET.IOU_THRESHOLDS,
142 | cfg.MODEL.RETINANET.IOU_LABELS,
143 | allow_low_quality_matches=True,
144 | )
145 | else:
146 | self.matcher = Matcher(
147 | cfg.MODEL.RETINANET.IOU_THRESHOLDS,
148 | cfg.MODEL.RETINANET.IOU_LABELS,
149 | allow_low_quality_matches=True,
150 | )
151 |
152 | self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1))
153 | self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1))
154 |
155 | # initialize with any reasonable #fg that's not too small
156 | self.loss_normalizer = 100
157 | self.loss_normalizer_momentum = 0.9
158 |
159 | @property
160 | def device(self):
161 | return self.pixel_mean.device
162 |
163 | def forward(self, batched_inputs, just_forward=False):
164 | if self.training:
165 | return self.train_forward(batched_inputs, just_forward)
166 | else:
167 | return self.test(batched_inputs)
168 |
169 | def train_forward(self, batched_inputs, just_forward=False):
170 | if self.clear_cuda_cache:
171 | torch.cuda.empty_cache()
172 |
173 | if "instances" in batched_inputs[0]:
174 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
175 | elif "targets" in batched_inputs[0]:
176 | log_first_n(
177 | logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10
178 | )
179 | gt_instances = [x["targets"].to(self.device) for x in batched_inputs]
180 | else:
181 | gt_instances = None
182 |
183 | images = self.preprocess_image(batched_inputs)
184 | features = self.backbone(images.tensor)
185 | all_features = [features[f] for f in self.in_features]
186 | all_anchors, all_centers = self.anchor_generator(all_features)
187 |
188 | query_feature = [all_features[x] for x in self.query_layer_train]
189 | _, query_centers = self.query_anchor_generator(query_feature)
190 |
191 | # make prediction
192 | det_cls, det_delta = self.det_head(all_features)
193 | query_logits = self.query_head(query_feature)
194 |
195 | if just_forward:
196 | return None
197 |
198 | gt_classes, gt_reg_targets = self.get_det_gt(all_anchors, gt_instances)
199 | losses = self.det_loss(gt_classes, gt_reg_targets, det_cls, det_delta, all_anchors)
200 |
201 | # query loss
202 | gt_query = self.get_query_gt(query_centers, gt_instances)
203 | query_forgrounds = [gt.sum().item() for gt in gt_query]
204 | _query_loss = self.query_loss(gt_query, query_logits, self.query_loss_gammas, self.query_loss_weights)
205 | losses.update(_query_loss)
206 | return losses
207 |
208 | def test(self, batched_inputs):
209 | images = self.preprocess_image(batched_inputs)
210 | results, total_time = self.test_forward(images) # normal test
211 | processed_results = []
212 | for results_per_image, input_per_image, image_size in zip(
213 | results, batched_inputs, images.image_sizes
214 | ):
215 | height = input_per_image.get("height", image_size[0])
216 | width = input_per_image.get("width", image_size[1])
217 | r = detector_postprocess(results_per_image, height, width)
218 | processed_results.append({"instances": r, 'time':total_time})
219 | return processed_results
220 |
221 | def test_forward(self, images):
222 | start_event = Event(enable_timing=True)
223 | end_event = Event(enable_timing=True)
224 |
225 | start_event.record()
226 | features = self.backbone(images.tensor[:, :, :])
227 |
228 | all_features = [features[f] for f in self.in_features]
229 |
230 | all_anchors, all_centers = self.anchor_generator(all_features)
231 |
232 | features_whole = [all_features[x] for x in self.layers_whole_test]
233 | features_value = [all_features[x] for x in self.layers_value_test]
234 | features_key = [all_features[x] for x in self.query_layer_test]
235 |
236 | anchors_whole = [all_anchors[x] for x in self.layers_whole_test]
237 | anchors_value = [all_anchors[x] for x in self.layers_value_test]
238 |
239 | det_cls_whole, det_delta_whole = self.det_head(features_whole)
240 |
241 |
242 | if not self.query_infer:
243 | det_cls_query, det_bbox_query = self.det_head(features_value)
244 | det_cls_query = [permute_to_N_HWA_K(x, self.num_classes) for x in det_cls_query]
245 | det_bbox_query = [permute_to_N_HWA_K(x, 4) for x in det_bbox_query]
246 | query_anchors = anchors_value
247 | else:
248 | if not self.qInfer.initialized:
249 | cls_weights, cls_biases, bbox_weights, bbox_biases = self.det_head.get_params()
250 | qcls_weights, qcls_bias = self.query_head.get_params()
251 | params = [cls_weights, cls_biases, bbox_weights, bbox_biases, qcls_weights, qcls_bias]
252 | else:
253 | params = None
254 |
255 | det_cls_query, det_bbox_query, query_anchors = self.qInfer.run_qinfer(params, features_key, features_value, anchors_value)
256 |
257 | results = self.inference(det_cls_whole, det_delta_whole, anchors_whole,
258 | det_cls_query, det_bbox_query, query_anchors,
259 | images.image_sizes)
260 |
261 | end_event.record()
262 | torch.cuda.synchronize()
263 | total_time = start_event.elapsed_time(end_event)
264 | return results, total_time
265 |
266 | # @float_function
267 | def _giou_loss(self, pred_deltas, anchors, gt_boxes):
268 | with autocast(False):
269 | pred_boxes = self.box2box_transform.apply_deltas(pred_deltas, anchors)
270 | loss = giou_loss(pred_boxes, gt_boxes, reduction='sum')
271 | return loss
272 |
273 |
274 | def det_loss(self, gt_classes, gt_anchors_targets, pred_logits, pred_deltas, all_anchors):
275 | def convert_gt_cls(logits, gt_class, f_idxs):
276 | gt_classes_target = torch.zeros_like(logits)
277 | gt_classes_target[f_idxs, gt_class[f_idxs]] = 1
278 | return gt_classes_target
279 |
280 | alphas = self.focal_loss_alpha
281 | gammas = self.focal_loss_gamma
282 | cls_weights = self.cls_weights
283 | reg_weights = self.reg_weights
284 |
285 | assert len(cls_weights) == len(pred_logits)
286 | assert len(cls_weights) == len(reg_weights)
287 |
288 | batch_size = pred_logits[0].size(0)
289 | pred_logits, pred_deltas = permute_all_to_NHWA_K_not_concat(pred_logits, pred_deltas, self.num_classes)
290 |
291 | lengths = [x.shape[0] for x in pred_logits]
292 | start_inds = [0] + [sum(lengths[:i]) for i in range(1, len(lengths))]
293 | end_inds = [sum(lengths[:i+1]) for i in range(len(lengths))]
294 |
295 | gt_classes = gt_classes.flatten()
296 | gt_anchors_targets = gt_anchors_targets.view(-1, 4)
297 |
298 | valid_idxs = gt_classes >= 0
299 | foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
300 | num_foreground = foreground_idxs.sum().item()
301 | get_event_storage().put_scalar("num_foreground", num_foreground)
302 | self.loss_normalizer = (
303 | self.loss_normalizer_momentum * self.loss_normalizer
304 | + (1 - self.loss_normalizer_momentum) * num_foreground
305 | )
306 | all_anchor_lists = [torch.cat([x.tensor.reshape(-1, 4) for _ in range(batch_size)]) for x in all_anchors]
307 | gt_clsses_list = [gt_classes[s:e] for s, e in zip(start_inds, end_inds)]
308 | gt_anchors_targets_list = [gt_anchors_targets[s:e] for s, e in zip(start_inds, end_inds)]
309 | valid_idxs_list = [valid_idxs[s:e] for s, e in zip(start_inds, end_inds)]
310 | foreground_idxs_list = [foreground_idxs[s:e] for s, e in zip(start_inds, end_inds)]
311 |
312 | loss_cls = [
313 | w * sigmoid_focal_loss_jit(
314 | x[v],
315 | convert_gt_cls(x, g, f)[v].detach(),
316 | alpha=alpha,
317 | gamma=gamma,
318 | reduction="sum"
319 | )
320 | for w, x, g, v, f, alpha, gamma in zip(cls_weights, pred_logits, gt_clsses_list, valid_idxs_list, foreground_idxs_list, alphas, gammas)
321 | ]
322 |
323 | if self.use_giou_loss:
324 | loss_box_reg = [
325 | w * self._giou_loss(
326 | x[f],
327 | a[f].detach(),
328 | g[f].detach(),
329 | )
330 | for w, x, a, g, f in zip(reg_weights, pred_deltas, all_anchor_lists, gt_anchors_targets_list, foreground_idxs_list)
331 | ]
332 | else:
333 | loss_box_reg = [
334 | w * smooth_l1_loss(
335 | x[f],
336 | g[f].detach(),
337 | beta=self.smooth_l1_loss_beta,
338 | reduction="sum"
339 | )
340 | for w, x, g, f in zip(reg_weights, pred_deltas, gt_anchors_targets_list, foreground_idxs_list)
341 | ]
342 |
343 | loss_cls = sum(loss_cls) / max(1., self.loss_normalizer)
344 | loss_box_reg = sum(loss_box_reg) / max(1., self.loss_normalizer)
345 | return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
346 |
347 | def query_loss(self, gt_small_obj, pred_small_obj, gammas, weights):
348 | pred_logits = [permute_to_N_HWA_K(x, 1).flatten() for x in pred_small_obj]
349 | gts = [x.flatten() for x in gt_small_obj]
350 | loss = sum([sigmoid_focal_loss_jit(x, y, alpha=0.25, gamma=g, reduction="mean") * w for (x, y, g, w) in zip(pred_logits, gts, gammas, weights)])
351 | return {'loss_query': loss}
352 |
353 | @torch.no_grad()
354 | def get_det_gt(self, anchors, targets):
355 | gt_classes = []
356 | gt_anchors_targets = []
357 | anchor_layers = len(anchors)
358 | anchor_lens = [len(x) for x in anchors]
359 | start_inds = [0] + [sum(anchor_lens[:i]) for i in range(1, len(anchor_lens))]
360 | end_inds = [sum(anchor_lens[:i+1]) for i in range(len(anchor_lens))]
361 | all_anchors = Boxes.cat(anchors) # Rx4
362 |
363 | for targets_per_image in targets:
364 |
365 | if type(self.matcher) == Matcher:
366 | match_quality_matrix = pairwise_iou(targets_per_image.gt_boxes, all_anchors)
367 | gt_matched_idxs, anchor_labels = self.matcher(match_quality_matrix)
368 | del(match_quality_matrix)
369 | elif type(self.matcher) == LoopMatcher: # for encoding images with lots of gts
370 | gt_matched_idxs, anchor_labels = self.matcher(targets_per_image.gt_boxes, all_anchors)
371 | else:
372 | raise NotImplementedError
373 |
374 | has_gt = len(targets_per_image) > 0
375 | if has_gt:
376 | # ground truth box regression
377 | matched_gt_boxes = targets_per_image.gt_boxes[gt_matched_idxs]
378 |
379 | if not self.use_giou_loss:
380 | gt_anchors_reg_targets_i = self.box2box_transform.get_deltas(
381 | all_anchors.tensor, matched_gt_boxes.tensor
382 | )
383 | else:
384 | gt_anchors_reg_targets_i = matched_gt_boxes.tensor
385 |
386 | gt_classes_i = targets_per_image.gt_classes[gt_matched_idxs]
387 | # Anchors with label 0 are treated as background.
388 | gt_classes_i[anchor_labels == 0] = self.num_classes
389 | # Anchors with label -1 are ignored.
390 | gt_classes_i[anchor_labels == -1] = -1
391 |
392 | else:
393 | gt_classes_i = torch.zeros_like(gt_matched_idxs) + self.num_classes
394 | gt_anchors_reg_targets_i = torch.zeros_like(all_anchors.tensor)
395 |
396 | gt_classes.append([gt_classes_i[s:e] for s, e in zip(start_inds, end_inds)])
397 | gt_anchors_targets.append([gt_anchors_reg_targets_i[s:e] for s, e in zip(start_inds, end_inds)])
398 |
399 | gt_classes = [torch.stack([x[i] for x in gt_classes]) for i in range(anchor_layers)]
400 | gt_anchors_targets = [torch.stack([x[i] for x in gt_anchors_targets]) for i in range(anchor_layers)]
401 |
402 | gt_classes = torch.cat([x.flatten() for x in gt_classes])
403 | gt_anchors_targets = torch.cat([x.reshape(-1, 4) for x in gt_anchors_targets])
404 |
405 | return gt_classes, gt_anchors_targets
406 |
407 |
408 | @torch.no_grad()
409 | def get_query_gt(self, small_anchor_centers, targets):
410 | small_gt_cls = []
411 | for lind, anchor_center in enumerate(small_anchor_centers):
412 | per_layer_small_gt = []
413 | for target_per_image in targets:
414 | target_box_scales = get_box_scales(target_per_image.gt_boxes)
415 |
416 | small_inds = (target_box_scales < self.small_obj_scale[lind][1]) & (target_box_scales >= self.small_obj_scale[lind][0])
417 | small_boxes = target_per_image[small_inds]
418 | center_dis, minarg = get_anchor_center_min_dis(small_boxes.gt_boxes.get_centers(), anchor_center)
419 | small_obj_target = torch.zeros_like(center_dis)
420 |
421 | if len(small_boxes) != 0:
422 | min_small_target_scale = (target_box_scales[small_inds])[minarg]
423 | small_obj_target[center_dis < min_small_target_scale * self.small_center_dis_coeff[lind]] = 1
424 |
425 | per_layer_small_gt.append(small_obj_target)
426 | small_gt_cls.append(torch.stack(per_layer_small_gt))
427 |
428 | return small_gt_cls
429 |
430 |
431 | def inference(self,
432 | retina_box_cls, retina_box_delta, retina_anchors,
433 | small_det_logits, small_det_delta, small_det_anchors,
434 | image_sizes
435 | ):
436 | results = []
437 |
438 | N, _, _, _ = retina_box_cls[0].size()
439 | retina_box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in retina_box_cls]
440 | retina_box_delta = [permute_to_N_HWA_K(x, 4) for x in retina_box_delta]
441 | small_det_logits = [x.view(N, -1, self.num_classes) for x in small_det_logits]
442 | small_det_delta = [x.view(N, -1, 4) for x in small_det_delta]
443 |
444 | for img_idx, image_size in enumerate(image_sizes):
445 |
446 | retina_box_cls_per_image = [box_cls_per_level[img_idx] for box_cls_per_level in retina_box_cls]
447 | retina_box_reg_per_image = [box_reg_per_level[img_idx] for box_reg_per_level in retina_box_delta]
448 | small_det_logits_per_image = [small_det_cls_per_level[img_idx] for small_det_cls_per_level in small_det_logits]
449 | small_det_reg_per_image = [small_det_reg_per_level[img_idx] for small_det_reg_per_level in small_det_delta]
450 |
451 | if len(small_det_anchors) == 0 or type(small_det_anchors[0]) == torch.Tensor:
452 | small_det_anchor_per_image = [small_det_anchor_per_level[img_idx] for small_det_anchor_per_level in small_det_anchors]
453 | else:
454 | small_det_anchor_per_image = small_det_anchors
455 |
456 | results_per_img = self.inference_single_image(
457 | retina_box_cls_per_image, retina_box_reg_per_image, retina_anchors,
458 | small_det_logits_per_image, small_det_reg_per_image, small_det_anchor_per_image,
459 | tuple(image_size))
460 | results.append(results_per_img)
461 |
462 | return results
463 |
464 |
465 | def inference_single_image(self,
466 | retina_box_cls, retina_box_delta, retina_anchors,
467 | small_det_logits, small_det_delta, small_det_anchors,
468 | image_size
469 | ):
470 | with autocast(False):
471 | # small pos cls inference
472 | all_cls = small_det_logits + retina_box_cls
473 | all_delta = small_det_delta + retina_box_delta
474 | all_anchors = small_det_anchors + retina_anchors
475 |
476 | boxes_all, scores_all, class_idxs_all = self.decode_dets(all_cls, all_delta, all_anchors)
477 | boxes_all, scores_all, class_idxs_all = [cat(x) for x in [boxes_all, scores_all, class_idxs_all]]
478 |
479 | if self.use_soft_nms:
480 | keep, soft_nms_scores = self.soft_nmser(boxes_all, scores_all, class_idxs_all)
481 | else:
482 | keep = batched_nms(boxes_all, scores_all, class_idxs_all, self.nms_threshold)
483 | result = Instances(image_size)
484 |
485 | keep = keep[: self.max_detections_per_image]
486 | result.pred_boxes = Boxes(boxes_all[keep])
487 | result.scores = scores_all[keep]
488 | result.pred_classes = class_idxs_all[keep]
489 | return result
490 |
491 |
492 | def preprocess_image(self, batched_inputs):
493 | images = [x["image"].to(self.device) for x in batched_inputs]
494 | images = [(x - self.pixel_mean) / self.pixel_std for x in images]
495 | images = ImageList.from_tensors(images, self.backbone.size_divisibility)
496 | return images
497 |
498 | def decode_dets(self, cls_results, reg_results, anchors):
499 | boxes_all = []
500 | scores_all = []
501 | class_idxs_all = []
502 |
503 | for cls_i, reg_i, anchors_i in zip(cls_results, reg_results, anchors):
504 | cls_i = cls_i.view(-1, self.num_classes)
505 | reg_i = reg_i.view(-1, 4)
506 |
507 | cls_i = cls_i.flatten().sigmoid_() # (HxWxAxK,)
508 | num_topk = min(self.topk_candidates, reg_i.size(0))
509 |
510 | predicted_prob, topk_idxs = cls_i.sort(descending=True)
511 | predicted_prob = predicted_prob[:num_topk]
512 | topk_idxs = topk_idxs[:num_topk]
513 |
514 | # filter out the proposals with low confidence score
515 | keep_idxs = predicted_prob > self.score_threshold
516 | predicted_prob = predicted_prob[keep_idxs]
517 | topk_idxs = topk_idxs[keep_idxs]
518 |
519 | anchor_idxs = topk_idxs // self.num_classes
520 | classes_idxs = topk_idxs % self.num_classes
521 | predicted_class = classes_idxs
522 |
523 | reg_i = reg_i[anchor_idxs]
524 | anchors_i = anchors_i[anchor_idxs]
525 |
526 | if type(anchors_i) != torch.Tensor:
527 | anchors_i = anchors_i.tensor
528 |
529 | predicted_boxes = self.box2box_transform.apply_deltas(reg_i, anchors_i)
530 |
531 | boxes_all.append(predicted_boxes)
532 | scores_all.append(predicted_prob)
533 | class_idxs_all.append(predicted_class)
534 |
535 | return boxes_all, scores_all, class_idxs_all
536 |
537 |
538 |
--------------------------------------------------------------------------------
/models/querydet/qinfer.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch
3 | import torch.nn.functional as F
4 | import spconv.pytorch as spconv
5 |
6 |
7 | def permute_to_N_HWA_K(tensor, K):
8 | assert tensor.dim() == 4, tensor.shape
9 | N, _, H, W = tensor.shape
10 | tensor = tensor.view(N, -1, K, H, W)
11 | tensor = tensor.permute(0, 3, 4, 1, 2)
12 | tensor = tensor.reshape(N, -1, K) # Size=(N,HWA,K)
13 | return tensor
14 |
15 | def run_conv2d(x, weights, bias):
16 | n_conv = len(weights)
17 | for i in range(n_conv):
18 | x = F.conv2d(x, weights[i], bias[i])
19 | if i != n_conv - 1:
20 | x = F.relu(x)
21 | return x
22 |
23 |
24 | class QueryInfer(object):
25 | def __init__(self, anchor_num, num_classes, score_th=0.12, context=2):
26 |
27 | self.anchor_num = anchor_num
28 | self.num_classes = num_classes
29 | self.score_th = score_th
30 | self.context = context
31 |
32 | self.initialized = False
33 | self.cls_spconv = None
34 | self.bbox_spconv = None
35 | self.qcls_spconv = None
36 | self.qcls_conv = None
37 | self.n_conv = None
38 |
39 |
40 | def _make_sparse_tensor(self, query_logits, last_ys, last_xs, anchors, feature_value):
41 | if last_ys is None:
42 | N, _, qh, qw = query_logits.size()
43 | assert N == 1
44 | prob = torch.sigmoid_(query_logits).view(-1)
45 | pidxs = torch.where(prob > self.score_th)[0]# .float()
46 | y = torch.div(pidxs, qw).int()
47 | x = torch.remainder(pidxs, qw).int()
48 | else:
49 | prob = torch.sigmoid_(query_logits).view(-1)
50 | pidxs = prob > self.score_th
51 | y = last_ys[pidxs]
52 | x = last_xs[pidxs]
53 |
54 | if y.size(0) == 0:
55 | return None, None, None, None, None, None
56 |
57 | _, fc, fh, fw = feature_value.shape
58 |
59 | ys, xs = [], []
60 | for i in range(2):
61 | for j in range(2):
62 | ys.append(y * 2 + i)
63 | xs.append(x * 2 + j)
64 |
65 | ys = torch.cat(ys, dim=0)
66 | xs = torch.cat(xs, dim=0)
67 | inds = (ys * fw + xs).long()
68 |
69 | sparse_ys = []
70 | sparse_xs = []
71 |
72 | for i in range(-1*self.context, self.context+1):
73 | for j in range(-1*self.context, self.context+1):
74 | sparse_ys.append(ys+i)
75 | sparse_xs.append(xs+j)
76 |
77 | sparse_ys = torch.cat(sparse_ys, dim=0)
78 | sparse_xs = torch.cat(sparse_xs, dim=0)
79 |
80 |
81 | good_idx = (sparse_ys >= 0) & (sparse_ys < fh) & (sparse_xs >= 0) & (sparse_xs < fw)
82 | sparse_ys = sparse_ys[good_idx]
83 | sparse_xs = sparse_xs[good_idx]
84 |
85 | sparse_yx = torch.stack((sparse_ys, sparse_xs), dim=0).t()
86 | sparse_yx = torch.unique(sparse_yx, sorted=False, dim=0)
87 |
88 | sparse_ys = sparse_yx[:, 0]
89 | sparse_xs = sparse_yx[:, 1]
90 |
91 | sparse_inds = (sparse_ys * fw + sparse_xs).long()
92 |
93 | sparse_features = feature_value.view(fc, -1).transpose(0, 1)[sparse_inds].view(-1, fc)
94 | sparse_indices = torch.stack((torch.zeros_like(sparse_ys), sparse_ys, sparse_xs), dim=-1)
95 | sparse_tensor = spconv.SparseConvTensor(sparse_features, sparse_indices.int(), (fh, fw), 1)
96 |
97 | anchors = anchors.tensor.view(-1, self.anchor_num, 4)
98 | selected_anchors = anchors[inds].view(1, -1, 4)
99 | return sparse_tensor, ys, xs, inds, selected_anchors, sparse_indices.size(0)
100 |
101 | def _make_spconv(self, weights, biases):
102 | nets = []
103 | for i in range(len(weights)):
104 | in_channel = weights[i].shape[1]
105 | out_channel = weights[i].shape[0]
106 | k_size = weights[i].shape[2]
107 | filter = spconv.SubMConv2d(in_channel, out_channel, k_size, 1, padding=k_size//2, indice_key="asd", algo=spconv.ConvAlgo.Native).to(device=weights[i].device)
108 | filter.weight.data[:] = weights[i].permute(2,3,1,0).contiguous()[:] # transpose(1,2).transpose(0,1).transpose(2,3).transpose(1,2).transpose(2,3)
109 | filter.bias.data = biases[i]
110 | nets.append(filter)
111 | if i != len(weights) - 1:
112 | nets.append(torch.nn.ReLU(inplace=True))
113 | return spconv.SparseSequential(*nets)
114 |
115 | def _make_conv(self, weights, biases):
116 | nets = []
117 | for i in range(len(weights)):
118 | in_channel = weights[i].shape[0]
119 | out_channel = weights[i].shape[1]
120 | k_size = weights[i].shape[2]
121 | filter = torch.nn.Conv2d(in_channel, out_channel, k_size, 1, padding=k_size//2)
122 | filter.weight.data = weights[i]
123 | filter.bias.data = biases[i]
124 | nets.append(filter)
125 | if i != len(weights) - 1:
126 | nets.append(torch.nn.ReLU())
127 | return torch.nn.Sequential(*nets)
128 |
129 | def _run_spconvs(self, x, filters):
130 | y = filters(x)
131 | return y.dense(channels_first=False)
132 |
133 | def _run_convs(self, x, filters):
134 | return filters(x)
135 |
136 | def run_qinfer(self, model_params, features_key, features_value, anchors_value):
137 |
138 | if not self.initialized:
139 | cls_weights, cls_biases, bbox_weights, bbox_biases, qcls_weights, qcls_biases = model_params
140 | assert len(cls_weights) == len(qcls_weights)
141 | self.n_conv = len(cls_weights)
142 | self.cls_spconv = self._make_spconv(cls_weights, cls_biases)
143 | self.bbox_spconv = self._make_spconv(bbox_weights, bbox_biases)
144 | self.qcls_spconv = self._make_spconv(qcls_weights, qcls_biases)
145 | self.qcls_conv = self._make_conv(qcls_weights, qcls_biases)
146 | self.initialized = True
147 |
148 | last_ys, last_xs = None, None
149 | query_logits = self._run_convs(features_key[-1], self.qcls_conv)
150 | det_cls_query, det_bbox_query, query_anchors = [], [], []
151 |
152 | n_inds_all = []
153 |
154 | for i in range(len(features_value)-1, -1, -1):
155 | x, last_ys, last_xs, inds, selected_anchors, n_inds = self._make_sparse_tensor(query_logits, last_ys, last_xs, anchors_value[i], features_value[i])
156 | n_inds_all.append(n_inds)
157 | if x == None:
158 | break
159 | cls_result = self._run_spconvs(x, self.cls_spconv).view(-1, self.anchor_num*self.num_classes)[inds]
160 | bbox_result = self._run_spconvs(x, self.bbox_spconv).view(-1, self.anchor_num*4)[inds]
161 | query_logits = self._run_spconvs(x, self.qcls_spconv).view(-1)[inds]
162 |
163 | query_anchors.append(selected_anchors)
164 | det_cls_query.append(torch.unsqueeze(cls_result, 0))
165 | det_bbox_query.append(torch.unsqueeze(bbox_result, 0))
166 |
167 | return det_cls_query, det_bbox_query, query_anchors
168 |
169 |
170 |
--------------------------------------------------------------------------------
/models/retinanet/__pycache__/retinanet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/models/retinanet/__pycache__/retinanet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/retinanet/__pycache__/retinanet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/models/retinanet/__pycache__/retinanet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/retinanet/retinanet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import logging
3 | import math
4 | import time
5 | import numpy as np
6 | from typing import List
7 | import torch
8 | from fvcore.nn import sigmoid_focal_loss_jit, smooth_l1_loss
9 | from torch import nn
10 | import torch.nn.functional as F
11 |
12 | from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm, Conv2d
13 | from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
14 | from detectron2.utils.events import get_event_storage
15 | from detectron2.utils.logger import log_first_n
16 | import detectron2.utils.comm as comm
17 |
18 | from detectron2.modeling.anchor_generator import build_anchor_generator
19 | from detectron2.modeling.backbone import build_backbone
20 | from detectron2.modeling.box_regression import Box2BoxTransform
21 | from detectron2.modeling.matcher import Matcher
22 | from detectron2.modeling.postprocessing import detector_postprocess
23 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
24 |
25 | from torch.cuda import Event
26 | from utils.loop_matcher import LoopMatcher
27 |
28 |
29 | __all__ = ["RetinaNet_D2"]
30 |
31 |
32 | def permute_to_N_HWA_K(tensor, K):
33 | """
34 | Transpose/reshape a tensor from (N, (A x K), H, W) to (N, (HxWxA), K)
35 | """
36 | assert tensor.dim() == 4, tensor.shape
37 | N, _, H, W = tensor.shape
38 | tensor = tensor.view(N, -1, K, H, W)
39 | tensor = tensor.permute(0, 3, 4, 1, 2)
40 | tensor = tensor.reshape(N, -1, K) # Size=(N,HWA,K)
41 | return tensor
42 |
43 |
44 | def permute_all_cls_and_box_to_N_HWA_K_and_concat(box_cls, box_delta, num_classes=80):
45 | """
46 | Rearrange the tensor layout from the network output, i.e.:
47 | list[Tensor]: #lvl tensors of shape (N, A x K, Hi, Wi)
48 | to per-image predictions, i.e.:
49 | Tensor: of shape (N x sum(Hi x Wi x A), K)
50 | """
51 | # for each feature level, permute the outputs to make them be in the
52 | # same format as the labels. Note that the labels are computed for
53 | # all feature levels concatenated, so we keep the same representation
54 | # for the objectness and the box_delta
55 | box_cls_flattened = [permute_to_N_HWA_K(x, num_classes) for x in box_cls]
56 | box_delta_flattened = [permute_to_N_HWA_K(x, 4) for x in box_delta]
57 | # concatenate on the first dimension (representing the feature levels), to
58 | # take into account the way the labels were generated (with all feature maps
59 | # being concatenated as well)
60 | box_cls = cat(box_cls_flattened, dim=1).view(-1, num_classes)
61 | box_delta = cat(box_delta_flattened, dim=1).view(-1, 4)
62 | return box_cls, box_delta
63 |
64 |
65 | def permute_all_to_NHWA_K_not_concat(box_cls, box_delta, num_classes=80):
66 | box_cls_flattened = [permute_to_N_HWA_K(x, num_classes).view(-1, num_classes) for x in box_cls]
67 | box_delta_flattened = [permute_to_N_HWA_K(x, 4).view(-1, 4) for x in box_delta]
68 | return box_cls_flattened, box_delta_flattened
69 |
70 | @META_ARCH_REGISTRY.register()
71 | class RetinaNet_D2(nn.Module):
72 | """
73 | Implement RetinaNet in :paper:`RetinaNet`.
74 | """
75 |
76 | def __init__(self, cfg):
77 | super().__init__()
78 |
79 | # fmt: off
80 | self.num_classes = cfg.MODEL.RETINANET.NUM_CLASSES
81 | self.in_features = cfg.MODEL.RETINANET.IN_FEATURES
82 | # Loss parameters:
83 | self.focal_loss_alpha = cfg.MODEL.CUSTOM.FOCAL_LOSS_ALPHAS
84 | self.focal_loss_gamma = cfg.MODEL.CUSTOM.FOCAL_LOSS_GAMMAS
85 | self.cls_weights = cfg.MODEL.CUSTOM.CLS_WEIGHTS
86 | self.reg_weights = cfg.MODEL.CUSTOM.REG_WEIGHTS
87 | self.smooth_l1_loss_beta = cfg.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA
88 | # Inference parameters:
89 | self.score_threshold = cfg.MODEL.RETINANET.SCORE_THRESH_TEST
90 | self.topk_candidates = cfg.MODEL.RETINANET.TOPK_CANDIDATES_TEST
91 | self.nms_threshold = cfg.MODEL.RETINANET.NMS_THRESH_TEST
92 | self.max_detections_per_image = cfg.TEST.DETECTIONS_PER_IMAGE
93 | # Vis parameters
94 | self.vis_period = cfg.VIS_PERIOD
95 | self.input_format = cfg.INPUT.FORMAT
96 | self.scale_factor = 1
97 | # fmt: on
98 |
99 | self.backbone = build_backbone(cfg)
100 |
101 | backbone_shape = self.backbone.output_shape()
102 | feature_shapes = [backbone_shape[f] for f in self.in_features]
103 | self.head = RetinaNetHead(cfg, feature_shapes)
104 | self.anchor_generator = build_anchor_generator(cfg, feature_shapes)
105 |
106 | # Matching and loss
107 | self.box2box_transform = Box2BoxTransform(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS)
108 | if cfg.MODEL.CUSTOM.USE_LOOP_MATCHER:
109 | self.matcher = LoopMatcher(
110 | cfg.MODEL.RETINANET.IOU_THRESHOLDS,
111 | cfg.MODEL.RETINANET.IOU_LABELS,
112 | allow_low_quality_matches=True,
113 | )
114 | else:
115 | self.matcher = Matcher(
116 | cfg.MODEL.RETINANET.IOU_THRESHOLDS,
117 | cfg.MODEL.RETINANET.IOU_LABELS,
118 | allow_low_quality_matches=True,
119 | )
120 |
121 | self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1))
122 | self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1))
123 |
124 | """
125 | In Detectron1, loss is normalized by number of foreground samples in the batch.
126 | When batch size is 1 per GPU, #foreground has a large variance and
127 | using it lead to lower performance. Here we maintain an EMA of #foreground to
128 | stabilize the normalizer.
129 | """
130 | self.loss_normalizer = 100 # initialize with any reasonable #fg that's not too small
131 | self.loss_normalizer_momentum = 0.9
132 |
133 | self.iter = 0
134 | self.class_stat = [0 for _ in range(10)]
135 |
136 | @property
137 | def device(self):
138 | return self.pixel_mean.device
139 |
140 |
141 | def visualize_training(self, batched_inputs, results):
142 | from detectron2.utils.visualizer import Visualizer
143 |
144 | assert len(batched_inputs) == len(
145 | results
146 | ), "Cannot visualize inputs and results of different sizes"
147 | storage = get_event_storage()
148 | max_boxes = 20
149 |
150 | image_index = 0 # only visualize a single image
151 | img = batched_inputs[image_index]["image"].cpu().numpy()
152 | assert img.shape[0] == 3, "Images should have 3 channels."
153 | if self.input_format == "BGR":
154 | img = img[::-1, :, :]
155 | img = img.transpose(1, 2, 0)
156 | v_gt = Visualizer(img, None)
157 | v_gt = v_gt.overlay_instances(boxes=batched_inputs[image_index]["instances"].gt_boxes)
158 | anno_img = v_gt.get_image()
159 | processed_results = detector_postprocess(results[image_index], img.shape[0], img.shape[1])
160 | predicted_boxes = processed_results.pred_boxes.tensor.detach().cpu().numpy()
161 |
162 | v_pred = Visualizer(img, None)
163 | v_pred = v_pred.overlay_instances(boxes=predicted_boxes[0:max_boxes])
164 | prop_img = v_pred.get_image()
165 | vis_img = np.vstack((anno_img, prop_img))
166 | vis_img = vis_img.transpose(2, 0, 1)
167 | vis_name = f"Top: GT bounding boxes; Bottom: {max_boxes} Highest Scoring Results"
168 | storage.put_image(vis_name, vis_img)
169 |
170 |
171 | def forward(self, batched_inputs):
172 | start_event = Event(enable_timing=True)
173 | end_event = Event(enable_timing=True)
174 |
175 | images = self.preprocess_image(batched_inputs)
176 | if "instances" in batched_inputs[0]:
177 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
178 | elif "targets" in batched_inputs[0]:
179 | log_first_n(
180 | logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10
181 | )
182 | gt_instances = [x["targets"].to(self.device) for x in batched_inputs]
183 | else:
184 | gt_instances = None
185 |
186 | start_event.record()
187 |
188 | features = self.backbone(images.tensor)
189 | features = [features[f] for f in self.in_features]
190 | box_cls, box_delta = self.head(features)
191 | anchors = self.anchor_generator(features)
192 |
193 | if self.training:
194 | # torch.cuda.empty_cache()
195 | # gt_classes, gt_anchors_reg_deltas = self.get_ground_truth(anchors, gt_instances)
196 | # losses = self.losses(gt_classes, gt_anchors_reg_deltas, box_cls, box_delta)
197 |
198 | gt_classes, gt_deltas = self.get_det_gt(anchors, gt_instances)
199 | losses = self.det_loss(gt_classes, gt_deltas, box_cls, box_delta, self.focal_loss_alpha, self.focal_loss_gamma, self.cls_weights, self.reg_weights)
200 |
201 |
202 | if self.vis_period > 0:
203 | storage = get_event_storage()
204 | if storage.iter % self.vis_period == 0:
205 | results = self.inference(box_cls, box_delta, anchors, images.image_sizes)
206 | self.visualize_training(batched_inputs, results)
207 |
208 | return losses
209 | else:
210 | results = self.inference(box_cls, box_delta, anchors, images.image_sizes)
211 | end_event.record()
212 | torch.cuda.synchronize()
213 | total_time = start_event.elapsed_time(end_event)
214 | processed_results = []
215 | for results_per_image, input_per_image, image_size in zip(
216 | results, batched_inputs, images.image_sizes
217 | ):
218 | height = input_per_image.get("height", image_size[0])
219 | width = input_per_image.get("width", image_size[1])
220 | r = detector_postprocess(results_per_image, height, width)
221 | processed_results.append({"instances": r, 'time':total_time})
222 | return processed_results
223 |
224 |
225 | @torch.no_grad()
226 | def get_det_gt(self, anchors, targets):
227 | gt_classes = []
228 | gt_anchors_deltas = []
229 | anchor_layers = len(anchors)
230 | anchor_lens = [len(x) for x in anchors]
231 | start_inds = [0] + [sum(anchor_lens[:i]) for i in range(1, len(anchor_lens))]
232 | end_inds = [sum(anchor_lens[:i+1]) for i in range(len(anchor_lens))]
233 | anchors = Boxes.cat(anchors) # Rx4
234 |
235 | for targets_per_image in targets:
236 | if type(self.matcher) == Matcher:
237 | match_quality_matrix = pairwise_iou(targets_per_image.gt_boxes, anchors)
238 | gt_matched_idxs, anchor_labels = self.matcher(match_quality_matrix)
239 | del(match_quality_matrix)
240 | else:
241 | gt_matched_idxs, anchor_labels = self.matcher(targets_per_image.gt_boxes, anchors)
242 |
243 | has_gt = len(targets_per_image) > 0
244 | if has_gt:
245 | # ground truth box regression
246 | matched_gt_boxes = targets_per_image.gt_boxes[gt_matched_idxs]
247 | gt_anchors_reg_deltas_i = self.box2box_transform.get_deltas(
248 | anchors.tensor, matched_gt_boxes.tensor
249 | )
250 |
251 | gt_classes_i = targets_per_image.gt_classes[gt_matched_idxs]
252 | # Anchors with label 0 are treated as background.
253 | gt_classes_i[anchor_labels == 0] = self.num_classes
254 | # Anchors with label -1 are ignored.
255 | gt_classes_i[anchor_labels == -1] = -1
256 |
257 | else:
258 | gt_classes_i = torch.zeros_like(gt_matched_idxs) + self.num_classes
259 | gt_anchors_reg_deltas_i = torch.zeros_like(anchors.tensor)
260 |
261 | gt_classes.append([gt_classes_i[s:e] for s, e in zip(start_inds, end_inds)])
262 | gt_anchors_deltas.append([gt_anchors_reg_deltas_i[s:e] for s, e in zip(start_inds, end_inds)])
263 |
264 | gt_classes = [torch.stack([x[i] for x in gt_classes]) for i in range(anchor_layers)]
265 | gt_anchors_deltas = [torch.stack([x[i] for x in gt_anchors_deltas]) for i in range(anchor_layers)]
266 |
267 | gt_classes = torch.cat([x.flatten() for x in gt_classes])
268 | gt_anchors_deltas = torch.cat([x.reshape(-1, 4) for x in gt_anchors_deltas])
269 |
270 | return gt_classes, gt_anchors_deltas
271 |
272 |
273 | def det_loss(self, gt_classes, gt_anchors_deltas, pred_logits, pred_deltas, alphas, gammas, cls_weights, reg_weights):
274 | def convert_gt_cls(logits, gt_class, f_idxs):
275 | gt_classes_target = torch.zeros_like(logits)
276 | gt_classes_target[f_idxs, gt_class[f_idxs]] = 1
277 | return gt_classes_target
278 |
279 | assert len(cls_weights) == len(pred_logits)
280 | assert len(cls_weights) == len(reg_weights)
281 |
282 | pred_logits, pred_deltas = permute_all_to_NHWA_K_not_concat(pred_logits, pred_deltas, self.num_classes)
283 |
284 | lengths = [x.shape[0] for x in pred_logits]
285 | start_inds = [0] + [sum(lengths[:i]) for i in range(1, len(lengths))]
286 | end_inds = [sum(lengths[:i+1]) for i in range(len(lengths))]
287 |
288 | gt_classes = gt_classes.flatten()
289 | gt_anchors_deltas = gt_anchors_deltas.view(-1, 4)
290 |
291 | valid_idxs = gt_classes >= 0
292 | foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
293 | num_foreground = foreground_idxs.sum().item()
294 | get_event_storage().put_scalar("num_foreground", num_foreground)
295 | self.loss_normalizer = (
296 | self.loss_normalizer_momentum * self.loss_normalizer
297 | + (1 - self.loss_normalizer_momentum) * num_foreground
298 | )
299 | gt_clsses_list = [gt_classes[s:e] for s, e in zip(start_inds, end_inds)]
300 | gt_anchors_deltas_list = [gt_anchors_deltas[s:e] for s, e in zip(start_inds, end_inds)]
301 | valid_idxs_list = [valid_idxs[s:e] for s, e in zip(start_inds, end_inds)]
302 | foreground_idxs_list = [foreground_idxs[s:e] for s, e in zip(start_inds, end_inds)]
303 |
304 | loss_cls = [
305 | w * sigmoid_focal_loss_jit(
306 | x[v],
307 | convert_gt_cls(x, g, f)[v].detach(),
308 | alpha=alpha,
309 | gamma=gamma,
310 | reduction="sum"
311 | )
312 | for w, x, g, v, f, alpha, gamma in zip(cls_weights, pred_logits, gt_clsses_list, valid_idxs_list, foreground_idxs_list, alphas, gammas)
313 | ]
314 |
315 | loss_box_reg = [
316 | w * smooth_l1_loss(
317 | x[f],
318 | g[f].detach(),
319 | beta=self.smooth_l1_loss_beta,
320 | reduction="sum"
321 | )
322 | for w, x, g, f in zip(reg_weights, pred_deltas, gt_anchors_deltas_list, foreground_idxs_list)
323 | ]
324 |
325 | loss_cls = sum(loss_cls) / max(1., self.loss_normalizer)
326 | loss_box_reg = sum(loss_box_reg) / max(1., self.loss_normalizer)
327 | return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
328 |
329 |
330 | def inference(self, box_cls, box_delta, anchors, image_sizes):
331 | """
332 | Arguments:
333 | box_cls, box_delta: Same as the output of :meth:`RetinaNetHead.forward`
334 | anchors (list[Boxes]): A list of #feature level Boxes.
335 | The Boxes contain anchors of this image on the specific feature level.
336 | image_sizes (List[torch.Size]): the input image sizes
337 |
338 | Returns:
339 | results (List[Instances]): a list of #images elements.
340 | """
341 | results = []
342 | times = []
343 |
344 | box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls]
345 | box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta]
346 |
347 | for img_idx, image_size in enumerate(image_sizes):
348 | box_cls_per_image = [box_cls_per_level[img_idx] for box_cls_per_level in box_cls]
349 | box_reg_per_image = [box_reg_per_level[img_idx] for box_reg_per_level in box_delta]
350 | results_per_image = self.inference_single_image(
351 | box_cls_per_image, box_reg_per_image, anchors, (image_size[0]*self.scale_factor, image_size[1]*self.scale_factor)
352 | )
353 | results.append(results_per_image)
354 | return results
355 |
356 |
357 | def inference_single_image(self, box_cls, box_delta, anchors, image_size):
358 | """
359 | Single-image inference. Return bounding-box detection results by thresholding
360 | on scores and applying non-maximum suppression (NMS).
361 |
362 | Arguments:
363 | box_cls (list[Tensor]): list of #feature levels. Each entry contains
364 | tensor of size (H x W x A, K)
365 | box_delta (list[Tensor]): Same shape as 'box_cls' except that K becomes 4.
366 | anchors (list[Boxes]): list of #feature levels. Each entry contains
367 | a Boxes object, which contains all the anchors for that
368 | image in that feature level.
369 | image_size (tuple(H, W)): a tuple of the image height and width.
370 |
371 | Returns:
372 | Same as `inference`, but for only one image.
373 | """
374 | boxes_all = []
375 | scores_all = []
376 | class_idxs_all = []
377 |
378 | # Iterate over every feature level
379 | for box_cls_i, box_reg_i, anchors_i in zip(box_cls, box_delta, anchors):
380 | # (HxWxAxK,)
381 | box_cls_i = box_cls_i.flatten().sigmoid_()
382 |
383 | # Keep top k top scoring indices only.
384 | num_topk = min(self.topk_candidates, box_reg_i.size(0))
385 | # torch.sort is actually faster than .topk (at least on GPUs)
386 | predicted_prob, topk_idxs = box_cls_i.sort(descending=True)
387 | predicted_prob = predicted_prob[:num_topk]
388 | topk_idxs = topk_idxs[:num_topk]
389 |
390 | # filter out the proposals with low confidence score
391 | keep_idxs = predicted_prob > self.score_threshold
392 | predicted_prob = predicted_prob[keep_idxs]
393 | topk_idxs = topk_idxs[keep_idxs]
394 |
395 | anchor_idxs = topk_idxs // self.num_classes
396 | classes_idxs = topk_idxs % self.num_classes
397 |
398 | box_reg_i = box_reg_i[anchor_idxs]
399 | anchors_i = anchors_i[anchor_idxs]
400 | # predict boxes
401 | predicted_boxes = self.box2box_transform.apply_deltas(box_reg_i, anchors_i.tensor)
402 |
403 | boxes_all.append(predicted_boxes)
404 | scores_all.append(predicted_prob)
405 | class_idxs_all.append(classes_idxs)
406 |
407 | boxes_all, scores_all, class_idxs_all = [
408 | cat(x) for x in [boxes_all, scores_all, class_idxs_all]
409 | ]
410 |
411 | keep = batched_nms(boxes_all, scores_all, class_idxs_all, self.nms_threshold)
412 |
413 | keep = keep[: self.max_detections_per_image]
414 |
415 | result = Instances(image_size)
416 | result.pred_boxes = Boxes(boxes_all[keep])
417 | result.scores = scores_all[keep]
418 | result.pred_classes = class_idxs_all[keep]
419 | return result
420 |
421 |
422 | def preprocess_image(self, batched_inputs):
423 | """
424 | Normalize, pad and batch the input images.
425 | """
426 | images = [x["image"].to(self.device) for x in batched_inputs]
427 | images = [(x - self.pixel_mean) / self.pixel_std for x in images]
428 | images = ImageList.from_tensors(images, self.backbone.size_divisibility)
429 | return images
430 |
431 |
432 | class RetinaNetHead(nn.Module):
433 | """
434 | The head used in RetinaNet for object classification and box regression.
435 | It has two subnets for the two tasks, with a common structure but separate parameters.
436 | """
437 |
438 | def __init__(self, cfg, input_shape: List[ShapeSpec]):
439 | super().__init__()
440 | # fmt: off
441 | in_channels = input_shape[0].channels
442 | num_classes = cfg.MODEL.RETINANET.NUM_CLASSES
443 | num_convs = cfg.MODEL.RETINANET.NUM_CONVS
444 | prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB
445 | num_anchors = build_anchor_generator(cfg, input_shape).num_cell_anchors
446 | # fmt: on
447 | assert (
448 | len(set(num_anchors)) == 1
449 | ), "Using different number of anchors between levels is not currently supported!"
450 | num_anchors = num_anchors[0]
451 |
452 | cls_subnet = []
453 | bbox_subnet = []
454 | for _ in range(num_convs):
455 | cls_subnet.append(
456 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
457 | )
458 | cls_subnet.append(nn.ReLU())
459 | bbox_subnet.append(
460 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
461 | )
462 | bbox_subnet.append(nn.ReLU())
463 |
464 | self.cls_subnet = nn.Sequential(*cls_subnet)
465 | self.bbox_subnet = nn.Sequential(*bbox_subnet)
466 | self.cls_score = nn.Conv2d(
467 | in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1
468 | )
469 | self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
470 |
471 | # Initialization
472 | for modules in [self.cls_subnet, self.bbox_subnet, self.cls_score, self.bbox_pred]:
473 | for layer in modules.modules():
474 | if isinstance(layer, nn.Conv2d):
475 | #torch.nn.init.xavier_normal_(layer.weight)
476 | torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
477 | torch.nn.init.constant_(layer.bias, 0)
478 |
479 | # Use prior in model initialization to improve stability
480 | bias_value = -(math.log((1 - prior_prob) / prior_prob))
481 | torch.nn.init.constant_(self.cls_score.bias, bias_value)
482 |
483 | def forward(self, features):
484 | """
485 | Arguments:
486 | features (list[Tensor]): FPN feature map tensors in high to low resolution.
487 | Each tensor in the list correspond to different feature levels.
488 |
489 | Returns:
490 | logits (list[Tensor]): #lvl tensors, each has shape (N, AxK, Hi, Wi).
491 | The tensor predicts the classification probability
492 | at each spatial position for each of the A anchors and K object
493 | classes.
494 | bbox_reg (list[Tensor]): #lvl tensors, each has shape (N, Ax4, Hi, Wi).
495 | The tensor predicts 4-vector (dx,dy,dw,dh) box
496 | regression values for every anchor. These values are the
497 | relative offset between the anchor and the ground truth box.
498 | """
499 | logits = []
500 | bbox_reg = []
501 | for feature in features:
502 | logits.append(self.cls_score(self.cls_subnet(feature)))
503 | bbox_reg.append(self.bbox_pred(self.bbox_subnet(feature)))
504 | return logits, bbox_reg
505 |
--------------------------------------------------------------------------------
/train_coco.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from detectron2.engine import launch
4 | from train_tools.coco_train import default_argument_parser, start_train
5 |
6 | from models.retinanet.retinanet import RetinaNet_D2
7 | from models.querydet.detector import RetinaNetQueryDet
8 |
9 |
10 | if __name__ == '__main__':
11 | args = default_argument_parser().parse_args()
12 | print("Command Line Args:", args)
13 | launch(
14 | start_train,
15 | args.num_gpus,
16 | num_machines=args.num_machines,
17 | machine_rank=args.machine_rank,
18 | dist_url=args.dist_url,
19 | args=(args,),
20 | )
--------------------------------------------------------------------------------
/train_tools/coco_infer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | """
4 | Detection Training Script.
5 |
6 | This scripts reads a given config file and runs the training or evaluation.
7 | It is an entry point that is made to train standard models in detectron2.
8 |
9 | In order to let one script support training of many models,
10 | this script contains logic that are specific to these built-in models and therefore
11 | may not be suitable for your own project.
12 | For example, your research project perhaps only needs a single "evaluator".
13 |
14 | Therefore, we recommend you to use detectron2 as an library and take
15 | this file as an example of how to use the library.
16 | You may want to write your own script with your datasets and other customizations.
17 | """
18 |
19 | import logging
20 | import sys
21 | import os
22 | from collections import OrderedDict
23 | import torch
24 | import argparse
25 | from torch.nn.parallel import DistributedDataParallel
26 |
27 | import detectron2.utils.comm as comm
28 | from detectron2.checkpoint import DetectionCheckpointer
29 | from detectron2.config import get_cfg
30 | from detectron2.data import MetadataCatalog, build_detection_test_loader
31 | from detectron2.engine import DefaultTrainer, default_setup, hooks, launch
32 | from detectron2.evaluation import (
33 | CityscapesInstanceEvaluator,
34 | CityscapesSemSegEvaluator,
35 | COCOEvaluator,
36 | COCOPanopticEvaluator,
37 | DatasetEvaluators,
38 | LVISEvaluator,
39 | PascalVOCDetectionEvaluator,
40 | SemSegEvaluator,
41 | verify_results,
42 | )
43 | from detectron2.modeling import GeneralizedRCNNWithTTA
44 | from detectron2.checkpoint import DetectionCheckpointer
45 | from detectron2_backbone.config import add_backbone_config
46 | from detectron2_backbone import mobilenet
47 |
48 | from utils.val_mapper_with_ann import ValMapper
49 | from utils.time_evaluator import GPUTimeEvaluator
50 | from utils.coco_eval_fpn import COCOEvaluatorFPN
51 | from utils.anchor_gen import AnchorGeneratorWithCenter
52 | from configs.custom_config import add_custom_config
53 |
54 |
55 |
56 | class Trainer(DefaultTrainer):
57 | @classmethod
58 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
59 | if output_folder is None:
60 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
61 | evaluator_list = []
62 | if cfg.META_INFO.EVAL_AP:
63 | evaluator_list.append(COCOEvaluatorFPN(dataset_name, cfg, True, os.path.join(cfg.OUTPUT_DIR)))
64 | if cfg.META_INFO.EVAL_GPU_TIME:
65 | evaluator_list.append(GPUTimeEvaluator(True, 'minisecond'))
66 | return DatasetEvaluators(evaluator_list)
67 |
68 | def default_argument_parser(epilog=None):
69 | """
70 | Create a parser with some common arguments used by detectron2 users.
71 |
72 | Args:
73 | epilog (str): epilog passed to ArgumentParser describing the usage.
74 |
75 | Returns:
76 | argparse.ArgumentParser:
77 | """
78 | parser = argparse.ArgumentParser(
79 | epilog=epilog
80 | or f"""
81 | Examples:
82 |
83 | Run on single machine:
84 | $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth
85 |
86 | Run on multiple machines:
87 | (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags]
88 | (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags]
89 | """,
90 | formatter_class=argparse.RawDescriptionHelpFormatter,
91 | )
92 | parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
93 | parser.add_argument(
94 | "--resume",
95 | action="store_true",
96 | help="whether to attempt to resume from the checkpoint directory",
97 | )
98 | parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
99 | parser.add_argument("--no-pretrain", action="store_true", help="whether to load pretrained model")
100 | parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
101 | parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
102 | parser.add_argument(
103 | "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
104 | )
105 |
106 |
107 | # PyTorch still may leave orphan processes in multi-gpu training.
108 | # Therefore we use a deterministic way to obtain port,
109 | # so that users are aware of orphan processes by seeing the port occupied.
110 | port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
111 | parser.add_argument(
112 | "--dist-url",
113 | default="tcp://127.0.0.1:{}".format(port),
114 | help="initialization URL for pytorch distributed backend. See "
115 | "https://pytorch.org/docs/stable/distributed.html for details.",
116 | )
117 | parser.add_argument(
118 | "opts",
119 | help="Modify config options using the command-line",
120 | default=None,
121 | nargs=argparse.REMAINDER,
122 | )
123 | return parser
124 |
125 | def setup(args):
126 | """
127 | Create configs and perform basic setups.
128 | """
129 | cfg = get_cfg()
130 | add_custom_config(cfg)
131 | add_backbone_config(cfg)
132 | cfg.merge_from_file(args.config_file)
133 | cfg.merge_from_list(args.opts)
134 | cfg.freeze()
135 | default_setup(cfg, args)
136 | return cfg
137 |
138 |
139 | def start_train(args):
140 | cfg = setup(args)
141 |
142 | if args.eval_only:
143 | model = Trainer.build_model(cfg)
144 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
145 | cfg.MODEL.WEIGHTS, resume=args.resume
146 | )
147 | res = Trainer.test(cfg, model)
148 | if comm.is_main_process():
149 | verify_results(cfg, res)
150 | return res
151 |
152 | """
153 | If you'd like to do anything fancier than the standard training logic,
154 | consider writing your own training loop (see plain_train_net.py) or
155 | subclassing the trainer.
156 | """
157 | trainer = Trainer(cfg)
158 | if not args.no_pretrain:
159 | trainer.resume_or_load(resume=args.resume)
160 | return trainer.train()
161 |
162 |
163 |
--------------------------------------------------------------------------------
/train_tools/coco_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | """
4 | Detection Training Script.
5 |
6 | This scripts reads a given config file and runs the training or evaluation.
7 | It is an entry point that is made to train standard models in detectron2.
8 |
9 | In order to let one script support training of many models,
10 | this script contains logic that are specific to these built-in models and therefore
11 | may not be suitable for your own project.
12 | For example, your research project perhaps only needs a single "evaluator".
13 |
14 | Therefore, we recommend you to use detectron2 as an library and take
15 | this file as an example of how to use the library.
16 | You may want to write your own script with your datasets and other customizations.
17 | """
18 |
19 | import logging
20 | import sys
21 | import os
22 | from collections import OrderedDict
23 | import torch
24 | import argparse
25 | from torch.nn.parallel import DistributedDataParallel
26 |
27 | import detectron2.utils.comm as comm
28 | from detectron2.checkpoint import DetectionCheckpointer
29 | from detectron2.config import get_cfg
30 | from detectron2.data import MetadataCatalog, build_detection_test_loader
31 | from detectron2.engine import DefaultTrainer, default_setup, hooks, launch
32 | from detectron2.evaluation import (
33 | CityscapesInstanceEvaluator,
34 | CityscapesSemSegEvaluator,
35 | COCOEvaluator,
36 | COCOPanopticEvaluator,
37 | DatasetEvaluators,
38 | LVISEvaluator,
39 | PascalVOCDetectionEvaluator,
40 | SemSegEvaluator,
41 | verify_results,
42 | )
43 | from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
44 | from detectron2.modeling import GeneralizedRCNNWithTTA
45 | from detectron2.checkpoint import DetectionCheckpointer
46 | from detectron2.engine.train_loop import AMPTrainer, SimpleTrainer
47 | from detectron2.engine.defaults import DefaultTrainer
48 |
49 | from utils.val_mapper_with_ann import ValMapper
50 | from utils.anchor_gen import AnchorGeneratorWithCenter
51 | from utils.coco_eval_fpn import COCOEvaluatorFPN
52 |
53 | from configs.custom_config import add_custom_config
54 |
55 | # from detectron2_backbone.config import add_backbone_config
56 | # import detectron2_backbone.backbone.mobilenet
57 |
58 |
59 | class Trainer(DefaultTrainer):
60 | def __init__(self, cfg, resume=False, reuse_ckpt=False):
61 | """
62 | Args:
63 | cfg (CfgNode):
64 | """
65 | super(DefaultTrainer, self).__init__()
66 |
67 | logger = logging.getLogger("detectron2")
68 | if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
69 | setup_logger()
70 | cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
71 |
72 | # Assume these objects must be constructed in this order.
73 | model = self.build_model(cfg)
74 |
75 | ckpt = DetectionCheckpointer(model)
76 | self.start_iter = 0
77 | self.start_iter = ckpt.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
78 | self.iter =self.start_iter
79 |
80 | optimizer = self.build_optimizer(cfg, model)
81 | data_loader = self.build_train_loader(cfg)
82 |
83 | # For training, wrap with DDP. But don't need this for inference.
84 | if comm.get_world_size() > 1:
85 | model = DistributedDataParallel(
86 | model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
87 | )
88 | self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
89 | model, data_loader, optimizer
90 | )
91 |
92 | self.scheduler = self.build_lr_scheduler(cfg, optimizer)
93 | self.checkpointer = DetectionCheckpointer(
94 | model,
95 | cfg.OUTPUT_DIR,
96 | optimizer=optimizer,
97 | scheduler=self.scheduler,
98 | )
99 | self.start_iter = 0
100 | self.max_iter = cfg.SOLVER.MAX_ITER
101 | self.cfg = cfg
102 | self.register_hooks(self.build_hooks())
103 |
104 | @classmethod
105 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
106 | if output_folder is None:
107 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
108 | evaluator_list = []
109 | if cfg.META_INFO.EVAL_AP:
110 | evaluator_list.append(COCOEvaluatorFPN(dataset_name, cfg, True, output_folder))
111 | return DatasetEvaluators(evaluator_list)
112 |
113 | @classmethod
114 | def build_test_loader(cls, cfg, dataset_name):
115 | return build_detection_test_loader(cfg, dataset_name, ValMapper(cfg))
116 |
117 |
118 | def default_argument_parser(epilog=None):
119 | """
120 | Create a parser with some common arguments used by detectron2 users.
121 |
122 | Args:
123 | epilog (str): epilog passed to ArgumentParser describing the usage.
124 |
125 | Returns:
126 | argparse.ArgumentParser:
127 | """
128 | parser = argparse.ArgumentParser(
129 | epilog=epilog
130 | or f"""
131 | Examples:
132 |
133 | Run on single machine:
134 | $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth
135 |
136 | Run on multiple machines:
137 | (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags]
138 | (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags]
139 | """,
140 | formatter_class=argparse.RawDescriptionHelpFormatter,
141 | )
142 | parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
143 | parser.add_argument(
144 | "--resume",
145 | action="store_true",
146 | help="whether to attempt to resume from the checkpoint directory",
147 | )
148 | parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
149 | parser.add_argument("--no-pretrain", action="store_true", help="whether to load pretrained model")
150 | parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
151 | parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
152 | parser.add_argument(
153 | "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
154 | )
155 |
156 | # PyTorch still may leave orphan processes in multi-gpu training.
157 | # Therefore we use a deterministic way to obtain port,
158 | # so that users are aware of orphan processes by seeing the port occupied.
159 | port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
160 | parser.add_argument(
161 | "--dist-url",
162 | default="tcp://127.0.0.1:{}".format(port),
163 | help="initialization URL for pytorch distributed backend. See "
164 | "https://pytorch.org/docs/stable/distributed.html for details.",
165 | )
166 | parser.add_argument(
167 | "opts",
168 | help="Modify config options using the command-line",
169 | default=None,
170 | nargs=argparse.REMAINDER,
171 | )
172 | return parser
173 |
174 |
175 | def setup(args):
176 | """
177 | Create configs and perform basic setups.
178 | """
179 | cfg = get_cfg()
180 | add_custom_config(cfg)
181 | # add_backbone_config(cfg)
182 | cfg.merge_from_file(args.config_file)
183 | cfg.merge_from_list(args.opts)
184 | cfg.freeze()
185 | default_setup(cfg, args)
186 | return cfg
187 |
188 |
189 | def start_train(args):
190 | cfg = setup(args)
191 | if args.eval_only:
192 | model = Trainer.build_model(cfg)
193 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
194 | cfg.MODEL.WEIGHTS, resume=args.resume
195 | )
196 | res = Trainer.test(cfg, model)
197 | if cfg.TEST.AUG.ENABLED:
198 | res.update(Trainer.test_with_TTA(cfg, model))
199 | if comm.is_main_process():
200 | verify_results(cfg, res)
201 | return res
202 | trainer = Trainer(cfg, resume=args.resume, reuse_ckpt=args.no_pretrain)
203 | return trainer.train()
204 |
--------------------------------------------------------------------------------
/train_tools/visdrone_infer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | import os
4 | from collections import OrderedDict
5 | import torch
6 | import argparse
7 |
8 | import detectron2.utils.comm as comm
9 | from detectron2.checkpoint import DetectionCheckpointer
10 | from detectron2.config import get_cfg
11 | from detectron2.data import MetadataCatalog, build_detection_test_loader
12 | from detectron2.engine import DefaultTrainer, default_setup, hooks, launch
13 | from detectron2.evaluation import (
14 | CityscapesInstanceEvaluator,
15 | CityscapesSemSegEvaluator,
16 | COCOEvaluator,
17 | COCOPanopticEvaluator,
18 | DatasetEvaluators,
19 | LVISEvaluator,
20 | PascalVOCDetectionEvaluator,
21 | SemSegEvaluator,
22 | verify_results,
23 | )
24 | from detectron2.evaluation import (
25 | DatasetEvaluator,
26 | inference_on_dataset,
27 | print_csv_format,
28 | verify_results,
29 | )
30 | from detectron2.modeling import GeneralizedRCNNWithTTA
31 | from detectron2.checkpoint import DetectionCheckpointer
32 | from detectron2.evaluation.evaluator import inference_on_dataset
33 |
34 |
35 | from utils.val_mapper_with_ann import ValMapper
36 | from utils.anchor_gen import AnchorGeneratorWithCenter
37 | from utils.coco_eval_fpn import COCOEvaluatorFPN
38 | from utils.json_evaluator import JsonEvaluator
39 | from utils.time_evaluator import GPUTimeEvaluator
40 |
41 | from visdrone.dataloader import build_train_loader, build_test_loader
42 |
43 | # from models.backbone import build
44 | from configs.custom_config import add_custom_config
45 |
46 | from models.retinanet.retinanet import RetinaNet_D2
47 | from models.querydet.detector import RetinaNetQueryDet
48 |
49 |
50 |
51 | class Trainer(DefaultTrainer):
52 | @classmethod
53 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
54 | if output_folder is None:
55 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
56 | evaluator_list = []
57 | evaluator_list.append(JsonEvaluator(os.path.join(cfg.OUTPUT_DIR, 'visdrone_infer.json'), class_add_1=True))
58 | if cfg.META_INFO.EVAL_GPU_TIME:
59 | evaluator_list.append(GPUTimeEvaluator(True, 'minisecond'))
60 | return DatasetEvaluators(evaluator_list)
61 |
62 | @classmethod
63 | def build_train_loader(cls, cfg):
64 | return build_train_loader(cfg)
65 |
66 | @classmethod
67 | def build_test_loader(cls, cfg, dataset_name):
68 | return build_test_loader(cfg)
69 |
70 | @classmethod
71 | def test(cls, cfg, model, evaluators=None):
72 | logger = logging.getLogger(__name__)
73 | dataset_name = 'VisDrone2018'
74 |
75 | data_loader = cls.build_test_loader(cfg, dataset_name)
76 | evaluator = cls.build_evaluator(cfg, dataset_name)
77 | result = inference_on_dataset(model, data_loader, evaluator)
78 | if comm.is_main_process():
79 | assert isinstance(
80 | result, dict
81 | ), "Evaluator must return a dict on the main process. Got {} instead.".format(
82 | result
83 | )
84 | logger.info("Evaluation results for {} in csv format:".format(dataset_name))
85 | print_csv_format(result)
86 |
87 | if len(result) == 1:
88 | result = list(result.values())[0]
89 | return result
90 |
91 |
92 | def default_argument_parser(epilog=None):
93 | """
94 | Create a parser with some common arguments used by detectron2 users.
95 |
96 | Args:
97 | epilog (str): epilog passed to ArgumentParser describing the usage.
98 |
99 | Returns:
100 | argparse.ArgumentParser:
101 | """
102 | parser = argparse.ArgumentParser(
103 | epilog=epilog
104 | or f"""
105 | Examples:
106 |
107 | Run on single machine:
108 | $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth
109 |
110 | Run on multiple machines:
111 | (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags]
112 | (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags]
113 | """,
114 | formatter_class=argparse.RawDescriptionHelpFormatter,
115 | )
116 | parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
117 | parser.add_argument(
118 | "--resume",
119 | action="store_true",
120 | help="whether to attempt to resume from the checkpoint directory",
121 | )
122 | parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
123 | parser.add_argument("--no-pretrain", action="store_true", help="whether to load pretrained model")
124 | parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
125 | parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
126 | parser.add_argument(
127 | "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
128 | )
129 |
130 | # PyTorch still may leave orphan processes in multi-gpu training.
131 | # Therefore we use a deterministic way to obtain port,
132 | # so that users are aware of orphan processes by seeing the port occupied.
133 | port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
134 | parser.add_argument(
135 | "--dist-url",
136 | default="tcp://127.0.0.1:{}".format(port),
137 | help="initialization URL for pytorch distributed backend. See "
138 | "https://pytorch.org/docs/stable/distributed.html for details.",
139 | )
140 | parser.add_argument(
141 | "opts",
142 | help="Modify config options using the command-line",
143 | default=None,
144 | nargs=argparse.REMAINDER,
145 | )
146 | return parser
147 |
148 |
149 | def setup(args):
150 | """
151 | Create configs and perform basic setups.
152 | """
153 | cfg = get_cfg()
154 | add_custom_config(cfg)
155 | cfg.merge_from_file(args.config_file)
156 | cfg.merge_from_list(args.opts)
157 | cfg.freeze()
158 | default_setup(cfg, args)
159 | return cfg
160 |
161 |
162 | def start_train(args):
163 | cfg = setup(args)
164 |
165 | if args.eval_only:
166 | model = Trainer.build_model(cfg)
167 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
168 | cfg.MODEL.WEIGHTS, resume=args.resume
169 | )
170 | res = Trainer.test(cfg, model)
171 | if comm.is_main_process():
172 | verify_results(cfg, res)
173 | return res
174 |
175 | trainer = Trainer(cfg)
176 | if not args.no_pretrain:
177 | trainer.resume_or_load(resume=args.resume)
178 | return trainer.train()
179 |
--------------------------------------------------------------------------------
/train_tools/visdrone_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | """
4 | Detection Training Script.
5 |
6 | This scripts reads a given config file and runs the training or evaluation.
7 | It is an entry point that is made to train standard models in detectron2.
8 |
9 | In order to let one script support training of many models,
10 | this script contains logic that are specific to these built-in models and therefore
11 | may not be suitable for your own project.
12 | For example, your research project perhaps only needs a single "evaluator".
13 |
14 | Therefore, we recommend you to use detectron2 as an library and take
15 | this file as an example of how to use the library.
16 | You may want to write your own script with your datasets and other customizations.
17 | """
18 |
19 | import logging
20 | import sys
21 | import os
22 | from collections import OrderedDict
23 | import torch
24 | import time
25 | import argparse
26 |
27 | from torch.nn.parallel import DistributedDataParallel
28 |
29 | import detectron2.utils.comm as comm
30 | from detectron2.checkpoint import DetectionCheckpointer
31 | from detectron2.config import get_cfg
32 | from detectron2.data import MetadataCatalog, build_detection_test_loader
33 | from detectron2.engine import DefaultTrainer, default_setup, hooks, launch
34 | from detectron2.evaluation import (
35 | CityscapesInstanceEvaluator,
36 | CityscapesSemSegEvaluator,
37 | COCOEvaluator,
38 | COCOPanopticEvaluator,
39 | DatasetEvaluators,
40 | LVISEvaluator,
41 | PascalVOCDetectionEvaluator,
42 | SemSegEvaluator,
43 | verify_results,
44 | )
45 | from detectron2.modeling import GeneralizedRCNNWithTTA
46 | from detectron2.checkpoint import DetectionCheckpointer
47 | from detectron2.evaluation.evaluator import inference_on_dataset
48 | from detectron2.utils.events import JSONWriter, TensorboardXWriter
49 | from detectron2.engine.train_loop import AMPTrainer, SimpleTrainer
50 | from detectron2.engine.defaults import DefaultTrainer
51 |
52 | from utils.val_mapper_with_ann import ValMapper
53 | from utils.anchor_gen import AnchorGeneratorWithCenter
54 | from utils.coco_eval_fpn import COCOEvaluatorFPN
55 | from utils.json_evaluator import JsonEvaluator
56 | from utils.time_evaluator import GPUTimeEvaluator
57 |
58 | from visdrone.dataloader import build_train_loader, build_test_loader
59 |
60 | from configs.custom_config import add_custom_config
61 |
62 |
63 | class Trainer(DefaultTrainer):
64 | def __init__(self, cfg, resume=False, reuse_ckpt=False):
65 | """
66 | Args:
67 | cfg (CfgNode):
68 | """
69 | super(DefaultTrainer, self).__init__()
70 |
71 | logger = logging.getLogger("detectron2")
72 | if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
73 | setup_logger()
74 | cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
75 |
76 | # Assume these objects must be constructed in this order.
77 | model = self.build_model(cfg)
78 |
79 | ckpt = DetectionCheckpointer(model)
80 | self.start_iter = 0
81 | self.start_iter = ckpt.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
82 | self.iter =self.start_iter
83 |
84 | optimizer = self.build_optimizer(cfg, model)
85 | data_loader = self.build_train_loader(cfg)
86 |
87 | # For training, wrap with DDP. But don't need this for inference.
88 | if comm.get_world_size() > 1:
89 | model = DistributedDataParallel(
90 | model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
91 | )
92 | self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
93 | model, data_loader, optimizer
94 | )
95 |
96 | self.scheduler = self.build_lr_scheduler(cfg, optimizer)
97 | self.checkpointer = DetectionCheckpointer(
98 | model,
99 | cfg.OUTPUT_DIR,
100 | optimizer=optimizer,
101 | scheduler=self.scheduler,
102 | )
103 | self.start_iter = 0
104 | self.max_iter = cfg.SOLVER.MAX_ITER
105 | self.cfg = cfg
106 |
107 | self.register_hooks(self.build_hooks())
108 |
109 | def resume_or_load(self, resume=True):
110 | """
111 | If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
112 | a `last_checkpoint` file), resume from the file. Resuming means loading all
113 | available states (eg. optimizer and scheduler) and update iteration counter
114 | from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
115 | Otherwise, this is considered as an independent training. The method will load model
116 | weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
117 | from iteration 0.
118 | Args:
119 | resume (bool): whether to do resume or not
120 | """
121 | checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
122 | print(self.cfg.MODEL.WEIGHTS)
123 | exit()
124 | if resume and self.checkpointer.has_checkpoint():
125 | self.start_iter = checkpoint.get("iteration", -1) + 1
126 | # The checkpoint stores the training iteration that just finished, thus we start
127 | # at the next iteration (or iter zero if there's no checkpoint).
128 | if isinstance(self.model, DistributedDataParallel):
129 | # broadcast loaded data/model from the first rank, because other
130 | # machines may not have access to the checkpoint file
131 | if TORCH_VERSION >= (1, 7):
132 | self.model._sync_params_and_buffers()
133 | self.start_iter = comm.all_gather(self.start_iter)[0]
134 |
135 | @classmethod
136 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
137 | if output_folder is None:
138 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
139 | evaluator_list = []
140 | evaluator_list.append(JsonEvaluator(os.path.join(cfg.OUTPUT_DIR, 'visdrone_infer.json')))
141 | if cfg.META_INFO.EVAL_GPU_TIME:
142 | evaluator_list.append(GPUTimeEvaluator(True, 'minisecond'))
143 | return DatasetEvaluators(evaluator_list)
144 |
145 | @classmethod
146 | def build_train_loader(cls, cfg):
147 | return build_train_loader(cfg)
148 |
149 | @classmethod
150 | def build_test_loader(cls, cfg, dataset_name):
151 | return build_test_loader(cfg)
152 |
153 | @classmethod
154 | def test(cls, cfg, model, evaluators=None):
155 | logger = logging.getLogger(__name__)
156 | dataset_name = 'VisDrone2018'
157 |
158 | data_loader = cls.build_test_loader(cfg, dataset_name)
159 | evaluator = cls.build_evaluator(cfg, dataset_name)
160 | result = inference_on_dataset(model, data_loader, evaluator)
161 | return []
162 |
163 |
164 |
165 |
166 | def default_argument_parser(epilog=None):
167 | """
168 | Create a parser with some common arguments used by detectron2 users.
169 |
170 | Args:
171 | epilog (str): epilog passed to ArgumentParser describing the usage.
172 |
173 | Returns:
174 | argparse.ArgumentParser:
175 | """
176 | parser = argparse.ArgumentParser(
177 | epilog=epilog
178 | or f"""
179 | Examples:
180 |
181 | Run on single machine:
182 | $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth
183 |
184 | Run on multiple machines:
185 | (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags]
186 | (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags]
187 | """,
188 | formatter_class=argparse.RawDescriptionHelpFormatter,
189 | )
190 | parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
191 | parser.add_argument(
192 | "--resume",
193 | action="store_true",
194 | help="whether to attempt to resume from the checkpoint directory",
195 | )
196 | parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
197 | parser.add_argument("--no-pretrain", action="store_true", help="whether to load pretrained model")
198 | parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
199 | parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
200 | parser.add_argument(
201 | "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
202 | )
203 |
204 | # PyTorch still may leave orphan processes in multi-gpu training.
205 | # Therefore we use a deterministic way to obtain port,
206 | # so that users are aware of orphan processes by seeing the port occupied.
207 | port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
208 | parser.add_argument(
209 | "--dist-url",
210 | default="tcp://127.0.0.1:{}".format(port),
211 | help="initialization URL for pytorch distributed backend. See "
212 | "https://pytorch.org/docs/stable/distributed.html for details.",
213 | )
214 | parser.add_argument(
215 | "opts",
216 | help="Modify config options using the command-line",
217 | default=None,
218 | nargs=argparse.REMAINDER,
219 | )
220 | return parser
221 |
222 |
223 | def setup(args):
224 | """
225 | Create configs and perform basic setups.
226 | """
227 | cfg = get_cfg()
228 | add_custom_config(cfg)
229 | cfg.merge_from_file(args.config_file)
230 | cfg.merge_from_list(args.opts)
231 | cfg.freeze()
232 | default_setup(cfg, args)
233 | return cfg
234 |
235 |
236 | def start_train(args):
237 | cfg = setup(args)
238 |
239 | if args.eval_only:
240 | model = Trainer.build_model(cfg)
241 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
242 | cfg.MODEL.WEIGHTS, resume=args.resume
243 | )
244 | res = Trainer.test(cfg, model)
245 | if comm.is_main_process():
246 | verify_results(cfg, res)
247 | return res
248 |
249 | trainer = Trainer(cfg, resume=args.resume, reuse_ckpt=args.no_pretrain)
250 | return trainer.train()
251 |
--------------------------------------------------------------------------------
/train_visdrone.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from detectron2.engine import launch
4 | from train_tools.visdrone_train import default_argument_parser, start_train
5 |
6 | from models.retinanet.retinanet import RetinaNet_D2
7 | from models.querydet.detector import RetinaNetQueryDet
8 |
9 | if __name__ == '__main__':
10 | args = default_argument_parser().parse_args()
11 | print("Command Line Args:", args)
12 | launch(
13 | start_train,
14 | args.num_gpus,
15 | num_machines=args.num_machines,
16 | machine_rank=args.machine_rank,
17 | dist_url=args.dist_url,
18 | args=(args,),
19 | )
--------------------------------------------------------------------------------
/utils/anchor_gen.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from detectron2.modeling.anchor_generator import DefaultAnchorGenerator, _create_grid_offsets
4 | from detectron2.modeling import ANCHOR_GENERATOR_REGISTRY
5 | from detectron2.structures import Boxes
6 | import math
7 | import detectron2.utils.comm as comm
8 |
9 |
10 | @ANCHOR_GENERATOR_REGISTRY.register()
11 | class AnchorGeneratorWithCenter(DefaultAnchorGenerator):
12 |
13 | def _grid_anchors(self, grid_sizes):
14 | anchors = []
15 | centers = []
16 | for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors):
17 | shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
18 | shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
19 | center = torch.stack((shift_x, shift_y), dim=1)
20 |
21 | anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
22 | centers.append(center.view(-1, 2))
23 | return anchors, centers
24 |
25 | def forward(self, features):
26 | grid_sizes = [feature_map.shape[-2:] for feature_map in features]
27 | anchors_over_all_feature_maps, centers_over_all_feature_maps = self._grid_anchors(grid_sizes)
28 | anchor_boxes = [Boxes(x) for x in anchors_over_all_feature_maps]
29 |
30 | return anchor_boxes, centers_over_all_feature_maps
--------------------------------------------------------------------------------
/utils/coco_eval_fpn.py:
--------------------------------------------------------------------------------
1 | from detectron2.evaluation import COCOEvaluator
2 | from detectron2.structures import Boxes, BoxMode, pairwise_iou
3 |
4 | def _instances_to_coco_json(instances, img_id):
5 | """
6 | Dump an "Instances" object to a COCO-format json that's used for evaluation.
7 |
8 | Args:
9 | instances (Instances):
10 | img_id (int): the image id
11 |
12 | Returns:
13 | list[dict]: list of json annotations in COCO format.
14 | """
15 | num_instance = len(instances)
16 | if num_instance == 0:
17 | return []
18 |
19 | boxes = instances.pred_boxes.tensor.numpy()
20 | boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
21 | boxes = boxes.tolist()
22 | scores = instances.scores.tolist()
23 | classes = instances.pred_classes.tolist()
24 |
25 | has_fpn_layer = instances.has("fpn_layers")
26 | if has_fpn_layer:
27 | fpn_layers = instances.fpn_layers.tolist()
28 |
29 | has_mask = instances.has("pred_masks")
30 | if has_mask:
31 | rles = [
32 | mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
33 | for mask in instances.pred_masks
34 | ]
35 | for rle in rles:
36 | rle["counts"] = rle["counts"].decode("utf-8")
37 |
38 | has_keypoints = instances.has("pred_keypoints")
39 | if has_keypoints:
40 | keypoints = instances.pred_keypoints
41 |
42 | results = []
43 | for k in range(num_instance):
44 | result = {
45 | "image_id": img_id,
46 | "category_id": classes[k],
47 | "bbox": boxes[k],
48 | "score": scores[k],
49 | }
50 | if has_fpn_layer:
51 | result["fpn_layer"] = fpn_layers[k]
52 | if has_mask:
53 | result["segmentation"] = rles[k]
54 | if has_keypoints:
55 | keypoints[k][:, :2] -= 0.5
56 | result["keypoints"] = keypoints[k].flatten().tolist()
57 | results.append(result)
58 | return results
59 |
60 |
61 |
62 | class COCOEvaluatorFPN(COCOEvaluator):
63 |
64 | def process(self, inputs, outputs):
65 | """
66 | Args:
67 | inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
68 | It is a list of dict. Each dict corresponds to an image and
69 | contains keys like "height", "width", "file_name", "image_id".
70 | outputs: the outputs of a COCO model. It is a list of dicts with key
71 | "instances" that contains :class:`Instances`.
72 | """
73 | for input, output in zip(inputs, outputs):
74 | prediction = {"image_id": input["image_id"]}
75 |
76 | # TODO this is ugly
77 | if "instances" in output:
78 | instances = output["instances"].to(self._cpu_device)
79 | prediction["instances"] = _instances_to_coco_json(instances, input["image_id"])
80 | if "proposals" in output:
81 | prediction["proposals"] = output["proposals"].to(self._cpu_device)
82 | self._predictions.append(prediction)
--------------------------------------------------------------------------------
/utils/gradient_checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import warnings
3 | from typing import Any, Iterable, List, Tuple
4 |
5 |
6 | def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
7 | if isinstance(inputs, tuple):
8 | out = []
9 | for inp in inputs:
10 | if not isinstance(inp, torch.Tensor):
11 | out.append(inp)
12 | continue
13 |
14 | x = inp.detach()
15 | x.requires_grad = inp.requires_grad
16 | out.append(x)
17 | return tuple(out)
18 | else:
19 | raise RuntimeError(
20 | "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
21 |
22 |
23 | def check_backward_validity(inputs: Iterable[Any]) -> None:
24 | if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
25 | warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
26 |
27 |
28 | # We can't know if the run_fn will internally move some args to different devices,
29 | # which would require logic to preserve rng states for those devices as well.
30 | # We could paranoically stash and restore ALL the rng states for all visible devices,
31 | # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
32 | # the device of all Tensor args.
33 | #
34 | # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
35 | def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
36 | # This will not error out if "arg" is a CPU tensor or a non-tensor type because
37 | # the conditionals short-circuit.
38 | fwd_gpu_devices = list(set(arg.get_device() for arg in args
39 | if isinstance(arg, torch.Tensor) and arg.is_cuda))
40 |
41 | fwd_gpu_states = []
42 | for device in fwd_gpu_devices:
43 | with torch.cuda.device(device):
44 | fwd_gpu_states.append(torch.cuda.get_rng_state())
45 |
46 | return fwd_gpu_devices, fwd_gpu_states
47 |
48 |
49 | def set_device_states(devices, states) -> None:
50 | for device, state in zip(devices, states):
51 | with torch.cuda.device(device):
52 | torch.cuda.set_rng_state(state)
53 |
54 |
55 | class CheckpointFunction(torch.autograd.Function):
56 |
57 | @staticmethod
58 | def forward(ctx, run_function, preserve_rng_state, *args):
59 | check_backward_validity(args)
60 | ctx.run_function = run_function
61 | ctx.preserve_rng_state = preserve_rng_state
62 | ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
63 | if preserve_rng_state:
64 | ctx.fwd_cpu_state = torch.get_rng_state()
65 | # Don't eagerly initialize the cuda context by accident.
66 | # (If the user intends that the context is initialized later, within their
67 | # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
68 | # we have no way to anticipate this will happen before we run the function.)
69 | ctx.had_cuda_in_fwd = False
70 | if torch.cuda._initialized:
71 | ctx.had_cuda_in_fwd = True
72 | ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
73 | ctx.save_for_backward(*args)
74 | with torch.no_grad():
75 | outputs = run_function(*args)
76 | return outputs
77 |
78 | @staticmethod
79 | def backward(ctx, *args):
80 | if not torch.autograd._is_checkpoint_valid():
81 | raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
82 | inputs = ctx.saved_tensors
83 | # Stash the surrounding rng state, and mimic the state that was
84 | # present at this time during forward. Restore the surrounding state
85 | # when we're done.
86 | rng_devices = []
87 | if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
88 | rng_devices = ctx.fwd_gpu_devices
89 | with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
90 | if ctx.preserve_rng_state:
91 | torch.set_rng_state(ctx.fwd_cpu_state)
92 | if ctx.had_cuda_in_fwd:
93 | set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
94 | detached_inputs = detach_variable(inputs)
95 | with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
96 | outputs = ctx.run_function(*detached_inputs)
97 |
98 | if isinstance(outputs, torch.Tensor):
99 | outputs = (outputs,)
100 | torch.autograd.backward(outputs, args)
101 | grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
102 | for inp in detached_inputs)
103 | return (None, None) + grads
104 |
105 |
106 | def checkpoint(function, *args, **kwargs):
107 | r"""Checkpoint a model or part of the model
108 |
109 | Checkpointing works by trading compute for memory. Rather than storing all
110 | intermediate activations of the entire computation graph for computing
111 | backward, the checkpointed part does **not** save intermediate activations,
112 | and instead recomputes them in backward pass. It can be applied on any part
113 | of a model.
114 |
115 | Specifically, in the forward pass, :attr:`function` will run in
116 | :func:`torch.no_grad` manner, i.e., not storing the intermediate
117 | activations. Instead, the forward pass saves the inputs tuple and the
118 | :attr:`function` parameter. In the backwards pass, the saved inputs and
119 | :attr:`function` is retrieved, and the forward pass is computed on
120 | :attr:`function` again, now tracking the intermediate activations, and then
121 | the gradients are calculated using these activation values.
122 |
123 | .. warning::
124 | Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
125 | with :func:`torch.autograd.backward`.
126 |
127 | .. warning::
128 | If :attr:`function` invocation during backward does anything different
129 | than the one during forward, e.g., due to some global variable, the
130 | checkpointed version won't be equivalent, and unfortunately it can't be
131 | detected.
132 |
133 | .. warning::
134 | If checkpointed segment contains tensors detached from the computational
135 | graph by `detach()` or `torch.no_grad()`, the backward pass will raise an
136 | error. This is because `checkpoint` makes all the outputs require
137 | gradients which causes issues when a tensor is defined to have no
138 | gradient in the model. To circumvent this, detach the tensors outside of
139 | the `checkpoint` function.
140 |
141 | .. warning:
142 | At least one of the inputs needs to have :code:`requires_grad=True` if
143 | grads are needed for model inputs, otherwise the checkpointed part of the
144 | model won't have gradients.
145 |
146 | Args:
147 | function: describes what to run in the forward pass of the model or
148 | part of the model. It should also know how to handle the inputs
149 | passed as the tuple. For example, in LSTM, if user passes
150 | ``(activation, hidden)``, :attr:`function` should correctly use the
151 | first input as ``activation`` and the second input as ``hidden``
152 | preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
153 | the RNG state during each checkpoint.
154 | args: tuple containing inputs to the :attr:`function`
155 |
156 | Returns:
157 | Output of running :attr:`function` on :attr:`*args`
158 | """
159 | # Hack to mix *args with **kwargs in a python 2.7-compliant way
160 | preserve = kwargs.pop('preserve_rng_state', True)
161 | if kwargs:
162 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
163 |
164 | return CheckpointFunction.apply(function, preserve, *args)
165 |
166 |
167 | def checkpoint_sequential(functions, segments, input, **kwargs):
168 | r"""A helper function for checkpointing sequential models.
169 |
170 | Sequential models execute a list of modules/functions in order
171 | (sequentially). Therefore, we can divide such a model in various segments
172 | and checkpoint each segment. All segments except the last will run in
173 | :func:`torch.no_grad` manner, i.e., not storing the intermediate
174 | activations. The inputs of each checkpointed segment will be saved for
175 | re-running the segment in the backward pass.
176 |
177 | See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
178 |
179 | .. warning::
180 | Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
181 | with :func:`torch.autograd.backward`.
182 |
183 | .. warning:
184 | At least one of the inputs needs to have :code:`requires_grad=True` if
185 | grads are needed for model inputs, otherwise the checkpointed part of the
186 | model won't have gradients.
187 |
188 | .. warning:
189 | Since PyTorch 1.4, it allows only one Tensor as the input and
190 | intermediate outputs, just like :class:`torch.nn.Sequential`.
191 |
192 | Args:
193 | functions: A :class:`torch.nn.Sequential` or the list of modules or
194 | functions (comprising the model) to run sequentially.
195 | segments: Number of chunks to create in the model
196 | input: A Tensor that is input to :attr:`functions`
197 | preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
198 | the RNG state during each checkpoint.
199 |
200 | Returns:
201 | Output of running :attr:`functions` sequentially on :attr:`*inputs`
202 |
203 | Example:
204 | >>> model = nn.Sequential(...)
205 | >>> input_var = checkpoint_sequential(model, chunks, input_var)
206 | """
207 | # Hack for keyword-only parameter in a python 2.7-compliant way
208 | preserve = kwargs.pop('preserve_rng_state', True)
209 | if kwargs:
210 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
211 |
212 | def run_function(start, end, functions):
213 | def forward(input):
214 | for j in range(start, end + 1):
215 | input = functions[j](input)
216 | return input
217 | return forward
218 |
219 | if isinstance(functions, torch.nn.Sequential):
220 | functions = list(functions.children())
221 |
222 | segment_size = len(functions) // segments
223 | # the last chunk has to be non-volatile
224 | end = -1
225 | for start in range(0, segment_size * (segments - 1), segment_size):
226 | end = start + segment_size - 1
227 | input = checkpoint(run_function(start, end, functions), input,
228 | preserve_rng_state=preserve)
229 | return run_function(end + 1, len(functions) - 1, functions)(input)
230 |
--------------------------------------------------------------------------------
/utils/json_evaluator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import json
4 | import torch
5 | import logging
6 | import itertools
7 | import numpy as np
8 |
9 | from detectron2.evaluation.evaluator import DatasetEvaluator
10 | import detectron2.utils.comm as comm
11 | import itertools
12 | from collections import OrderedDict
13 | from detectron2.evaluation.coco_evaluation import instances_to_coco_json
14 |
15 | import numpy as np
16 |
17 | class JsonEvaluator(DatasetEvaluator):
18 | def __init__(self, out_json, distributed=True, class_add_1=True):
19 | self._out_json = out_json
20 | self.class_add_1 = class_add_1
21 |
22 | self._distributed = distributed
23 | self._cpu_device = torch.device("cpu")
24 | self._logger = logging.getLogger(__name__)
25 | self._predictions = []
26 |
27 | self.reset()
28 |
29 |
30 | def reset(self):
31 | self._predictions = []
32 |
33 |
34 | def process(self, inputs, outputs):
35 | for input, output in zip(inputs, outputs):
36 | img_name = os.path.split(input['file_name'])[-1].split('.')[0]
37 | if "instances" in output:
38 | prediction = {"img_name": img_name}
39 | instances = output["instances"].to(self._cpu_device)
40 | if self.class_add_1:
41 | instances.pred_classes += 1
42 | prediction["instances"] = instances_to_coco_json(instances, input['image_id'])
43 | self._predictions.append(prediction)
44 |
45 | def evaluate(self):
46 | if self._distributed:
47 | comm.synchronize()
48 | predictions = comm.gather(self._predictions, dst=0)
49 | predictions = list(itertools.chain(*predictions))
50 | if not comm.is_main_process():
51 | return {}
52 | else:
53 | predictions = self._predictions
54 |
55 | if len(predictions) == 0:
56 | return {}
57 |
58 | det_preds = []
59 | for pred in predictions:
60 | det_preds = det_preds + pred['instances']
61 |
62 | with open(self._out_json, "w") as f:
63 | f.write(json.dumps(det_preds))
64 | f.flush()
65 |
66 | return {}
67 |
68 |
69 |
--------------------------------------------------------------------------------
/utils/loop_matcher.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from typing import List
3 | import torch
4 |
5 | # useful when there are huge number of gt boxes
6 | class LoopMatcher(object):
7 | def __init__(
8 | self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False
9 | ):
10 | thresholds = thresholds[:]
11 | assert thresholds[0] > 0
12 | thresholds.insert(0, -float("inf"))
13 | thresholds.append(float("inf"))
14 | assert all(low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:]))
15 | assert all(l in [-1, 0, 1] for l in labels)
16 | assert len(labels) == len(thresholds) - 1
17 |
18 | self.low_quality_thrshold = 0.3
19 | self.thresholds = thresholds
20 | self.labels = labels
21 | self.allow_low_quality_matches = allow_low_quality_matches
22 |
23 |
24 | def _iou(self, boxes, box):
25 | iw = torch.clamp(boxes[:, 2], max=box[2]) - torch.clamp(boxes[:, 0], min=box[0])
26 | ih = torch.clamp(boxes[:, 3], max=box[3]) - torch.clamp(boxes[:, 1], min=box[1])
27 |
28 | inter = torch.clamp(iw, min=0) * torch.clamp(ih, min=0)
29 |
30 | areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
31 | area = (box[2] - box[0]) * (box[3] - box[1])
32 |
33 | iou = inter / (areas + area - inter)
34 | return iou
35 |
36 | def __call__(self, gt_boxes, anchors):
37 | if len(gt_boxes) == 0:
38 | default_matches = torch.zeros((len(anchors)), dtype=torch.int64).to(anchors.tensor.device)
39 | default_match_labels = torch.zeros((len(anchors)), dtype=torch.int8).to(anchors.tensor.device) + self.labels[0]
40 | return default_matches, default_match_labels
41 |
42 | gt_boxes_tensor = gt_boxes.tensor
43 | anchors_tensor = anchors.tensor
44 |
45 | max_ious = torch.zeros((len(anchors))).to(anchors_tensor.device)
46 | matched_inds = torch.zeros((len(anchors)), dtype=torch.long).to(anchors_tensor.device)
47 | gt_ious = torch.zeros((len(gt_boxes))).to(anchors_tensor.device)
48 |
49 | for i in range(len(gt_boxes)):
50 | ious = self._iou(anchors_tensor, gt_boxes_tensor[i])
51 | gt_ious[i] = ious.max()
52 | matched_inds = torch.where(ious > max_ious, torch.zeros(1, dtype=torch.long, device=matched_inds.device)+i, matched_inds)
53 | max_ious = torch.max(ious, max_ious)
54 | del(ious)
55 |
56 | matched_vals = max_ious
57 | matches = matched_inds
58 |
59 | match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
60 |
61 | for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
62 | low_high = (matched_vals >= low) & (matched_vals < high)
63 | match_labels[low_high] = l
64 |
65 | if self.allow_low_quality_matches:
66 | self.set_low_quality_matches_(match_labels, matched_vals, matches, gt_ious)
67 |
68 | return matches, match_labels
69 |
70 | def set_low_quality_matches_(self, match_labels, matched_vals, matches, gt_ious):
71 | for i in range(len(gt_ious)):
72 | match_labels[(matched_vals==gt_ious[i]) & (matches==i)] = 1
73 |
74 |
--------------------------------------------------------------------------------
/utils/merged_sync_bn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import logging
3 | import torch
4 | import torch.distributed as dist
5 | from torch import nn
6 | from torch.autograd.function import Function
7 | from torch.nn import functional as F
8 | from torch.cuda.amp import autocast
9 |
10 | from detectron2.utils import comm, env
11 | from detectron2.layers.wrappers import BatchNorm2d
12 |
13 | class AllReduce(Function):
14 | @staticmethod
15 | def forward(ctx, input):
16 | input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())]
17 | # Use allgather instead of allreduce since I don't trust in-place operations ..
18 | dist.all_gather(input_list, input, async_op=False)
19 | inputs = torch.stack(input_list, dim=0)
20 | return torch.sum(inputs, dim=0)
21 |
22 | @staticmethod
23 | def backward(ctx, grad_output):
24 | dist.all_reduce(grad_output, async_op=False)
25 | return grad_output
26 |
27 | class MergedSyncBatchNorm(BatchNorm2d):
28 | """
29 | In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient
30 | when the batch size on each worker is different.
31 | (e.g., when scale augmentation is used, or when it is applied to mask head).
32 |
33 | This is a slower but correct alternative to `nn.SyncBatchNorm`.
34 |
35 | Note:
36 | There isn't a single definition of Sync BatchNorm.
37 |
38 | When ``stats_mode==""``, this module computes overall statistics by using
39 | statistics of each worker with equal weight. The result is true statistics
40 | of all samples (as if they are all on one worker) only when all workers
41 | have the same (N, H, W). This mode does not support inputs with zero batch size.
42 |
43 | When ``stats_mode=="N"``, this module computes overall statistics by weighting
44 | the statistics of each worker by their ``N``. The result is true statistics
45 | of all samples (as if they are all on one worker) only when all workers
46 | have the same (H, W). It is slower than ``stats_mode==""``.
47 |
48 | Even though the result of this module may not be the true statistics of all samples,
49 | it may still be reasonable because it might be preferrable to assign equal weights
50 | to all workers, regardless of their (H, W) dimension, instead of putting larger weight
51 | on larger images. From preliminary experiments, little difference is found between such
52 | a simplified implementation and an accurate computation of overall mean & variance.
53 | """
54 |
55 | def __init__(self, *args, stats_mode="", **kwargs):
56 | super().__init__(*args, **kwargs)
57 | assert stats_mode in ["", "N"]
58 | self._stats_mode = stats_mode
59 | self._batch_mean = None # for precise BN
60 | self._batch_meansqr = None # for precise BN
61 |
62 | def _eval_forward(self, inputs):
63 | scale = self.weight * torch.rsqrt(self.running_var + self.eps)
64 | bias = self.bias - self.running_mean * scale
65 | scale = scale.view(1, -1, 1, 1)
66 | bias = bias.view(1, -1, 1, 1)
67 | return [(x * scale + bias) for x in inputs]
68 |
69 |
70 | # @float_function
71 | def forward(self, inputs):
72 | with autocast(False):
73 | if comm.get_world_size() == 1 or not self.training:
74 | return self._eval_forward(inputs)
75 |
76 | B, C = inputs[0].shape[0], inputs[0].shape[1]
77 |
78 | mean = sum([torch.mean(input, dim=[0, 2, 3]) for input in inputs]) / len(inputs)
79 | meansqr = sum([torch.mean(input * input, dim=[0, 2, 3]) for input in inputs]) / len(inputs)
80 |
81 | if self._stats_mode == "":
82 | assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
83 | vec = torch.cat([mean, meansqr], dim=0)
84 | vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
85 | mean, meansqr = torch.split(vec, C)
86 | momentum = self.momentum
87 | else:
88 | if B == 0:
89 | vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype)
90 | vec = vec + _input.sum() # make sure there is gradient w.r.t input
91 | else:
92 | vec = torch.cat(
93 | [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0
94 | )
95 | vec = AllReduce.apply(vec * B)
96 |
97 | total_batch = vec[-1].detach()
98 | momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0
99 | total_batch = torch.max(total_batch, torch.ones_like(total_batch)) # avoid div-by-zero
100 | mean, meansqr, _ = torch.split(vec / total_batch, C)
101 |
102 | var = meansqr - mean * mean
103 | invstd = torch.rsqrt(var + self.eps)
104 | scale = self.weight * invstd
105 | bias = self.bias - mean * scale
106 | scale = scale.reshape(1, -1, 1, 1)
107 | bias = bias.reshape(1, -1, 1, 1)
108 |
109 | self.running_mean += momentum * (mean.detach() - self.running_mean)
110 | self.running_var += momentum * (var.detach() - self.running_var)
111 |
112 | self._batch_mean = mean
113 | self._batch_meansqr = meansqr
114 |
115 | outputs = [(input * scale + bias) for input in inputs]
116 | return outputs
117 |
--------------------------------------------------------------------------------
/utils/soft_nms.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from detectron2.structures import Boxes, RotatedBoxes, pairwise_iou, pairwise_iou_rotated
4 |
5 |
6 | def soft_nms(boxes, scores, method, gaussian_sigma, linear_threshold, prune_threshold):
7 | """
8 | Performs soft non-maximum suppression algorithm on axis aligned boxes
9 |
10 | Args:
11 | boxes (Tensor[N, 5]):
12 | boxes where NMS will be performed. They
13 | are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format
14 | scores (Tensor[N]):
15 | scores for each one of the boxes
16 | method (str):
17 | one of ['gaussian', 'linear', 'hard']
18 | see paper for details. users encouraged not to use "hard", as this is the
19 | same nms available elsewhere in detectron2
20 | gaussian_sigma (float):
21 | parameter for Gaussian penalty function
22 | linear_threshold (float):
23 | iou threshold for applying linear decay. Nt from the paper
24 | re-used as threshold for standard "hard" nms
25 | prune_threshold (float):
26 | boxes with scores below this threshold are pruned at each iteration.
27 | Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
28 |
29 | Returns:
30 | tuple(Tensor, Tensor):
31 | [0]: int64 tensor with the indices of the elements that have been kept
32 | by Soft NMS, sorted in decreasing order of scores
33 | [1]: float tensor with the re-scored scores of the elements that were kept
34 | """
35 | return _soft_nms(
36 | Boxes,
37 | pairwise_iou,
38 | boxes,
39 | scores,
40 | method,
41 | gaussian_sigma,
42 | linear_threshold,
43 | prune_threshold,
44 | )
45 |
46 |
47 | def soft_nms_rotated(boxes, scores, method, gaussian_sigma, linear_threshold, prune_threshold):
48 | """
49 | Performs soft non-maximum suppression algorithm on rotated boxes
50 |
51 | Args:
52 | boxes (Tensor[N, 5]):
53 | boxes where NMS will be performed. They
54 | are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format
55 | scores (Tensor[N]):
56 | scores for each one of the boxes
57 | method (str):
58 | one of ['gaussian', 'linear', 'hard']
59 | see paper for details. users encouraged not to use "hard", as this is the
60 | same nms available elsewhere in detectron2
61 | gaussian_sigma (float):
62 | parameter for Gaussian penalty function
63 | linear_threshold (float):
64 | iou threshold for applying linear decay. Nt from the paper
65 | re-used as threshold for standard "hard" nms
66 | prune_threshold (float):
67 | boxes with scores below this threshold are pruned at each iteration.
68 | Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
69 |
70 | Returns:
71 | tuple(Tensor, Tensor):
72 | [0]: int64 tensor with the indices of the elements that have been kept
73 | by Soft NMS, sorted in decreasing order of scores
74 | [1]: float tensor with the re-scored scores of the elements that were kept """
75 | return _soft_nms(
76 | RotatedBoxes,
77 | pairwise_iou_rotated,
78 | boxes,
79 | scores,
80 | method,
81 | gaussian_sigma,
82 | linear_threshold,
83 | prune_threshold,
84 | )
85 |
86 |
87 | def batched_soft_nms(
88 | boxes, scores, idxs, method, gaussian_sigma, linear_threshold, prune_threshold
89 | ):
90 | """
91 | Performs soft non-maximum suppression in a batched fashion.
92 |
93 | Each index value correspond to a category, and NMS
94 | will not be applied between elements of different categories.
95 |
96 | Args:
97 | boxes (Tensor[N, 4]):
98 | boxes where NMS will be performed. They
99 | are expected to be in (x1, y1, x2, y2) format
100 | scores (Tensor[N]):
101 | scores for each one of the boxes
102 | idxs (Tensor[N]):
103 | indices of the categories for each one of the boxes.
104 | method (str):
105 | one of ['gaussian', 'linear', 'hard']
106 | see paper for details. users encouraged not to use "hard", as this is the
107 | same nms available elsewhere in detectron2
108 | gaussian_sigma (float):
109 | parameter for Gaussian penalty function
110 | linear_threshold (float):
111 | iou threshold for applying linear decay. Nt from the paper
112 | re-used as threshold for standard "hard" nms
113 | prune_threshold (float):
114 | boxes with scores below this threshold are pruned at each iteration.
115 | Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
116 | Returns:
117 | tuple(Tensor, Tensor):
118 | [0]: int64 tensor with the indices of the elements that have been kept
119 | by Soft NMS, sorted in decreasing order of scores
120 | [1]: float tensor with the re-scored scores of the elements that were kept
121 | """
122 | if boxes.numel() == 0:
123 | return (
124 | torch.empty((0,), dtype=torch.int64, device=boxes.device),
125 | torch.empty((0,), dtype=torch.float32, device=scores.device),
126 | )
127 | # strategy: in order to perform NMS independently per class.
128 | # we add an offset to all the boxes. The offset is dependent
129 | # only on the class idx, and is large enough so that boxes
130 | # from different classes do not overlap
131 | max_coordinate = boxes.max()
132 | offsets = idxs.to(boxes) * (max_coordinate + 1)
133 | boxes_for_nms = boxes + offsets[:, None]
134 | return soft_nms(
135 | boxes_for_nms, scores, method, gaussian_sigma, linear_threshold, prune_threshold
136 | )
137 |
138 |
139 | class SoftNMSer(object):
140 | def __init__(self, method, gaussian_sigma, linear_threshold, prune_threshold):
141 | self.method = method
142 | self.gaussian_sigma = gaussian_sigma
143 | self.linear_threshold = linear_threshold
144 | self.prune_threshold = prune_threshold
145 |
146 | def __call__(self, boxes, scores, class_idxs):
147 | return batched_soft_nms(boxes, scores, class_idxs, self.method, self.gaussian_sigma, self.linear_threshold, self.prune_threshold)
148 |
149 |
150 | def batched_soft_nms_rotated(
151 | boxes, scores, idxs, method, gaussian_sigma, linear_threshold, prune_threshold
152 | ):
153 | """
154 | Performs soft non-maximum suppression in a batched fashion on rotated bounding boxes.
155 |
156 | Each index value correspond to a category, and NMS
157 | will not be applied between elements of different categories.
158 |
159 | Args:
160 | boxes (Tensor[N, 5]):
161 | boxes where NMS will be performed. They
162 | are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format
163 | scores (Tensor[N]):
164 | scores for each one of the boxes
165 | idxs (Tensor[N]):
166 | indices of the categories for each one of the boxes.
167 | method (str):
168 | one of ['gaussian', 'linear', 'hard']
169 | see paper for details. users encouraged not to use "hard", as this is the
170 | same nms available elsewhere in detectron2
171 | gaussian_sigma (float):
172 | parameter for Gaussian penalty function
173 | linear_threshold (float):
174 | iou threshold for applying linear decay. Nt from the paper
175 | re-used as threshold for standard "hard" nms
176 | prune_threshold (float):
177 | boxes with scores below this threshold are pruned at each iteration.
178 | Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
179 | Returns:
180 | tuple(Tensor, Tensor):
181 | [0]: int64 tensor with the indices of the elements that have been kept
182 | by Soft NMS, sorted in decreasing order of scores
183 | [1]: float tensor with the re-scored scores of the elements that were kept
184 | """
185 | if boxes.numel() == 0:
186 | return (
187 | torch.empty((0,), dtype=torch.int64, device=boxes.device),
188 | torch.empty((0,), dtype=torch.float32, device=scores.device),
189 | )
190 | # strategy: in order to perform NMS independently per class.
191 | # we add an offset to all the boxes. The offset is dependent
192 | # only on the class idx, and is large enough so that boxes
193 | # from different classes do not overlap
194 | max_coordinate = boxes[:, :2].max() + torch.norm(boxes[:, 2:4], 2, dim=1).max()
195 | offsets = idxs.to(boxes) * (max_coordinate + 1)
196 | boxes_for_nms = boxes.clone()
197 | boxes_for_nms[:, :2] += offsets[:, None]
198 | return soft_nms_rotated(
199 | boxes_for_nms, scores, method, gaussian_sigma, linear_threshold, prune_threshold
200 | )
201 |
202 |
203 | def _soft_nms(
204 | box_class,
205 | pairwise_iou_func,
206 | boxes,
207 | scores,
208 | method,
209 | gaussian_sigma,
210 | linear_threshold,
211 | prune_threshold,
212 | ):
213 | """
214 | Soft non-max suppression algorithm.
215 |
216 | Implementation of [Soft-NMS -- Improving Object Detection With One Line of Codec]
217 | (https://arxiv.org/abs/1704.04503)
218 |
219 | Args:
220 | box_class (cls): one of Box, RotatedBoxes
221 | pairwise_iou_func (func): one of pairwise_iou, pairwise_iou_rotated
222 | boxes (Tensor[N, ?]):
223 | boxes where NMS will be performed
224 | if Boxes, in (x1, y1, x2, y2) format
225 | if RotatedBoxes, in (x_ctr, y_ctr, width, height, angle_degrees) format
226 | scores (Tensor[N]):
227 | scores for each one of the boxes
228 | method (str):
229 | one of ['gaussian', 'linear', 'hard']
230 | see paper for details. users encouraged not to use "hard", as this is the
231 | same nms available elsewhere in detectron2
232 | gaussian_sigma (float):
233 | parameter for Gaussian penalty function
234 | linear_threshold (float):
235 | iou threshold for applying linear decay. Nt from the paper
236 | re-used as threshold for standard "hard" nms
237 | prune_threshold (float):
238 | boxes with scores below this threshold are pruned at each iteration.
239 | Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
240 |
241 | Returns:
242 | tuple(Tensor, Tensor):
243 | [0]: int64 tensor with the indices of the elements that have been kept
244 | by Soft NMS, sorted in decreasing order of scores
245 | [1]: float tensor with the re-scored scores of the elements that were kept
246 | """
247 | boxes = boxes.clone()
248 | scores = scores.clone()
249 | idxs = torch.arange(scores.size()[0])
250 |
251 | idxs_out = []
252 | scores_out = []
253 |
254 | while scores.numel() > 0:
255 | top_idx = torch.argmax(scores)
256 | idxs_out.append(idxs[top_idx].item())
257 | scores_out.append(scores[top_idx].item())
258 |
259 | top_box = boxes[top_idx]
260 | ious = pairwise_iou_func(box_class(top_box.unsqueeze(0)), box_class(boxes))[0]
261 |
262 | if method == "linear":
263 | decay = torch.ones_like(ious)
264 | decay_mask = ious > linear_threshold
265 | decay[decay_mask] = 1 - ious[decay_mask]
266 | elif method == "gaussian":
267 | decay = torch.exp(-torch.pow(ious, 2) / gaussian_sigma)
268 | elif method == "hard": # standard NMS
269 | decay = (ious < linear_threshold).float()
270 | else:
271 | raise NotImplementedError("{} soft nms method not implemented.".format(method))
272 |
273 | scores *= decay
274 | keep = scores > prune_threshold
275 | keep[top_idx] = False
276 |
277 | boxes = boxes[keep]
278 | scores = scores[keep]
279 | idxs = idxs[keep]
280 |
281 | return torch.tensor(idxs_out).to(boxes.device), torch.tensor(scores_out).to(scores.device)
--------------------------------------------------------------------------------
/utils/time_evaluator.py:
--------------------------------------------------------------------------------
1 | import time
2 | from detectron2.evaluation.evaluator import DatasetEvaluator
3 | import detectron2.utils.comm as comm
4 | import itertools
5 | from collections import OrderedDict
6 |
7 | import numpy as np
8 |
9 | class GPUTimeEvaluator(DatasetEvaluator):
10 | def __init__(self, distributed, unit, out_file=None):
11 | self.distributed = distributed
12 | self.all_time = []
13 | self.unit = unit
14 | self.out_file = out_file
15 | if unit not in {'minisecond', 'second'}:
16 | raise NotImplementedError('Unsupported time unit %s'%unit)
17 | self.reset()
18 |
19 | def reset(self):
20 | self.all_time = []
21 |
22 | def process(self, inputs, outputs):
23 | for output in outputs:
24 | if 'time' in output.keys():
25 | self.all_time.append(output['time'])
26 | return
27 |
28 | def evaluate(self):
29 | if self.distributed:
30 | comm.synchronize()
31 | all_time = comm.gather(self.all_time, dst=0)
32 | all_time = list(itertools.chain(*all_time))
33 |
34 | if not comm.is_main_process():
35 | return {}
36 | else:
37 | all_time = self.all_time
38 |
39 | if len(all_time) == 0:
40 | return {'GPU_Speed': 0}
41 |
42 | all_time = np.array(all_time)
43 |
44 | speeds = 1. / all_time
45 | if self.unit == 'minisecond':
46 | speeds *= 1000
47 |
48 | mean_speed = speeds.mean()
49 | std_speed = speeds.std()
50 | max_speed = speeds.max()
51 | min_speed = speeds.min()
52 | mid_speed = np.median(speeds)
53 |
54 | if self.out_file is not None:
55 | f = open(self.out_file, 'a')
56 | curr_time = time.strftime('%Y/%m/%d,%H:%M:%S', time.localtime())
57 | f.write('%s\t%.2f\n'%(curr_time, mean_speed))
58 | f.close()
59 |
60 | ret_dict = {'Mean_FPS': mean_speed, 'Std_FPS': std_speed, 'Max_FPS': max_speed, 'Min_FPS': min_speed, 'Mid_FPS': mid_speed}
61 |
62 | return {'GPU_Speed': ret_dict}
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 |
5 | from detectron2.structures import Boxes
6 |
7 |
8 | def get_box_scales(boxes: Boxes):
9 | return torch.sqrt((boxes.tensor[:, 2] - boxes.tensor[:, 0]) * (boxes.tensor[:, 3] - boxes.tensor[:, 1]))
10 |
11 | def get_anchor_center_min_dis(box_centers: torch.Tensor, anchor_centers: torch.Tensor):
12 | """
13 | Args:
14 | box_centers: [N, 2]
15 | anchor_centers: [M, 2]
16 | Returns:
17 |
18 | """
19 | N, _ = box_centers.size()
20 | M, _ = anchor_centers.size()
21 | if N == 0:
22 | return torch.ones_like(anchor_centers)[:, 0] * 99999, (torch.zeros_like(anchor_centers)[:, 0]).long()
23 | acenters = anchor_centers.view(-1, 1, 2)
24 | acenters = acenters.repeat(1, N, 1)
25 | bcenters = box_centers.view(1, -1, 2)
26 | bcenters = bcenters.repeat(M, 1, 1)
27 |
28 | dis = torch.sqrt(torch.sum((acenters - bcenters)**2, dim=2))
29 |
30 | mindis, minind = torch.min(input=dis, dim=1)
31 |
32 | return mindis, minind
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/utils/val_mapper_with_ann.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | import copy
5 | import torch
6 | from fvcore.common.file_io import PathManager
7 |
8 | from detectron2.data import MetadataCatalog
9 | from detectron2.data import detection_utils as utils
10 | from detectron2.data import transforms as T
11 |
12 |
13 |
14 | class ValMapper(object):
15 | """
16 | COCO validation mapper, with annotations
17 | """
18 |
19 | def __init__(self, cfg):
20 | self.is_train = False
21 |
22 | self.tfm_gens = utils.build_transform_gen(cfg, self.is_train)
23 |
24 | self.img_format = cfg.INPUT.FORMAT
25 | assert not cfg.MODEL.LOAD_PROPOSALS, "not supported yet"
26 |
27 |
28 |
29 | def __call__(self, dataset_dict):
30 | """
31 | Args:
32 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
33 |
34 | Returns:
35 | dict: a format that builtin models in detectron2 accept
36 | """
37 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
38 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
39 | utils.check_image_size(dataset_dict, image)
40 |
41 | image, transforms = T.apply_transform_gens(self.tfm_gens, image)
42 | image_shape = image.shape[:2] # h, w
43 | dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
44 |
45 | for anno in dataset_dict["annotations"]:
46 | anno.pop("segmentation", None)
47 | anno.pop("keypoints", None)
48 |
49 | annos = [
50 | utils.transform_instance_annotations(
51 | obj, transforms, image_shape, keypoint_hflip_indices=None
52 | )
53 | for obj in dataset_dict.pop("annotations")
54 | if obj.get("iscrowd", 0) == 0
55 | ]
56 |
57 | instances = utils.annotations_to_instances(annos, image_shape)
58 |
59 | dataset_dict["instances"] = instances[instances.gt_boxes.nonempty()]
60 | return dataset_dict
61 |
62 |
--------------------------------------------------------------------------------
/visdrone/data_prepare.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from pathlib import Path
4 | sys.path.append(os.path.abspath(Path(__file__).parent.parent))
5 |
6 | import shutil
7 | import cv2
8 | import json
9 | from visdrone import utils
10 | from tqdm import tqdm
11 |
12 | import argparse
13 |
14 |
15 | def get_save_path(img_path, index):
16 | name = img_path.split('.')[0]
17 | return name + '_' + str(index) + '.jpg'
18 |
19 | def crop_and_save_image(img_root, img_path, new_img_root):
20 | img = cv2.imread(os.path.join(img_root,img_path))
21 | h, w, c = img.shape
22 |
23 | _y = h // 2
24 | _x = w // 2
25 |
26 | img0 = img[:_y, :_x, :]
27 | img1 = img[:_y, _x:, :]
28 | img2 = img[_y:, :_x, :]
29 | img3 = img[_y:, _x:, :]
30 |
31 | cv2.imwrite(os.path.join(new_img_root, get_save_path(img_path, 0)), img0)
32 | cv2.imwrite(os.path.join(new_img_root, get_save_path(img_path, 1)), img1)
33 | cv2.imwrite(os.path.join(new_img_root, get_save_path(img_path, 2)), img2)
34 | cv2.imwrite(os.path.join(new_img_root, get_save_path(img_path, 3)), img3)
35 |
36 | return h, w, _y, _x
37 |
38 |
39 | def copy_image(img_root, img_path, new_img_root):
40 | img = cv2.imread(os.path.join(img_root,img_path))
41 | h, w, c = img.shape
42 | cv2.imwrite(os.path.join(new_img_root, img_path), img)
43 | return h, w
44 |
45 |
46 | def get_new_label(label, img_path, cy, cx, id, img_id_base):
47 | if label['class'] == 0 or label['ignore']:
48 | return None
49 |
50 | x, y, w, h = label['bbox']
51 |
52 | if x < cx and y < cy:
53 | nx = x
54 | ny = y
55 | nw = min(x+w, cx) - x
56 | nh = min(y+h, cy) - y
57 | img_id = img_id_base
58 | elif x < cx and y >= cy:
59 | nx = x
60 | ny = y - cy
61 | nw = min(x+w, cx) - x
62 | nh = h
63 | img_id = img_id_base + 2
64 | elif x >= cx and y < cy:
65 | nx = x - cx
66 | ny = y
67 | nw = w
68 | nh = min(y+h, cy) - y
69 | img_id = img_id_base + 1
70 | else:
71 | nx = x - cx
72 | ny = y - cy
73 | nw = w
74 | nh = h
75 | img_id = img_id_base + 3
76 |
77 | new_label = {'category_id': label['class'], 'id': id, 'iscrowd':0, 'image_id':img_id, 'area':nw*nh, 'segmentation':[], 'bbox':[nx,ny,nw,nh]}
78 | return new_label
79 |
80 |
81 | def label_to_coco(label, id, img_id):
82 | x, y, w, h = label['bbox']
83 | new_label = {'category_id': label['class'], 'id': id, 'iscrowd':0, 'image_id':img_id, 'area':w*h, 'segmentation':[], 'bbox':[x,y,w,h]}
84 | return new_label
85 |
86 |
87 | def make_json(images, annotations, new_label_json):
88 | ann_dict = {}
89 | ann_dict['categories'] = [
90 | {'supercategory': 'things', 'id': 1, 'name': 'pedestrian'},
91 | {'supercategory': 'things', 'id': 2, 'name': 'people'},
92 | {'supercategory': 'things', 'id': 3, 'name': 'bicycle'},
93 | {'supercategory': 'things', 'id': 4, 'name': 'car'},
94 | {'supercategory': 'things', 'id': 5, 'name': 'van'},
95 | {'supercategory': 'things', 'id': 6, 'name': 'truck'},
96 | {'supercategory': 'things', 'id': 7, 'name': 'tricycle'},
97 | {'supercategory': 'things', 'id': 8, 'name': 'awning-tricycle'},
98 | {'supercategory': 'things', 'id': 9, 'name': 'bus'},
99 | {'supercategory': 'things', 'id': 10, 'name': 'motor'}
100 | ]
101 | ann_dict['images'] = images
102 | ann_dict['annotations'] = annotations
103 | with open(new_label_json, 'w') as outfile:
104 | json.dump(ann_dict, outfile)
105 |
106 |
107 | def make_new_train_set(img_root, label_root, new_img_root, new_label_json):
108 | all_labels = utils.read_all_labels(label_root)
109 |
110 | annotations = []
111 | images = []
112 | ann_id = 0
113 | img_id = 0
114 | for filename, labels in tqdm(all_labels.items()):
115 | img_path = filename.replace('txt', 'jpg')
116 | h, w, cy, cx = crop_and_save_image(img_root, img_path, new_img_root)
117 |
118 | images.append({'file_name': get_save_path(img_path, 0), 'height': cy, 'width': cx, 'id': img_id})
119 | images.append({'file_name': get_save_path(img_path, 1), 'height': cy, 'width': w-cx, 'id': img_id+1})
120 | images.append({'file_name': get_save_path(img_path, 2), 'height': h-cy, 'width': cx, 'id':img_id+2})
121 | images.append({'file_name': get_save_path(img_path, 3), 'height': h-cy, 'width': w-cx, 'id':img_id+3})
122 |
123 | for label in labels:
124 | new_label = get_new_label(label, img_path, cy, cx, ann_id, img_id)
125 | if new_label != None:
126 | ann_id += 1
127 | annotations.append(new_label)
128 | img_id += 4
129 | make_json(images, annotations, new_label_json)
130 |
131 |
132 | def make_new_test_set(img_root, label_root, new_img_root, new_label_json):
133 | all_labels = utils.read_all_labels(label_root)
134 | annotations = []
135 | images = []
136 | ann_id = 0
137 | img_id = 0
138 |
139 | for filename, labels in tqdm(all_labels.items()):
140 | img_path = filename.replace('txt', 'jpg')
141 | h, w = copy_image(img_root, img_path, new_img_root)
142 | images.append({'file_name': img_path, 'height': h, 'width': w, 'id': img_id})
143 | for label in labels:
144 | coco_label = label_to_coco(label, ann_id, img_id)
145 | if coco_label != None:
146 | ann_id += 1
147 | annotations.append(coco_label)
148 | img_id += 1
149 |
150 | make_json(images, annotations, new_label_json)
151 |
152 |
153 |
154 | if __name__ == '__main__':
155 |
156 | parser = argparse.ArgumentParser(description='Data Prepare Arguments')
157 | parser.add_argument('--visdrone-root', required=True, type=str, help='VisDrone dataset root')
158 | args = parser.parse_args()
159 |
160 | if not os.path.isdir(os.path.join(args.visdrone_root, 'coco_format')):
161 | os.mkdir(os.path.join(args.visdrone_root, 'coco_format'))
162 | os.mkdir(os.path.join(args.visdrone_root, 'coco_format/train_images'))
163 | os.mkdir(os.path.join(args.visdrone_root, 'coco_format/val_images'))
164 | os.mkdir(os.path.join(args.visdrone_root, 'coco_format/annotations'))
165 |
166 |
167 | '''
168 | Training
169 | '''
170 | train_img_root = os.path.join(args.visdrone_root, 'VisDrone2019-DET-train/images')
171 | train_label_root = os.path.join(args.visdrone_root, 'VisDrone2019-DET-train/annotations')
172 | train_new_img_root = os.path.join(args.visdrone_root, 'coco_format/train_images')
173 | train_new_label_json = os.path.join(args.visdrone_root, 'coco_format/annotations/train_label.json')
174 | make_new_train_set(train_img_root, train_label_root, train_new_img_root, train_new_label_json)
175 |
176 | '''
177 | Validation
178 | '''
179 | val_img_root = os.path.join(args.visdrone_root, 'VisDrone2019-DET-val/images')
180 | val_label_root = os.path.join(args.visdrone_root, 'VisDrone2019-DET-val/annotations')
181 | val_new_img_root = os.path.join(args.visdrone_root, 'coco_format/val_images')
182 | val_new_label_json = os.path.join(args.visdrone_root, 'coco_format/annotations/val_label.json')
183 | make_new_test_set(val_img_root, val_label_root, val_new_img_root, val_new_label_json)
184 |
185 | '''
186 | Test set, not needed here. You can convert by yourself in the same way as validation set if you want to.
187 | '''
188 | # img_root = '/path/to/test/images'
189 | # label_root = '/path/to/test/annotations'
190 | # new_img_root = '/path/to/test/images'
191 | # new_label_json = '/path/to/test/label.json'
192 | # make_new_test_set(img_root, label_root, new_img_root, new_label_json)
193 |
194 |
195 |
196 |
197 |
198 |
--------------------------------------------------------------------------------
/visdrone/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import bisect
4 | import copy
5 | import itertools
6 | import logging
7 | import numpy as np
8 | import operator
9 | import pickle
10 | import torch.utils.data
11 | from fvcore.common.file_io import PathManager
12 | from tabulate import tabulate
13 | from termcolor import colored
14 |
15 | from detectron2.structures import BoxMode
16 | from detectron2.utils.comm import get_world_size
17 | from detectron2.utils.env import seed_all_rng
18 | from detectron2.utils.logger import log_first_n
19 |
20 | from detectron2.structures.boxes import BoxMode
21 | from detectron2.data import samplers
22 | from detectron2.data.catalog import DatasetCatalog, MetadataCatalog
23 | from detectron2.data.common import AspectRatioGroupedDataset, DatasetFromList, MapDataset
24 | from detectron2.data.dataset_mapper import DatasetMapper
25 | from detectron2.data.detection_utils import check_metadata_consistency
26 |
27 |
28 | from visdrone.mapper import Mapper
29 |
30 |
31 | def get_train_data_dicts(json_file, img_root, filter_empty=False):
32 | data = json.load(open(json_file))
33 |
34 | images = {x['id']: {'file': x['file_name'], 'height':x['height'], 'width':x['width']} for x in data['images']}
35 |
36 | annotations = {}
37 | for ann in data['annotations']:
38 | img_id = ann['image_id']
39 | if img_id not in annotations.keys():
40 | annotations[img_id] = []
41 | annotations[img_id].append({'bbox': ann['bbox'], 'category_id': ann['category_id']-1, 'iscrowd': ann['iscrowd'], 'area': ann['area']})
42 |
43 | for img_id in images.keys():
44 | if img_id not in annotations.keys():
45 | annotations[img_id] = []
46 |
47 | data_dicts = []
48 | for img_id in images.keys():
49 | if filter_empty and len(annotations[img_id]) == 0:
50 | continue
51 | data_dict = {}
52 | data_dict['file_name'] = str(os.path.join(img_root, images[img_id]['file']))
53 | data_dict['height'] = images[img_id]['height']
54 | data_dict['width'] = images[img_id]['width']
55 | data_dict['image_id'] = img_id
56 | data_dict['annotations'] = []
57 | for ann in annotations[img_id]:
58 | data_dict['annotations'].append({'bbox': ann['bbox'], 'iscrowd': ann['iscrowd'], 'category_id': ann['category_id'], 'bbox_mode': BoxMode.XYWH_ABS})
59 | data_dicts.append(data_dict)
60 | return data_dicts
61 |
62 |
63 | def get_test_data_dicts(json_file, img_root):
64 | data = json.load(open(json_file))
65 | images = {x['id']: {'file': x['file_name'], 'height':x['height'], 'width':x['width']} for x in data['images']}
66 |
67 | data_dicts = []
68 | for img_id in images.keys():
69 | data_dict = {}
70 | data_dict['file_name'] = str(os.path.join(img_root, images[img_id]['file']))
71 | data_dict['height'] = images[img_id]['height']
72 | data_dict['width'] = images[img_id]['width']
73 | data_dict['image_id'] = img_id
74 | data_dict['annotations'] = []
75 | data_dicts.append(data_dict)
76 | return data_dicts
77 |
78 |
79 | def build_train_loader(cfg):
80 | num_workers = get_world_size()
81 | images_per_batch = cfg.SOLVER.IMS_PER_BATCH
82 |
83 | assert (
84 | images_per_batch % num_workers == 0
85 | ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
86 | images_per_batch, num_workers
87 | )
88 | assert (
89 | images_per_batch >= num_workers
90 | ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
91 | images_per_batch, num_workers
92 | )
93 | images_per_worker = images_per_batch // num_workers
94 |
95 | dataset_dicts = get_train_data_dicts(cfg.VISDRONE.TRAIN_JSON, cfg.VISDRONE.TRING_IMG_ROOT)
96 | dataset = DatasetFromList(dataset_dicts, copy=False)
97 | mapper = Mapper(cfg, True)
98 | dataset = MapDataset(dataset, mapper)
99 |
100 | sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
101 | logger = logging.getLogger(__name__)
102 | logger.info("Using training sampler {}".format(sampler_name))
103 |
104 | if sampler_name == "TrainingSampler":
105 | sampler = samplers.TrainingSampler(len(dataset))
106 | elif sampler_name == "RepeatFactorTrainingSampler":
107 | sampler = samplers.RepeatFactorTrainingSampler(
108 | dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
109 | )
110 | else:
111 | raise ValueError("Unknown training sampler: {}".format(sampler_name))
112 |
113 | batch_sampler = torch.utils.data.sampler.BatchSampler(
114 | sampler, images_per_worker, drop_last=True
115 | )
116 | # drop_last so the batch always have the same size
117 | data_loader = torch.utils.data.DataLoader(
118 | dataset,
119 | num_workers=cfg.DATALOADER.NUM_WORKERS,
120 | batch_sampler=batch_sampler,
121 | collate_fn=trivial_batch_collator,
122 | worker_init_fn=worker_init_reset_seed,
123 | )
124 | return data_loader
125 |
126 |
127 | def build_test_loader(cfg):
128 |
129 | dataset_dicts = get_test_data_dicts(cfg.VISDRONE.TEST_JSON, cfg.VISDRONE.TEST_IMG_ROOT)
130 |
131 | dataset = DatasetFromList(dataset_dicts)
132 | mapper = Mapper(cfg, False)
133 | dataset = MapDataset(dataset, mapper)
134 |
135 | sampler = samplers.InferenceSampler(len(dataset))
136 | batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False)
137 |
138 | data_loader = torch.utils.data.DataLoader(
139 | dataset,
140 | num_workers=cfg.DATALOADER.NUM_WORKERS,
141 | batch_sampler=batch_sampler,
142 | collate_fn=trivial_batch_collator,
143 | )
144 | return data_loader
145 |
146 |
147 | def worker_init_reset_seed(worker_id):
148 | seed_all_rng(np.random.randint(2 ** 31) + worker_id)
149 |
150 |
151 | def trivial_batch_collator(batch):
152 | return batch
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
--------------------------------------------------------------------------------
/visdrone/json_to_txt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import json
4 |
5 |
6 | import argparse
7 |
8 |
9 | class Json2Txt(object):
10 |
11 | def __init__(self, gt_json, det_json, out_dir):
12 | gt_data = json.load(open(gt_json))
13 | self.images = {x['id']: {'file': x['file_name'], 'height':x['height'], 'width':x['width']} for x in gt_data['images']}
14 |
15 | det_data = json.load(open(det_json))
16 |
17 | self.results = {}
18 | for result in det_data:
19 | if result['image_id'] not in self.results.keys():
20 | self.results[result['image_id']] = []
21 | self.results[result['image_id']].append({'box': result['bbox'], 'category': result['category_id'], 'score': result['score']})
22 |
23 | self.out_dir = out_dir
24 |
25 | def to_txt(self):
26 | for img_id in tqdm.tqdm(self.images.keys()):
27 | file_name = self.images[img_id]['file'].replace('jpg', 'txt')
28 | with open(os.path.join(self.out_dir, file_name), 'w') as fw:
29 | for pred in self.results[img_id]:
30 | row = '%.2f,%.2f,%.2f,%.2f,%.8f,%d,-1,-1'%(pred['box'][0],pred['box'][1],pred['box'][2],pred['box'][3],pred['score'],pred['category'])
31 | fw.write(row+'\n')
32 |
33 | if __name__ == '__main__':
34 |
35 | parser = argparse.ArgumentParser(description='Arguments')
36 | parser.add_argument('--out', required=True, type=str, help='output txt dir')
37 | parser.add_argument('--gt-json', required=False, type=str, default='visdrone_data/annotations/val_label', help='Grond Truth Info JSON')
38 | parser.add_argument('--det-json', required=True, type=str, help='COCO style result JSON')
39 | args = parser.parse_args()
40 |
41 | gt_json = args.gt_json
42 | det_json = args.det_json
43 | outdir = args.out
44 |
45 | if not os.path.isdir(outdir):
46 | os.mkdir(outdir)
47 |
48 | print('Json to txt:', outdir)
49 | tool = Json2Txt(gt_json, det_json, outdir)
50 | tool.to_txt()
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/visdrone/mapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import copy
3 | import logging
4 | import numpy as np
5 | import torch
6 | from fvcore.common.file_io import PathManager
7 | from PIL import Image
8 |
9 | from detectron2.data import detection_utils as utils
10 | from detectron2.data import transforms as T
11 |
12 | """
13 | This file contains the default mapping that's applied to "dataset dicts".
14 | """
15 |
16 | __all__ = ["DatasetMapper"]
17 |
18 |
19 | class Mapper:
20 | """
21 | A callable which takes a dataset dict in Detectron2 Dataset format,
22 | and map it into a format used by the model.
23 |
24 | This is the default callable to be used to map your dataset dict into training data.
25 | You may need to follow it to implement your own one for customized logic,
26 | such as a different way to read or transform images.
27 | See :doc:`/tutorials/data_loading` for details.
28 |
29 | The callable currently does the following:
30 |
31 | 1. Read the image from "file_name"
32 | 2. Applies cropping/geometric transforms to the image and annotations
33 | 3. Prepare data and annotations to Tensor and :class:`Instances`
34 | """
35 |
36 | def __init__(self, cfg, is_train=True):
37 |
38 | self.tfm_gens = build_transform_gen(cfg, is_train)
39 | # fmt: off
40 | self.img_format = cfg.INPUT.FORMAT
41 | self.mask_on = False
42 | self.mask_format = cfg.INPUT.MASK_FORMAT
43 | self.keypoint_on = False
44 | self.load_proposals = False
45 | self.keypoint_hflip_indices = None
46 | # fmt: on
47 |
48 | self.is_train = is_train
49 |
50 | def __call__(self, dataset_dict):
51 |
52 | dataset_dict = copy.deepcopy(dataset_dict)
53 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
54 | utils.check_image_size(dataset_dict, image)
55 |
56 | image, transforms = T.apply_transform_gens(self.tfm_gens, image)
57 | image_shape = image.shape[:2] # h, w
58 |
59 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
60 |
61 |
62 | if not self.is_train:
63 | dataset_dict.pop("annotations", None)
64 | return dataset_dict
65 |
66 | if "annotations" in dataset_dict:
67 | # USER: Modify this if you want to keep them for some reason.
68 | for anno in dataset_dict["annotations"]:
69 | anno.pop("segmentation", None)
70 | anno.pop("keypoints", None)
71 |
72 | # USER: Implement additional transformations if you have other types of data
73 | annos = [
74 | utils.transform_instance_annotations(
75 | obj, transforms, image_shape
76 | )
77 | for obj in dataset_dict.pop("annotations")
78 | if obj.get("iscrowd", 0) == 0
79 | ]
80 | instances = utils.annotations_to_instances(
81 | annos, image_shape, mask_format=self.mask_format
82 | )
83 | dataset_dict["instances"] = utils.filter_empty_instances(instances)
84 | return dataset_dict
85 |
86 |
87 | def build_transform_gen(cfg, is_train):
88 | if is_train:
89 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
90 | else:
91 | sample_style = 'choice'
92 |
93 | logger = logging.getLogger(__name__)
94 | tfm_gens = []
95 | if is_train:
96 | tfm_gens.append(T.RandomFlip(horizontal=True, vertical=False))
97 | tfm_gens.append(T.ResizeShortestEdge(short_edge_length=cfg.VISDRONE.SHORT_LENGTH, max_size=cfg.VISDRONE.MAX_LENGTH, sample_style=sample_style))
98 | else:
99 | tfm_gens.append(T.ResizeShortestEdge(short_edge_length=[cfg.VISDRONE.TEST_LENGTH], max_size=cfg.VISDRONE.TEST_LENGTH, sample_style=sample_style))
100 |
101 | return tfm_gens
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
--------------------------------------------------------------------------------
/visdrone/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 |
5 |
6 |
7 | def read_label_txt(txt_file):
8 | f = open(txt_file, 'r')
9 | lines = f.readlines()
10 |
11 | labels = []
12 | for line in lines:
13 | line = line.strip().split(',')
14 |
15 | x, y, w, h, not_ignore, cate, trun, occ = line[:8]
16 |
17 | labels.append(
18 | {'bbox': (int(x),int(y),int(w),int(h)),
19 | 'ignore': 0 if int(not_ignore) else 1,
20 | 'class': int(cate),
21 | 'truncate': int(trun),
22 | 'occlusion': int(occ)}
23 | )
24 | return labels
25 |
26 |
27 | def read_all_labels(ann_root):
28 | ann_list = os.listdir(ann_root)
29 | all_labels = {}
30 | for ann_file in ann_list:
31 | if not ann_file.endswith('txt'):
32 | continue
33 | ann_labels = read_label_txt(os.path.join(ann_root, ann_file))
34 | all_labels[ann_file] = ann_labels
35 | return all_labels
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/visdrone_eval/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 同济大学智能汽车研究所综合感知研究组 ( Comprehensive Perception Research Group under Institute of Intelligent Vehicles, School of Automotive Studies, Tongji University)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/visdrone_eval/README.md:
--------------------------------------------------------------------------------
1 | # visdrone-det-toolkit-python
2 |
3 | Python implementation of evaluation utilities of **[VisDrone2018-DET-toolkit](https://github.com/VisDrone/VisDrone2018-DET-toolkit)**.
4 |
5 | ### Run Evaluation
6 |
7 | Modify the dataset and result directories in evaluate.py and run:
8 |
9 | ```shell
10 | python evaluate.py
11 | ```
12 |
13 | ### Installation and Usage
14 |
15 | Installation:
16 |
17 | ```bash
18 | pip install -e .
19 | ```
20 |
21 | An example of using the function `eval_det` is given below:
22 |
23 | ```python
24 | from viseval import eval_det
25 | ...
26 | ap_all, ap_50, ap_75, ar_1, ar_10, ar_100, ar_500 = eval_det(
27 | annotations, results, heights, widths)
28 | ...
29 | ```
30 |
31 | Reference: https://github.com/tjiiv-cprg/visdrone-det-toolkit-python.git
32 |
33 |
--------------------------------------------------------------------------------
/visdrone_eval/evaluate.py:
--------------------------------------------------------------------------------
1 | # reference: https://github.com/tjiiv-cprg/visdrone-det-toolkit-python
2 |
3 |
4 | import os.path as osp
5 | import os
6 | import numpy as np
7 | import cv2
8 | from viseval.eval_det import eval_det
9 |
10 | import argparse
11 | parser = argparse.ArgumentParser(description='Arguments')
12 | parser.add_argument('--dataset-dir', required=True, type=str, help='output txt dir')
13 | parser.add_argument('--res-dir', required=True, type=str, help='Grond Truth Info JSON')
14 | args = parser.parse_args()
15 |
16 | def open_label_file(path, dtype=np.float32):
17 | label = np.loadtxt(path, delimiter=',', dtype=dtype,
18 | ndmin=2, usecols=range(8))
19 | if not len(label):
20 | label = label.reshape(0, 8)
21 | return label
22 |
23 |
24 | def main():
25 | dataset_dir = args.dataset_dir
26 | res_dir = args.res_dir
27 |
28 | gt_dir = osp.join(dataset_dir, 'annotations')
29 | img_dir = osp.join(dataset_dir, 'images')
30 |
31 | all_gt = []
32 | all_det = []
33 | allheight = []
34 | allwidth = []
35 |
36 | data_list_path = os.listdir(img_dir)
37 |
38 | for filename in data_list_path:
39 | filename = filename.strip().split('.')[0]
40 | img_path = osp.join(img_dir, filename + '.jpg')
41 | img = cv2.imread(img_path)
42 | height, width = img.shape[:2]
43 |
44 | allheight.append(height)
45 | allwidth.append(width)
46 |
47 | label = open_label_file(
48 | osp.join(gt_dir, filename + '.txt'), dtype=np.int32)
49 | all_gt.append(label)
50 |
51 | det = open_label_file(
52 | osp.join(res_dir, filename + '.txt'))
53 | all_det.append(det)
54 |
55 | ap_all, ap_50, ap_75, ar_1, ar_10, ar_100, ar_500, ap_classwise = eval_det(
56 | all_gt, all_det, allheight, allwidth, per_class=True)
57 |
58 | print('Average Precision (AP) @[ IoU=0.50:0.95 | maxDets=500 ] = {}%.'.format(ap_all))
59 | print('Average Precision (AP) @[ IoU=0.50 | maxDets=500 ] = {}%.'.format(ap_50))
60 | print('Average Precision (AP) @[ IoU=0.75 | maxDets=500 ] = {}%.'.format(ap_75))
61 | print('Average Recall (AR) @[ IoU=0.50:0.95 | maxDets= 1 ] = {}%.'.format(ar_1))
62 | print('Average Recall (AR) @[ IoU=0.50:0.95 | maxDets= 10 ] = {}%.'.format(ar_10))
63 | print('Average Recall (AR) @[ IoU=0.50:0.95 | maxDets=100 ] = {}%.'.format(ar_100))
64 | print('Average Recall (AR) @[ IoU=0.50:0.95 | maxDets=500 ] = {}%.'.format(ar_500))
65 |
66 | for i, ap in enumerate(ap_classwise):
67 | print('Class {} AP = {}%'.format(i, ap))
68 |
69 |
70 | if __name__ == '__main__':
71 | main()
72 |
--------------------------------------------------------------------------------
/visdrone_eval/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy~=1.18.5
2 | setuptools~=46.4.0
3 | opencv-python~=4.2.0.34
--------------------------------------------------------------------------------
/visdrone_eval/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from setuptools import find_packages, setup
3 |
4 |
5 | def parse_requirements(fname='requirements.txt', with_version=True):
6 | """
7 | Parse the package dependencies listed in a requirements file but strips
8 | specific versioning information.
9 |
10 | Args:
11 | fname (str): path to requirements file
12 | with_version (bool, default=False): if True include version specs
13 |
14 | Returns:
15 | List[str]: list of requirements items
16 |
17 | CommandLine:
18 | python -c "import setup; print(setup.parse_requirements())"
19 | """
20 | import sys
21 | from os.path import exists
22 | import re
23 | require_fpath = fname
24 |
25 | def parse_line(line):
26 | """
27 | Parse information from a line in a requirements text file
28 | """
29 | if line.startswith('-r '):
30 | # Allow specifying requirements in other files
31 | target = line.split(' ')[1]
32 | for info in parse_require_file(target):
33 | yield info
34 | else:
35 | info = {'line': line}
36 | if line.startswith('-e '):
37 | info['package'] = line.split('#egg=')[1]
38 | else:
39 | # Remove versioning from the package
40 | pat = '(' + '|'.join(['>=', '==', '>']) + ')'
41 | parts = re.split(pat, line, maxsplit=1)
42 | parts = [p.strip() for p in parts]
43 |
44 | info['package'] = parts[0]
45 | if len(parts) > 1:
46 | op, rest = parts[1:]
47 | if ';' in rest:
48 | # Handle platform specific dependencies
49 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
50 | version, platform_deps = map(str.strip,
51 | rest.split(';'))
52 | info['platform_deps'] = platform_deps
53 | else:
54 | version = rest # NOQA
55 | info['version'] = (op, version)
56 | yield info
57 |
58 | def parse_require_file(fpath):
59 | with open(fpath, 'r') as f:
60 | for line in f.readlines():
61 | line = line.strip()
62 | if line and not line.startswith('#'):
63 | for info in parse_line(line):
64 | yield info
65 |
66 | def gen_packages_items():
67 | if exists(require_fpath):
68 | for info in parse_require_file(require_fpath):
69 | parts = [info['package']]
70 | if with_version and 'version' in info:
71 | parts.extend(info['version'])
72 | if not sys.version.startswith('3.4'):
73 | # apparently package_deps are broken in 3.4
74 | platform_deps = info.get('platform_deps')
75 | if platform_deps is not None:
76 | parts.append(';' + platform_deps)
77 | item = ''.join(parts)
78 | yield item
79 |
80 | packages = list(gen_packages_items())
81 | return packages
82 |
83 |
84 | if __name__ == '__main__':
85 | setup(
86 | name='visdrone_eval',
87 | version='0.1',
88 | description='Python Implementation of VisDrone Detection Toolbox',
89 | packages=find_packages(exclude=('configs', 'tools', 'demo')),
90 | install_requires=parse_requirements('requirements.txt')
91 | )
92 |
--------------------------------------------------------------------------------
/visdrone_eval/viseval/__init__.py:
--------------------------------------------------------------------------------
1 | from .eval_det import eval_det
2 |
3 |
4 | __all__ = ['eval_det']
5 |
--------------------------------------------------------------------------------
/visdrone_eval/viseval/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/visdrone_eval/viseval/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/visdrone_eval/viseval/__pycache__/bbox_overlaps.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/visdrone_eval/viseval/__pycache__/bbox_overlaps.cpython-37.pyc
--------------------------------------------------------------------------------
/visdrone_eval/viseval/__pycache__/calc_accuracy.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/visdrone_eval/viseval/__pycache__/calc_accuracy.cpython-37.pyc
--------------------------------------------------------------------------------
/visdrone_eval/viseval/__pycache__/drop_objects_in_igr.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/visdrone_eval/viseval/__pycache__/drop_objects_in_igr.cpython-37.pyc
--------------------------------------------------------------------------------
/visdrone_eval/viseval/__pycache__/eval_det.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhongyiYang/QueryDet-PyTorch/feebf218d53d59ba054132dfa6ef84159f793967/visdrone_eval/viseval/__pycache__/eval_det.cpython-37.pyc
--------------------------------------------------------------------------------
/visdrone_eval/viseval/bbox_overlaps.py:
--------------------------------------------------------------------------------
1 | # from mmdetection
2 |
3 | import numpy as np
4 |
5 |
6 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', eps=1e-6):
7 | """Calculate the ious between each bbox of bboxes1 and bboxes2.
8 |
9 | Args:
10 | bboxes1(ndarray): shape (n, 4)
11 | bboxes2(ndarray): shape (k, 4)
12 | mode(str): iou (intersection over union) or iof (intersection
13 | over foreground)
14 | eps(float):
15 |
16 | Returns:
17 | ious(ndarray): shape (n, k)
18 | """
19 |
20 | assert mode in ['iou', 'iof']
21 |
22 | bboxes1 = bboxes1.astype(np.float32)
23 | bboxes2 = bboxes2.astype(np.float32)
24 | rows = bboxes1.shape[0]
25 | cols = bboxes2.shape[0]
26 | ious = np.zeros((rows, cols), dtype=np.float32)
27 | if rows * cols == 0:
28 | return ious
29 | exchange = False
30 | if bboxes1.shape[0] > bboxes2.shape[0]:
31 | bboxes1, bboxes2 = bboxes2, bboxes1
32 | ious = np.zeros((cols, rows), dtype=np.float32)
33 | exchange = True
34 | area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
35 | area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
36 | for i in range(bboxes1.shape[0]):
37 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
38 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
39 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
40 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
41 | overlap = np.maximum(x_end - x_start, 0) * np.maximum(
42 | y_end - y_start, 0)
43 | if mode == 'iou':
44 | union = area1[i] + area2 - overlap
45 | else:
46 | union = area1[i] if not exchange else area2
47 | union = np.maximum(union, eps)
48 | ious[i, :] = overlap / union
49 | if exchange:
50 | ious = ious.T
51 | return ious
52 |
--------------------------------------------------------------------------------
/visdrone_eval/viseval/calc_accuracy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .bbox_overlaps import bbox_overlaps
3 |
4 |
5 | def eval_res(gt0, dt0, thr):
6 | """
7 | :param gt0: np.array[ng, 5], ground truth results [x, y, w, h, ignore]
8 | :param dt0: np.array[nd, 5], detection results [x, y, w, h, score]
9 | :param thr: float, IoU threshold
10 | :return gt1: np.array[ng, 5], gt match types
11 | dt1: np.array[nd, 6], dt match types
12 | """
13 | nd = len(dt0)
14 | ng = len(gt0)
15 |
16 | # sort
17 | dt = dt0[dt0[:, 4].argsort()[::-1]]
18 | gt_ignore_mask = gt0[:, 4] == 1
19 | gt = gt0[np.logical_not(gt_ignore_mask)]
20 | ig = gt0[gt_ignore_mask]
21 | ig[:, 4] = -ig[:, 4] # -1 indicates ignore
22 |
23 | dt_format = dt[:, :4].copy()
24 | gt_format = gt[:, :4].copy()
25 | ig_format = ig[:, :4].copy()
26 | dt_format[:, 2:] += dt_format[:, :2] # [x2, y2] = [w, h] + [x1, y1]
27 | gt_format[:, 2:] += gt_format[:, :2]
28 | ig_format[:, 2:] += ig_format[:, :2]
29 |
30 | iou_dtgt = bbox_overlaps(dt_format, gt_format, mode='iou')
31 | iof_dtig = bbox_overlaps(dt_format, gt_format, mode='iof')
32 | oa = np.concatenate((iou_dtgt, iof_dtig), axis=1)
33 |
34 | # [nd, 6]
35 | dt1 = np.concatenate((dt, np.zeros((nd, 1), dtype=dt.dtype)), axis=1)
36 | # [ng, 5]
37 | gt1 = np.concatenate((gt, ig), axis=0)
38 |
39 | for d in range(nd):
40 | bst_oa = thr
41 | bstg = -1 # index of matched gt
42 | bstm = 0 # best match type
43 | for g in range(ng):
44 | m = gt1[g, 4]
45 | # if gt already matched, continue to next gt
46 | if m == 1:
47 | continue
48 | # if dt already matched, and on ignore gt, nothing more to do
49 | if bstm != 0 and m == -1:
50 | break
51 | # continue to next gt until better match is found
52 | if oa[d, g] < bst_oa:
53 | continue
54 | bst_oa = oa[d, g]
55 | bstg = g
56 | bstm = 1 if m == 0 else -1 # 1: matched to gt, -1: matched to ignore
57 |
58 | # store match type for dt
59 | dt1[d, 5] = bstm
60 | # store match flag for gt
61 | if bstm == 1:
62 | gt1[bstg, 4] = 1
63 |
64 | return gt1, dt1
65 |
66 |
67 | def voc_ap(rec, prec):
68 | mrec = np.concatenate(([0], rec, [1]))
69 | mpre = np.concatenate(([0], prec, [0]))
70 | for i in reversed(range(0, len(mpre)-1)):
71 | mpre[i] = max(mpre[i], mpre[i + 1])
72 | i = np.flatnonzero(mrec[1:] != mrec[:-1]) + 1
73 | ap = np.sum((mrec[i] - mrec[i - 1]) * mpre[i])
74 | return ap
75 |
76 |
77 | def calc_accuracy(num_imgs, all_gt, all_det, per_class=False):
78 | """
79 | :param num_imgs: int
80 | :param all_gt: list of np.array[m, 8], [:, 4] == 1 indicates ignored regions,
81 | which should be dropped before calling this function
82 | :param all_det: list of np.array[m, 6], truncation and occlusion not necessary
83 | :param per_class:
84 | """
85 | assert num_imgs == len(all_gt) == len(all_det)
86 |
87 | ap = np.zeros((10, 10), dtype=np.float32)
88 | ar = np.zeros((10, 10, 4), dtype=np.float32)
89 | eval_class = []
90 |
91 | print('')
92 | for id_class in range(1, 11):
93 | print('evaluating object category {}/10...'.format(id_class))
94 |
95 | for gt in all_gt:
96 | if np.any(gt[:, 5] == id_class):
97 | eval_class.append(id_class - 1)
98 |
99 | x = 0
100 | for thr in np.linspace(0.5, 0.95, num=10):
101 | y = 0
102 | for max_dets in (1, 10, 100, 500):
103 | gt_match = []
104 | det_match = []
105 | for gt, det in zip(all_gt, all_det):
106 | det_limited = det[:min(len(det), max_dets)]
107 | mask_gt_cur_class = gt[:, 5] == id_class
108 | mask_det_cur_class = det_limited[:, 5] == id_class
109 | gt0 = gt[mask_gt_cur_class, :5]
110 | dt0 = det_limited[mask_det_cur_class, :5]
111 | gt1, dt1 = eval_res(gt0, dt0, thr)
112 | # 1: matched, 0: unmatched, -1: ignore
113 | gt_match.append(gt1[:, 4])
114 | # [score, match type]
115 | # 1: matched to gt, 0: unmatched, -1: matched to ignore
116 | det_match.append(dt1[:, 4:6])
117 | gt_match = np.concatenate(gt_match, axis=0)
118 | det_match = np.concatenate(det_match, axis=0)
119 |
120 | idrank = det_match[:, 0].argsort()[::-1]
121 | tp = np.cumsum(det_match[idrank, 1] == 1)
122 | rec = tp / max(1, len(gt_match)) # including ignore (already dropped)
123 | if len(rec):
124 | ar[id_class - 1, x, y] = np.max(rec) * 100
125 |
126 | y += 1
127 |
128 | fp = np.cumsum(det_match[idrank, 1] == 0)
129 | prec = tp / (fp + tp).clip(min=1)
130 | ap[id_class - 1, x] = voc_ap(rec, prec) * 100
131 |
132 | x += 1
133 |
134 | ap_all = np.mean(ap[eval_class, :])
135 | ap_50 = np.mean(ap[eval_class, 0])
136 | ap_75 = np.mean(ap[eval_class, 5])
137 | ar_1 = np.mean(ar[eval_class, :, 0])
138 | ar_10 = np.mean(ar[eval_class, :, 1])
139 | ar_100 = np.mean(ar[eval_class, :, 2])
140 | ar_500 = np.mean(ar[eval_class, :, 3])
141 |
142 | results = (ap_all, ap_50, ap_75, ar_1, ar_10, ar_100, ar_500)
143 |
144 | if per_class:
145 | ap_classwise = np.mean(ap, axis=1)
146 | results += (ap_classwise,)
147 |
148 | print('Evaluation completed. The performance of the detector is presented as follows.')
149 |
150 | return results
151 |
--------------------------------------------------------------------------------
/visdrone_eval/viseval/drop_objects_in_igr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def create_int_img(img):
5 | int_img = np.cumsum(img, axis=0)
6 | np.cumsum(int_img, axis=1, out=int_img)
7 | return int_img
8 |
9 |
10 | def drop_objects_in_igr(gt, det, img_height, img_width):
11 | gt_ignore_mask = gt[:, 5] == 0
12 | curgt = gt[np.logical_not(gt_ignore_mask)]
13 | igr_region = gt[gt_ignore_mask, :4].clip(min=1)
14 | if len(igr_region):
15 | igr_map = np.zeros((img_height, img_width), dtype=np.int)
16 |
17 | for igr in igr_region:
18 | x1 = igr[0]
19 | y1 = igr[1]
20 | x2 = min(x1 + igr[2], img_width)
21 | y2 = min(y1 + igr[3], img_height)
22 | igr_map[y1 - 1:y2, x1 - 1:x2] = 1
23 | int_igr_map = create_int_img(igr_map)
24 | idx_left_gt = []
25 |
26 | for i, gtbox in enumerate(curgt):
27 | pos = np.round(gtbox[:4]).astype(np.int32).clip(min=1)
28 | x = max(1, min(img_width - 1, pos[0]))
29 | y = max(1, min(img_height - 1, pos[1]))
30 | w = pos[2]
31 | h = pos[3]
32 | tl = int_igr_map[y - 1, x - 1]
33 | tr = int_igr_map[y - 1, min(img_width, x + w) - 1]
34 | bl = int_igr_map[max(1, min(img_height, y + h)) - 1, x - 1]
35 | br = int_igr_map[max(1, min(img_height, y + h)) - 1,
36 | min(img_width, x + w) - 1]
37 | igr_val = tl + br - tr - bl
38 | if igr_val / (h * w) < 0.5:
39 | idx_left_gt.append(i)
40 |
41 | curgt = curgt[idx_left_gt]
42 |
43 | idx_left_det = []
44 | for i, dtbox in enumerate(det):
45 | pos = np.round(dtbox[:4]).astype(np.int32).clip(min=1)
46 | x = max(1, min(img_width - 1, pos[0]))
47 | y = max(1, min(img_height - 1, pos[1]))
48 | w = pos[2]
49 | h = pos[3]
50 | tl = int_igr_map[y - 1, x - 1]
51 | tr = int_igr_map[y - 1, min(img_width, x + w) - 1]
52 | bl = int_igr_map[max(1, min(img_height, y + h)) - 1, x - 1]
53 | br = int_igr_map[max(1, min(img_height, y + h)) - 1,
54 | min(img_width, x + w) - 1]
55 | igr_val = tl + br - tr - bl
56 | if igr_val / (h * w) < 0.5:
57 | idx_left_det.append(i)
58 |
59 | det = det[idx_left_det]
60 |
61 | return curgt, det
62 |
--------------------------------------------------------------------------------
/visdrone_eval/viseval/eval_det.py:
--------------------------------------------------------------------------------
1 | from .calc_accuracy import calc_accuracy
2 | from .drop_objects_in_igr import drop_objects_in_igr
3 |
4 |
5 | def eval_det(all_gt, all_det, allheight, allwidth, per_class=False):
6 | """
7 | :param all_gt: list of np.array[m, 8]
8 | :param all_det: list of np.array[m, 6], truncation and occlusion not necessary
9 | :param allheight:
10 | :param allwidth:
11 | :param per_class:
12 | """
13 | all_gt_ = []
14 | all_det_ = []
15 | num_imgs = len(all_gt)
16 | for gt, det, height, width in zip(all_gt, all_det, allheight, allwidth):
17 | gt, det = drop_objects_in_igr(gt, det, height, width)
18 | gt[:, 4] = 1 - gt[:, 4] # set ignore flag
19 | all_gt_.append(gt)
20 | all_det_.append(det)
21 | return calc_accuracy(num_imgs, all_gt_, all_det_, per_class)
22 |
--------------------------------------------------------------------------------