├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── figs └── grouprcnn.png ├── projects ├── api │ └── test.py ├── configs │ ├── 10_coco │ │ ├── 01_student_fcos.py │ │ ├── base.py │ │ ├── group_rcnn_24e_10_percent_coco_detr_augmentation.py │ │ └── group_rcnn_50e_10_percent_coco_detr_augmentation.py │ └── _base_ │ │ ├── datasets │ │ └── coco_detection.py │ │ ├── default_runtime.py │ │ └── schedules │ │ └── schedule_1x.py ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── point_coco.py │ └── transform.py └── models │ ├── __init__.py │ ├── group_rcnn.py │ └── group_roi_head.py └── tools ├── dist_test.sh ├── dist_train.sh ├── generate_anns.py ├── slurm_test.sh ├── slurm_train.sh ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | data/ 107 | data 108 | .vscode 109 | .idea 110 | .DS_Store 111 | 112 | # custom 113 | *.pkl 114 | *.pkl.json 115 | *.log.json 116 | docs/modelzoo_statistics.md 117 | mmdet/.mim 118 | work_dirs/ 119 | 120 | # Pytorch 121 | *.pth 122 | *.py~ 123 | *.sh~ 124 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | known_third_party =mmcv,mmdet,numpy,pycocotools,terminaltables,torch 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://gitlab.com/pycqa/flake8.git 3 | rev: 3.8.3 4 | hooks: 5 | - id: flake8 6 | - repo: https://github.com/asottile/seed-isort-config 7 | rev: v2.2.0 8 | hooks: 9 | - id: seed-isort-config 10 | - repo: https://github.com/timothycrosley/isort 11 | rev: 4.3.21 12 | hooks: 13 | - id: isort 14 | - repo: https://github.com/pre-commit/mirrors-yapf 15 | rev: v0.30.0 16 | hooks: 17 | - id: yapf 18 | - repo: https://github.com/pre-commit/pre-commit-hooks 19 | rev: v3.1.0 20 | hooks: 21 | - id: trailing-whitespace 22 | - id: check-yaml 23 | - id: end-of-file-fixer 24 | - id: requirements-txt-fixer 25 | - id: double-quote-string-fixer 26 | - id: check-merge-conflict 27 | - id: fix-encoding-pragma 28 | args: ["--remove"] 29 | - id: mixed-line-ending 30 | args: ["--fix=lf"] 31 | - repo: https://github.com/markdownlint/markdownlint 32 | rev: v0.11.0 33 | hooks: 34 | - id: markdownlint 35 | args: ["-r", "~MD002,~MD013,~MD029,~MD033,~MD034", 36 | "-t", "allow_different_nesting"] 37 | - repo: https://github.com/codespell-project/codespell 38 | rev: v2.1.0 39 | hooks: 40 | - id: codespell 41 | - repo: https://github.com/myint/docformatter 42 | rev: v1.3.1 43 | hooks: 44 | - id: docformatter 45 | args: ["--in-place", "--wrap-descriptions", "79"] 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [Group R-CNN for Point-based Weakly Semi-supervised Object Detection](https://arxiv.org/abs/2205.05920) (CVPR2022) 2 | 3 | By Shilong Zhang*, Zhuoran Yu*, Liyang Liu*, Xinjiang Wang, Aojun Zhou, Kai Chen 4 | ### Abstract: 5 | We study the problem of weakly semi-supervised object 6 | detection with points (WSSOD-P), where the training data 7 | is combined by a small set of fully annotated images with 8 | bounding boxes and a large set of weakly-labeled images 9 | with only a single point annotated for each instance. The 10 | core of this task is to train a point-to-box regressor on well 11 | labeled images that can be used to predict credible bounding boxes for each point annotation. 12 | Group R-CNN significantly outperforms the prior 13 | method Point DETR by 3.9 mAP with 5% well-labeled images, which is the most challenging scenario. 14 | 15 | 16 | ![](./figs/grouprcnn.png) 17 | 18 | 19 | ### Install 20 | The project has been fully tested under [MMDetection V2.22.0](https://github.com/open-mmlab/mmdetection/releases/tag/v2.22.0) and [MMCV V1.4.6](https://github.com/open-mmlab/mmcv/releases/tag/v1.4.6), other versions may not be compatible. so you have to install mmcv and mmdetection firstly. 21 | You can refer to [Installation of MMCV](https://github.com/open-mmlab/mmcv) & [Installation of MMDetection](https://mmdetection.readthedocs.io/en/v2.18.1/get_started.html#installation) 22 | 23 | ### Prepare the dataset 24 | 25 | ```text 26 | mmdetection 27 | ├── data 28 | │ ├── coco 29 | │ │ ├── annotations 30 | │ │ │ ├──instances_train2017.json 31 | │ │ │ ├──instances_val2017.json 32 | │ │ ├── train2017 33 | │ │ ├── val2017 34 | ``` 35 | 36 | You can generate point annotations with the command. It may take you several minutes for `instances_train2017.json` 37 | ```python 38 | python tools/generate_anns.py /data/coco/annotations/instances_train2017.json 39 | python tools/generate_anns.py /data/coco/annotations/instances_val2017.json 40 | ``` 41 | Then you can find a `point_ann` directory, all annotations in the directory contain point annotations. Then you should replace the original annotations in `data/coco/annotations` with generated annotations. 42 | 43 | ### NOTES 44 | Here, we sample a point from the mask for all instances. But we split the images into two divisions in :class:`PointCocoDataset`. 45 | - Images with only bbox annotations(well-labeled images): Only be used in training phase. We sample a point from its bbox 46 | as point annotations each iteration. 47 | - Images with only point annotations(weakly-labeled sets): Only be used to generate bbox annotations from point annotations with trained point to bbox regressor. 48 | ### Train and Test 49 | 8 is the number of gpus. 50 | ##### For slurm 51 | 52 | Train 53 | ```shell 54 | GPUS=8 sh tools/slurm_train.sh partition_name job_name projects/configs/10_coco/group_rcnn_24e_10_percent_coco_detr_augmentation.py ./exp/group_rcnn 55 | ``` 56 | Evaluate the quality of generated bbox annotations on val dataset with pre-defined point annotations. 57 | ```shell 58 | GPUS=8 sh tools/slurm_test.sh partition_name job_name projects/configs/10_coco/group_rcnn_24e_10_percent_coco_detr_augmentation.py ./exp/group_rcnn/latest.pth --eval bbox 59 | ``` 60 | Run the inference process on weakly-labeled images with point annotations to get bbox annotations. 61 | ```shell 62 | GPUS=8 sh tools/slurm_test.sh partition_name job_name projects/configs/10_coco/group_rcnn_50e_10_percent_coco_detr_augmentation.py path_to_checkpoint --format-only --options "jsonfile_prefix=./generated" 63 | ``` 64 | ##### For Pytorch distributed 65 | 66 | Train 67 | ```shell 68 | sh tools/dist_train.sh projects/configs/10_coco/group_rcnn_24e_10_percent_coco_detr_augmentation.py 8 --work-dir ./exp/group_rcnn 69 | ``` 70 | Evaluate the quality of generated bbox annotations on val dataset with pre-defined point annotations. 71 | ```shell 72 | sh tools/dist_test.sh projects/configs/10_coco/group_rcnn_24e_10_percent_coco_detr_augmentation.py path_to_checkpoint 8 --eval bbox 73 | ``` 74 | 75 | Run the inference process on weakly-labeled images with point annotations to get bbox annotations. 76 | ```shell 77 | sh tools/dist_test.sh projects/configs/10_coco/group_rcnn_50e_10_percent_coco_detr_augmentation.py path_to_checkpoint 8 --format-only --options "jsonfile_prefix=./data/coco/annotations/generated" 78 | ``` 79 | Then you can train the student model focs. 80 | ```shell 81 | sh tools/dist_train.sh projects/configs/10_coco/01_student_fcos.py 8 --work-dir ./exp/01_student_fcos 82 | ``` 83 | 84 | #### Results & Checkpoints 85 | We find that the performance of teacher is unstable under 24e setting and may fluctuate by about 0.2 mAP. We report the average. 86 | 87 | | Model | Backbone | Lr schd | Augmentation | box AP | Config | Model | log |Generated Annotations | 88 | | :----: | :------: | :-----: | :----: | :------: |:------: |:------: |:------: |:------: | 89 | | Teacher(Group R-CNN) | R-50-FPN | 24e | DETR Aug| 39.2 | [config](https://github.com/jshilong/GroupRCNN/tree/main/projects/configs/10_coco/group_rcnn_24e_10_percent_coco_detr_augmentation.py) | [ckpt](https://drive.google.com/file/d/18czpIJcKOgp8T7wE693WZEj1kbUUsaMA/view?usp=sharing) | [log](https://drive.google.com/file/d/14n09FOv3bSVLf_aYGpucYI4_Q8eJUczP/view?usp=sharing) | - 90 | | Teacher(Group R-CNN) | R-50-FPN | 50e | DETR Aug|39.9| [config](https://github.com/jshilong/GroupRCNN/tree/main/projects/configs/10_coco/group_rcnn_50e_10_percent_coco_detr_augmentation.py) | [ckpt](https://drive.google.com/file/d/1_yQtDBS9MqeCvRXMbAaBKyi5zgR5LP-z/view?usp=sharing) | [log](https://drive.google.com/file/d/1AiqXqbdf425tXdCP0T8pJyym-9Yf7p96/view?usp=sharing) | [generated.bbox.json](https://drive.google.com/file/d/1hyTgWRXuCUCRcPEgsqU-0eGVqDDdw7-b/view?usp=sharing) 91 | | Student(FCOS) | R-50-FPN | 12e |Normal 1x Aug| 33.1| [config](https://github.com/jshilong/GroupRCNN/tree/main/projects/configs/10_coco/01_student_fcos.py) | [ckpt](https://drive.google.com/file/d/1F8vQ7hp69T3xs51lb6dKxaxB8QsG-H5T/view?usp=sharing) | [log](https://drive.google.com/file/d/1LHbp5LBQEQoFtC5z7qwVIhsHHlX6LDlM/view?usp=sharing) | - 92 | 93 | ### Please cite our paper in your publications if it helps your research. 94 | 95 | ``` 96 | @InProceedings{Zhang_2022_CVPR, 97 | author = {Zhang, Shilong and Yu, Zhuoran and Liu, Liyang and Wang, Xinjiang and Zhou, Aojun and Chen, Kai}, 98 | title = {Group R-CNN for Weakly Semi-Supervised Object Detection With Points}, 99 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 100 | month = {June}, 101 | year = {2022}, 102 | pages = {9417-9426} 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /figs/grouprcnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jshilong/GroupRCNN/5e9fe03bef7319555088ad6e9b3804b5d4416ec5/figs/grouprcnn.png -------------------------------------------------------------------------------- /projects/api/test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import time 3 | 4 | import mmcv 5 | import torch 6 | import torch.distributed as dist 7 | from mmcv.image import tensor2imgs 8 | from mmcv.runner import EvalHook, get_dist_info 9 | from mmdet.apis.test import collect_results_cpu, collect_results_gpu 10 | from mmdet.core import DistEvalHook 11 | from torch.nn.modules.batchnorm import _BatchNorm 12 | 13 | 14 | def pointdet_single_gpu_test(model, 15 | data_loader, 16 | show=False, 17 | out_dir=None, 18 | show_score_thr=0.3): 19 | model.eval() 20 | results = [] 21 | dataset = data_loader.dataset 22 | prog_bar = mmcv.ProgressBar(len(dataset)) 23 | for i, data in enumerate(data_loader): 24 | with torch.no_grad(): 25 | result = model(return_loss=False, rescale=True, **data) 26 | 27 | batch_size = len(result) 28 | if show or out_dir: 29 | if batch_size == 1 and isinstance(data['img'][0], torch.Tensor): 30 | img_tensor = data['img'][0] 31 | else: 32 | img_tensor = data['img'][0].data[0] 33 | img_metas = data['img_metas'][0].data[0] 34 | imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) 35 | assert len(imgs) == len(img_metas) 36 | 37 | for i, (img, img_meta) in enumerate(zip(imgs, img_metas)): 38 | h, w, _ = img_meta['img_shape'] 39 | img_show = img[:h, :w, :] 40 | 41 | ori_h, ori_w = img_meta['ori_shape'][:-1] 42 | img_show = mmcv.imresize(img_show, (ori_w, ori_h)) 43 | 44 | if out_dir: 45 | out_file = osp.join(out_dir, img_meta['ori_filename']) 46 | else: 47 | out_file = None 48 | 49 | model.module.show_result(img_show, 50 | result[i], 51 | show=show, 52 | out_file=out_file, 53 | score_thr=show_score_thr) 54 | 55 | results.extend(result) 56 | 57 | for _ in range(batch_size): 58 | prog_bar.update() 59 | return results 60 | 61 | 62 | def pointdet_multi_gpu_test(model, 63 | data_loader, 64 | tmpdir=None, 65 | gpu_collect=False): 66 | """Test model with multiple gpus. 67 | 68 | This method tests model with multiple gpus and collects the results 69 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' 70 | it encodes results to gpu tensors and use gpu communication for results 71 | collection. On cpu mode it saves the results on different gpus to 'tmpdir' 72 | and collects them by the rank 0 worker. 73 | 74 | Args: 75 | model (nn.Module): Model to be tested. 76 | data_loader (nn.Dataloader): Pytorch data loader. 77 | tmpdir (str): Path of directory to save the temporary results from 78 | different gpus under cpu mode. 79 | gpu_collect (bool): Option to use either gpu or cpu to collect results. 80 | 81 | Returns: 82 | list: The prediction results. 83 | """ 84 | model.eval() 85 | results = [] 86 | dataset = data_loader.dataset 87 | rank, world_size = get_dist_info() 88 | if rank == 0: 89 | prog_bar = mmcv.ProgressBar(len(dataset)) 90 | time.sleep(2) # This line can prevent deadlock problem in some cases. 91 | for i, data in enumerate(data_loader): 92 | with torch.no_grad(): 93 | result = model(return_loss=False, rescale=True, **data) 94 | results.extend(result) 95 | 96 | if rank == 0: 97 | batch_size = len(result) 98 | for _ in range(batch_size * world_size): 99 | prog_bar.update() 100 | 101 | # collect results from all ranks 102 | if gpu_collect: 103 | results = collect_results_gpu(results, len(dataset)) 104 | else: 105 | results = collect_results_cpu(results, len(dataset), tmpdir) 106 | return results 107 | 108 | 109 | def _calc_dynamic_intervals(start_interval, dynamic_interval_list): 110 | assert mmcv.is_list_of(dynamic_interval_list, tuple) 111 | 112 | dynamic_milestones = [0] 113 | dynamic_milestones.extend( 114 | [dynamic_interval[0] for dynamic_interval in dynamic_interval_list]) 115 | dynamic_intervals = [start_interval] 116 | dynamic_intervals.extend( 117 | [dynamic_interval[1] for dynamic_interval in dynamic_interval_list]) 118 | return dynamic_milestones, dynamic_intervals 119 | 120 | 121 | class PointdetEvalHook(EvalHook): 122 | def _do_evaluate(self, runner): 123 | """perform evaluation and save ckpt.""" 124 | if not self._should_evaluate(runner): 125 | return 126 | 127 | results = pointdet_single_gpu_test(runner.model, 128 | self.dataloader, 129 | show=False) 130 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 131 | key_score = self.evaluate(runner, results) 132 | if self.save_best: 133 | self._save_ckpt(runner, key_score) 134 | 135 | 136 | class PointdetDistEvalHook(DistEvalHook): 137 | def _do_evaluate(self, runner): 138 | """perform evaluation and save ckpt.""" 139 | # Synchronization of BatchNorm's buffer (running_mean 140 | # and running_var) is not supported in the DDP of pytorch, 141 | # which may cause the inconsistent performance of models in 142 | # different ranks, so we broadcast BatchNorm's buffers 143 | # of rank 0 to other ranks to avoid this. 144 | if self.broadcast_bn_buffer: 145 | model = runner.model 146 | for name, module in model.named_modules(): 147 | if isinstance(module, 148 | _BatchNorm) and module.track_running_stats: 149 | dist.broadcast(module.running_var, 0) 150 | dist.broadcast(module.running_mean, 0) 151 | 152 | if not self._should_evaluate(runner): 153 | return 154 | 155 | tmpdir = self.tmpdir 156 | if tmpdir is None: 157 | tmpdir = osp.join(runner.work_dir, '.eval_hook') 158 | 159 | results = pointdet_multi_gpu_test(runner.model, 160 | self.dataloader, 161 | tmpdir=tmpdir, 162 | gpu_collect=self.gpu_collect) 163 | if runner.rank == 0: 164 | print('\n') 165 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 166 | key_score = self.evaluate(runner, results) 167 | 168 | if self.save_best: 169 | self._save_ckpt(runner, key_score) 170 | -------------------------------------------------------------------------------- /projects/configs/10_coco/01_student_fcos.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/coco_detection.py', 3 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' 4 | ] 5 | 6 | full_ann_ratio = 0.1 7 | model = dict( 8 | type='FCOS', 9 | backbone=dict(type='ResNet', 10 | depth=50, 11 | num_stages=4, 12 | out_indices=(0, 1, 2, 3), 13 | frozen_stages=1, 14 | norm_cfg=dict(type='BN', requires_grad=False), 15 | norm_eval=True, 16 | style='caffe', 17 | init_cfg=dict( 18 | type='Pretrained', 19 | checkpoint='open-mmlab://detectron/resnet50_caffe')), 20 | neck=dict( 21 | type='FPN', 22 | in_channels=[256, 512, 1024, 2048], 23 | out_channels=256, 24 | start_level=1, 25 | add_extra_convs='on_output', # use P5 26 | num_outs=5, 27 | relu_before_extra_convs=True), 28 | bbox_head=dict(type='FCOSHead', 29 | num_classes=80, 30 | in_channels=256, 31 | stacked_convs=4, 32 | feat_channels=256, 33 | strides=[8, 16, 32, 64, 128], 34 | loss_cls=dict(type='FocalLoss', 35 | use_sigmoid=True, 36 | gamma=2.0, 37 | alpha=0.25, 38 | loss_weight=1.0), 39 | loss_bbox=dict(type='IoULoss', loss_weight=1.0), 40 | loss_centerness=dict(type='CrossEntropyLoss', 41 | use_sigmoid=True, 42 | loss_weight=1.0)), 43 | # training and testing settings 44 | train_cfg=dict(assigner=dict(type='MaxIoUAssigner', 45 | pos_iou_thr=0.5, 46 | neg_iou_thr=0.4, 47 | min_pos_iou=0, 48 | ignore_iof_thr=-1), 49 | allowed_border=-1, 50 | pos_weight=-1, 51 | debug=False), 52 | test_cfg=dict(nms_pre=1000, 53 | min_bbox_size=0, 54 | score_thr=0.05, 55 | nms=dict(type='nms', iou_threshold=0.5), 56 | max_per_img=100)) 57 | img_norm_cfg = dict(mean=[102.9801, 115.9465, 122.7717], 58 | std=[1.0, 1.0, 1.0], 59 | to_rgb=False) 60 | img_norm_cfg = dict(mean=[102.9801, 115.9465, 122.7717], 61 | std=[1.0, 1.0, 1.0], 62 | to_rgb=False) 63 | train_pipeline = [ 64 | dict(type='LoadImageFromFile'), 65 | dict(type='LoadAnnotations', with_bbox=True), 66 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 67 | dict(type='RandomFlip', flip_ratio=0.5), 68 | dict(type='Normalize', **img_norm_cfg), 69 | dict(type='Pad', size_divisor=32), 70 | dict(type='DefaultFormatBundle'), 71 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 72 | ] 73 | test_pipeline = [ 74 | dict(type='LoadImageFromFile'), 75 | dict(type='MultiScaleFlipAug', 76 | img_scale=(1333, 800), 77 | flip=False, 78 | transforms=[ 79 | dict(type='Resize', keep_ratio=True), 80 | dict(type='RandomFlip'), 81 | dict(type='Normalize', **img_norm_cfg), 82 | dict(type='Pad', size_divisor=32), 83 | dict(type='ImageToTensor', keys=['img']), 84 | dict(type='Collect', keys=['img']), 85 | ]) 86 | ] 87 | data = dict(samples_per_gpu=2, 88 | workers_per_gpu=2, 89 | train=dict( 90 | type='PointCocoDataset', 91 | need_points=False, 92 | predictions_path='./data/coco/annotations/generated.bbox.json', 93 | full_ann_ratio=full_ann_ratio, 94 | pipeline=train_pipeline), 95 | val=dict(type='CocoDataset', pipeline=test_pipeline), 96 | test=dict(type='PointCocoDataset', 97 | need_points=False, 98 | pipeline=test_pipeline)) 99 | # optimizer 100 | optimizer = dict(lr=0.01, 101 | paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.)) 102 | optimizer_config = dict(_delete_=True, 103 | grad_clip=dict(max_norm=35, norm_type=2)) 104 | # learning policy 105 | lr_config = dict(policy='step', 106 | warmup='constant', 107 | warmup_iters=500, 108 | warmup_ratio=1.0 / 3, 109 | step=[8, 11]) 110 | runner = dict(type='EpochBasedRunner', max_epochs=12) 111 | -------------------------------------------------------------------------------- /projects/configs/10_coco/base.py: -------------------------------------------------------------------------------- 1 | samples_per_gpu = 2 2 | workers_per_gpu = 5 3 | eval_interval = 12 4 | 5 | full_ann_ratio = 0.1 6 | 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', with_bbox=True), 10 | dict(type='RandomFlip', flip_ratio=0.5), 11 | dict( 12 | type='AutoAugment', 13 | policies=[ 14 | [ 15 | dict(type='Resize', 16 | img_scale=[(480, 1333), (512, 1333), (544, 1333), 17 | (576, 1333), (608, 1333), (640, 1333), 18 | (672, 1333), (704, 1333), (736, 1333), 19 | (768, 1333), (800, 1333)], 20 | multiscale_mode='value', 21 | keep_ratio=True) 22 | ], 23 | [ 24 | dict(type='Resize', 25 | img_scale=[(400, 1333), (500, 1333), (600, 1333)], 26 | multiscale_mode='value', 27 | keep_ratio=True), 28 | # process points annotation 29 | dict(type='PointRandomCrop', 30 | crop_type='absolute_range', 31 | crop_size=(384, 600), 32 | allow_negative_crop=False), 33 | dict(type='Resize', 34 | img_scale=[(480, 1333), (512, 1333), (544, 1333), 35 | (576, 1333), (608, 1333), (640, 1333), 36 | (672, 1333), (704, 1333), (736, 1333), 37 | (768, 1333), (800, 1333)], 38 | multiscale_mode='value', 39 | override=True, 40 | keep_ratio=True) 41 | ] 42 | ]), 43 | dict(type='Normalize', 44 | mean=[123.675, 116.28, 103.53], 45 | std=[58.395, 57.12, 57.375], 46 | to_rgb=True), 47 | dict(type='Pad', size_divisor=32), 48 | dict(type='DefaultFormatBundle'), 49 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 50 | ] 51 | data = dict( 52 | samples_per_gpu=samples_per_gpu, 53 | workers_per_gpu=workers_per_gpu, 54 | train=dict(type='PointCocoDataset', 55 | full_ann_ratio=0.1, 56 | ann_file='data/coco/annotations/instances_train2017.json', 57 | img_prefix='data/coco/train2017/', 58 | pipeline=train_pipeline), 59 | val=dict(type='PointCocoDataset', 60 | ann_file='data/coco/annotations/instances_val2017.json', 61 | img_prefix='data/coco/val2017/', 62 | pipeline=[ 63 | dict(type='LoadImageFromFile'), 64 | dict(type='LoadAnnotations', with_bbox=True), 65 | dict(type='MultiScaleFlipAug', 66 | img_scale=(1333, 800), 67 | flip=False, 68 | transforms=[ 69 | dict(type='Resize', keep_ratio=True), 70 | dict(type='RandomFlip'), 71 | dict(type='Normalize', 72 | mean=[123.675, 116.28, 103.53], 73 | std=[58.395, 57.12, 57.375], 74 | to_rgb=True), 75 | dict(type='Pad', size_divisor=32), 76 | dict(type='DefaultFormatBundle'), 77 | dict(type='Collect', 78 | keys=['img', 'gt_bboxes', 'gt_labels']) 79 | ]) 80 | ]), 81 | test=dict(type='PointCocoDataset', 82 | ann_file='data/coco/annotations/instances_train2017.json', 83 | img_prefix='data/coco/train2017/', 84 | pipeline=[ 85 | dict(type='LoadImageFromFile'), 86 | dict(type='LoadAnnotations', with_bbox=True), 87 | dict(type='MultiScaleFlipAug', 88 | img_scale=(1333, 800), 89 | flip=False, 90 | transforms=[ 91 | dict(type='Resize', keep_ratio=True), 92 | dict(type='RandomFlip'), 93 | dict(type='Normalize', 94 | mean=[123.675, 116.28, 103.53], 95 | std=[58.395, 57.12, 57.375], 96 | to_rgb=True), 97 | dict(type='Pad', size_divisor=32), 98 | dict(type='DefaultFormatBundle'), 99 | dict(type='Collect', 100 | keys=['img', 'gt_bboxes', 'gt_labels']) 101 | ]) 102 | ])) 103 | evaluation = dict(interval=eval_interval, metric='bbox') 104 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 105 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) 106 | lr_config = dict(policy='step', 107 | warmup='linear', 108 | warmup_iters=500, 109 | warmup_ratio=0.001, 110 | step=[16, 22]) 111 | runner = dict(type='EpochBasedRunner', max_epochs=24) 112 | checkpoint_config = dict(interval=24) 113 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 114 | custom_hooks = [dict(type='NumClassCheckHook')] 115 | dist_params = dict(backend='nccl') 116 | log_level = 'INFO' 117 | load_from = None 118 | resume_from = None 119 | workflow = [('train', 1)] 120 | work_dir = './work_dirs/cascade_rcnn_r50_fpn_1x_coco' 121 | gpu_ids = range(0, 1) 122 | -------------------------------------------------------------------------------- /projects/configs/10_coco/group_rcnn_24e_10_percent_coco_detr_augmentation.py: -------------------------------------------------------------------------------- 1 | _base_ = 'base.py' 2 | model = dict( 3 | type='GroupRCNN', 4 | backbone=dict(type='ResNet', 5 | depth=50, 6 | num_stages=4, 7 | out_indices=(0, 1, 2, 3), 8 | frozen_stages=1, 9 | norm_cfg=dict(type='BN', requires_grad=True), 10 | norm_eval=True, 11 | style='pytorch', 12 | init_cfg=dict(type='Pretrained', 13 | checkpoint='torchvision://resnet50')), 14 | neck=dict(type='FPN', 15 | in_channels=[256, 512, 1024, 2048], 16 | out_channels=256, 17 | start_level=1, 18 | add_extra_convs='on_input', 19 | num_outs=5), 20 | rpn_head=dict(type='RetinaHead', 21 | num_classes=80, 22 | in_channels=256, 23 | stacked_convs=4, 24 | feat_channels=256, 25 | anchor_generator=dict(type='AnchorGenerator', 26 | octave_base_scale=4, 27 | scales_per_octave=3, 28 | ratios=[0.5, 1.0, 2.0], 29 | strides=[8, 16, 32, 64, 128]), 30 | bbox_coder=dict(type='DeltaXYWHBBoxCoder', 31 | target_means=[.0, .0, .0, .0], 32 | target_stds=[1.0, 1.0, 1.0, 1.0]), 33 | loss_cls=dict(type='FocalLoss', 34 | use_sigmoid=True, 35 | gamma=2.0, 36 | alpha=0.25, 37 | loss_weight=1.0), 38 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 39 | roi_head=dict(type='GroupRoIHead', 40 | num_stages=3, 41 | stage_loss_weights=[1, 0.5, 0.25], 42 | bbox_roi_extractor=dict(type='SingleRoIExtractor', 43 | roi_layer=dict(type='RoIAlign', 44 | output_size=7, 45 | sampling_ratio=0), 46 | out_channels=256, 47 | featmap_strides=[8, 16, 32, 64]), 48 | bbox_head=[ 49 | dict(type='Shared2FCBBoxHead', 50 | in_channels=256, 51 | fc_out_channels=1024, 52 | roi_feat_size=7, 53 | num_classes=80, 54 | bbox_coder=dict(type='DeltaXYWHBBoxCoder', 55 | target_means=[0.0, 0.0, 0.0, 0.0], 56 | target_stds=[0.1, 0.1, 0.2, 0.2]), 57 | reg_class_agnostic=True, 58 | loss_cls=dict(type='CrossEntropyLoss', 59 | use_sigmoid=False, 60 | loss_weight=1.0), 61 | loss_bbox=dict(type='SmoothL1Loss', 62 | beta=1.0, 63 | loss_weight=1.0)), 64 | dict(type='Shared2FCBBoxHead', 65 | in_channels=256, 66 | fc_out_channels=1024, 67 | roi_feat_size=7, 68 | num_classes=80, 69 | bbox_coder=dict(type='DeltaXYWHBBoxCoder', 70 | target_means=[0.0, 0.0, 0.0, 0.0], 71 | target_stds=[0.05, 0.05, 0.1, 0.1]), 72 | reg_class_agnostic=True, 73 | loss_cls=dict(type='CrossEntropyLoss', 74 | use_sigmoid=False, 75 | loss_weight=1.0), 76 | loss_bbox=dict(type='SmoothL1Loss', 77 | beta=1.0, 78 | loss_weight=1.0)), 79 | dict(type='Shared2FCBBoxHead', 80 | in_channels=256, 81 | fc_out_channels=1024, 82 | roi_feat_size=7, 83 | num_classes=80, 84 | bbox_coder=dict( 85 | type='DeltaXYWHBBoxCoder', 86 | target_means=[0.0, 0.0, 0.0, 0.0], 87 | target_stds=[0.033, 0.033, 0.067, 0.067]), 88 | reg_class_agnostic=True, 89 | loss_cls=dict(type='CrossEntropyLoss', 90 | use_sigmoid=False, 91 | loss_weight=1.0), 92 | loss_bbox=dict(type='SmoothL1Loss', 93 | beta=1.0, 94 | loss_weight=1.0)) 95 | ]), 96 | train_cfg=dict(rpn=dict(assigner=dict(type='MaxIoUAssigner', 97 | pos_iou_thr=0.5, 98 | neg_iou_thr=0.4, 99 | min_pos_iou=0, 100 | ignore_iof_thr=-1), 101 | allowed_border=-1, 102 | pos_weight=-1, 103 | debug=False), 104 | rpn_proposal=None, 105 | rcnn=None), 106 | # only used to evaluate the rpn results 107 | # do not be adopted when generate the proposals for second stage. 108 | test_cfg=dict(rpn=dict(nms_pre=1000, 109 | max_per_img=100, 110 | nms=dict(type='nms', iou_threshold=0.5), 111 | score_thr=0.05, 112 | min_bbox_size=0), 113 | rcnn=None)) 114 | data = dict(test=dict(type='PointCocoDataset', 115 | ann_file='data/coco/annotations/instances_val2017.json', 116 | img_prefix='data/coco/val2017/', 117 | pipeline=[ 118 | dict(type='LoadImageFromFile'), 119 | dict(type='LoadAnnotations', with_bbox=True), 120 | dict(type='MultiScaleFlipAug', 121 | img_scale=(1333, 800), 122 | flip=False, 123 | transforms=[ 124 | dict(type='Resize', keep_ratio=True), 125 | dict(type='RandomFlip'), 126 | dict(type='Normalize', 127 | mean=[123.675, 116.28, 103.53], 128 | std=[58.395, 57.12, 57.375], 129 | to_rgb=True), 130 | dict(type='Pad', size_divisor=32), 131 | dict(type='DefaultFormatBundle'), 132 | dict(type='Collect', 133 | keys=['img', 'gt_bboxes', 'gt_labels']) 134 | ]) 135 | ])) 136 | -------------------------------------------------------------------------------- /projects/configs/10_coco/group_rcnn_50e_10_percent_coco_detr_augmentation.py: -------------------------------------------------------------------------------- 1 | _base_ = 'base.py' 2 | model = dict( 3 | type='GroupRCNN', 4 | backbone=dict(type='ResNet', 5 | depth=50, 6 | num_stages=4, 7 | out_indices=(0, 1, 2, 3), 8 | frozen_stages=1, 9 | norm_cfg=dict(type='BN', requires_grad=True), 10 | norm_eval=True, 11 | style='pytorch', 12 | init_cfg=dict(type='Pretrained', 13 | checkpoint='torchvision://resnet50')), 14 | neck=dict(type='FPN', 15 | in_channels=[256, 512, 1024, 2048], 16 | out_channels=256, 17 | start_level=1, 18 | add_extra_convs='on_input', 19 | num_outs=5), 20 | rpn_head=dict(type='RetinaHead', 21 | num_classes=80, 22 | in_channels=256, 23 | stacked_convs=4, 24 | feat_channels=256, 25 | anchor_generator=dict(type='AnchorGenerator', 26 | octave_base_scale=4, 27 | scales_per_octave=3, 28 | ratios=[0.5, 1.0, 2.0], 29 | strides=[8, 16, 32, 64, 128]), 30 | bbox_coder=dict(type='DeltaXYWHBBoxCoder', 31 | target_means=[.0, .0, .0, .0], 32 | target_stds=[1.0, 1.0, 1.0, 1.0]), 33 | loss_cls=dict(type='FocalLoss', 34 | use_sigmoid=True, 35 | gamma=2.0, 36 | alpha=0.25, 37 | loss_weight=1.0), 38 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 39 | roi_head=dict(type='GroupRoIHead', 40 | num_stages=3, 41 | stage_loss_weights=[1, 0.5, 0.25], 42 | bbox_roi_extractor=dict(type='SingleRoIExtractor', 43 | roi_layer=dict(type='RoIAlign', 44 | output_size=7, 45 | sampling_ratio=0), 46 | out_channels=256, 47 | featmap_strides=[8, 16, 32, 64]), 48 | bbox_head=[ 49 | dict(type='Shared2FCBBoxHead', 50 | in_channels=256, 51 | fc_out_channels=1024, 52 | roi_feat_size=7, 53 | num_classes=80, 54 | bbox_coder=dict(type='DeltaXYWHBBoxCoder', 55 | target_means=[0.0, 0.0, 0.0, 0.0], 56 | target_stds=[0.1, 0.1, 0.2, 0.2]), 57 | reg_class_agnostic=True, 58 | loss_cls=dict(type='CrossEntropyLoss', 59 | use_sigmoid=False, 60 | loss_weight=1.0), 61 | loss_bbox=dict(type='SmoothL1Loss', 62 | beta=1.0, 63 | loss_weight=1.0)), 64 | dict(type='Shared2FCBBoxHead', 65 | in_channels=256, 66 | fc_out_channels=1024, 67 | roi_feat_size=7, 68 | num_classes=80, 69 | bbox_coder=dict(type='DeltaXYWHBBoxCoder', 70 | target_means=[0.0, 0.0, 0.0, 0.0], 71 | target_stds=[0.05, 0.05, 0.1, 0.1]), 72 | reg_class_agnostic=True, 73 | loss_cls=dict(type='CrossEntropyLoss', 74 | use_sigmoid=False, 75 | loss_weight=1.0), 76 | loss_bbox=dict(type='SmoothL1Loss', 77 | beta=1.0, 78 | loss_weight=1.0)), 79 | dict(type='Shared2FCBBoxHead', 80 | in_channels=256, 81 | fc_out_channels=1024, 82 | roi_feat_size=7, 83 | num_classes=80, 84 | bbox_coder=dict( 85 | type='DeltaXYWHBBoxCoder', 86 | target_means=[0.0, 0.0, 0.0, 0.0], 87 | target_stds=[0.033, 0.033, 0.067, 0.067]), 88 | reg_class_agnostic=True, 89 | loss_cls=dict(type='CrossEntropyLoss', 90 | use_sigmoid=False, 91 | loss_weight=1.0), 92 | loss_bbox=dict(type='SmoothL1Loss', 93 | beta=1.0, 94 | loss_weight=1.0)) 95 | ]), 96 | train_cfg=dict(rpn=dict(assigner=dict(type='MaxIoUAssigner', 97 | pos_iou_thr=0.5, 98 | neg_iou_thr=0.4, 99 | min_pos_iou=0, 100 | ignore_iof_thr=-1), 101 | allowed_border=-1, 102 | pos_weight=-1, 103 | debug=False), 104 | rpn_proposal=None, 105 | rcnn=None), 106 | # only used to evaluate the rpn results 107 | # do not be adopted when generate the proposals for second stage. 108 | test_cfg=dict(rpn=dict(nms_pre=1000, 109 | max_per_img=100, 110 | nms=dict(type='nms', iou_threshold=0.5), 111 | score_thr=0.05, 112 | min_bbox_size=0), 113 | rcnn=None)) 114 | 115 | lr_config = dict(step=[30, 40]) 116 | runner = dict(type='EpochBasedRunner', max_epochs=50) 117 | checkpoint_config = dict(interval=5) 118 | evaluation = dict(interval=5, metric='bbox') 119 | -------------------------------------------------------------------------------- /projects/configs/_base_/datasets/coco_detection.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CocoDataset' 3 | data_root = 'data/coco/' 4 | img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], 5 | std=[58.395, 57.12, 57.375], 6 | to_rgb=True) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', with_bbox=True), 10 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 11 | dict(type='RandomFlip', flip_ratio=0.5), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size_divisor=32), 14 | dict(type='DefaultFormatBundle'), 15 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 16 | ] 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict(type='MultiScaleFlipAug', 20 | img_scale=(1333, 800), 21 | flip=False, 22 | transforms=[ 23 | dict(type='Resize', keep_ratio=True), 24 | dict(type='RandomFlip'), 25 | dict(type='Normalize', **img_norm_cfg), 26 | dict(type='Pad', size_divisor=32), 27 | dict(type='ImageToTensor', keys=['img']), 28 | dict(type='Collect', keys=['img']), 29 | ]) 30 | ] 31 | data = dict( 32 | samples_per_gpu=2, 33 | workers_per_gpu=2, 34 | train=dict(type=dataset_type, 35 | ann_file=data_root + 'annotations/instances_train2017.json', 36 | img_prefix=data_root + 'train2017/', 37 | pipeline=train_pipeline), 38 | val=dict(type=dataset_type, 39 | ann_file=data_root + 'annotations/instances_val2017.json', 40 | img_prefix=data_root + 'val2017/', 41 | pipeline=test_pipeline), 42 | test=dict(type=dataset_type, 43 | ann_file=data_root + 'annotations/instances_val2017.json', 44 | img_prefix=data_root + 'val2017/', 45 | pipeline=test_pipeline)) 46 | evaluation = dict(interval=1, metric='bbox') 47 | -------------------------------------------------------------------------------- /projects/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type='TextLoggerHook'), 7 | # dict(type='TensorboardLoggerHook') 8 | ]) 9 | # yapf:enable 10 | custom_hooks = [dict(type='NumClassCheckHook')] 11 | 12 | dist_params = dict(backend='nccl') 13 | log_level = 'INFO' 14 | load_from = None 15 | resume_from = None 16 | workflow = [('train', 1)] 17 | -------------------------------------------------------------------------------- /projects/configs/_base_/schedules/schedule_1x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', 6 | warmup='linear', 7 | warmup_iters=500, 8 | warmup_ratio=0.001, 9 | step=[8, 11]) 10 | runner = dict(type='EpochBasedRunner', max_epochs=12) 11 | -------------------------------------------------------------------------------- /projects/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Group R-CNN 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | # Written by Shilong Zhang 5 | # -------------------------------------------------------- 6 | from .point_coco import PointCocoDataset 7 | from .transform import PointRandomCrop 8 | 9 | __all__ = ['PointCocoDataset', 'PointRandomCrop'] 10 | -------------------------------------------------------------------------------- /projects/datasets/builder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Group R-CNN 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | # Written by Shilong Zhang 5 | # -------------------------------------------------------- 6 | from collections.abc import Mapping, Sequence 7 | from functools import partial 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from mmcv.parallel import collate 12 | from mmcv.parallel.data_container import DataContainer 13 | from mmcv.runner import get_dist_info 14 | from mmdet.datasets.builder import worker_init_fn 15 | from mmdet.datasets.samplers import (DistributedGroupSampler, 16 | DistributedSampler, GroupSampler) 17 | from torch.utils.data import DataLoader 18 | from torch.utils.data.dataloader import default_collate 19 | 20 | 21 | def multi_collate_fn(batch, samples_per_gpu=1): 22 | """Puts each data field into a tensor/DataContainer with outer dimension 23 | batch size. This is mainly used in query_support dataloader. The main 24 | difference with the :func:`collate_fn` in mmcv is it can process 25 | list[list[DataContainer]]. 26 | 27 | Extend default_collate to add support for 28 | :type:`~mmcv.parallel.DataContainer`. There are 3 cases. 29 | 30 | 1. cpu_only = True, e.g., meta data. 31 | 2. cpu_only = False, stack = True, e.g., images tensors. 32 | 3. cpu_only = False, stack = False, e.g., gt bboxes. 33 | 34 | Args: 35 | batch (list[list[:obj:`mmcv.parallel.DataContainer`]] | 36 | list[:obj:`mmcv.parallel.DataContainer`]): Data of 37 | single batch. 38 | samples_per_gpu (int): The number of samples of single GPU. 39 | """ 40 | 41 | if not isinstance(batch, Sequence): 42 | raise TypeError(f'{batch.dtype} is not supported.') 43 | 44 | # This is usually a case in query_support dataloader, which 45 | # the :func:`__getitem__` of dataset return more than one images. 46 | # Here we process the support batch data in type of 47 | # List: [ List: [ DataContainer]] 48 | if isinstance(batch[0], Sequence): 49 | samples_per_gpu = len(batch[0]) * samples_per_gpu 50 | batch = sum(batch, []) 51 | if isinstance(batch[0], DataContainer): 52 | stacked = [] 53 | if batch[0].cpu_only: 54 | for i in range(0, len(batch), samples_per_gpu): 55 | stacked.append( 56 | [sample.data for sample in batch[i:i + samples_per_gpu]]) 57 | return DataContainer(stacked, 58 | batch[0].stack, 59 | batch[0].padding_value, 60 | cpu_only=True) 61 | elif batch[0].stack: 62 | for i in range(0, len(batch), samples_per_gpu): 63 | assert isinstance(batch[i].data, torch.Tensor) 64 | 65 | if batch[i].pad_dims is not None: 66 | ndim = batch[i].dim() 67 | assert ndim > batch[i].pad_dims 68 | max_shape = [0 for _ in range(batch[i].pad_dims)] 69 | for dim in range(1, batch[i].pad_dims + 1): 70 | max_shape[dim - 1] = batch[i].size(-dim) 71 | for sample in batch[i:i + samples_per_gpu]: 72 | for dim in range(0, ndim - batch[i].pad_dims): 73 | assert batch[i].size(dim) == sample.size(dim) 74 | for dim in range(1, batch[i].pad_dims + 1): 75 | max_shape[dim - 1] = max(max_shape[dim - 1], 76 | sample.size(-dim)) 77 | padded_samples = [] 78 | for sample in batch[i:i + samples_per_gpu]: 79 | pad = [0 for _ in range(batch[i].pad_dims * 2)] 80 | for dim in range(1, batch[i].pad_dims + 1): 81 | pad[2 * dim - 82 | 1] = max_shape[dim - 1] - sample.size(-dim) 83 | padded_samples.append( 84 | F.pad(sample.data, pad, 85 | value=sample.padding_value)) 86 | stacked.append(default_collate(padded_samples)) 87 | elif batch[i].pad_dims is None: 88 | stacked.append( 89 | default_collate([ 90 | sample.data 91 | for sample in batch[i:i + samples_per_gpu] 92 | ])) 93 | else: 94 | raise ValueError( 95 | 'pad_dims should be either None or integers (1-3)') 96 | 97 | else: 98 | for i in range(0, len(batch), samples_per_gpu): 99 | stacked.append( 100 | [sample.data for sample in batch[i:i + samples_per_gpu]]) 101 | return DataContainer(stacked, batch[0].stack, batch[0].padding_value) 102 | elif isinstance(batch[0], Sequence): 103 | transposed = zip(*batch) 104 | return [collate(samples, samples_per_gpu) for samples in transposed] 105 | elif isinstance(batch[0], Mapping): 106 | return { 107 | key: collate([d[key] for d in batch], samples_per_gpu) 108 | for key in batch[0] 109 | } 110 | else: 111 | return default_collate(batch) 112 | 113 | 114 | def build_point_dataloader(dataset, 115 | samples_per_gpu, 116 | workers_per_gpu, 117 | num_gpus=1, 118 | dist=True, 119 | shuffle=True, 120 | seed=None, 121 | **kwargs): 122 | """Build PyTorch DataLoader. 123 | 124 | In distributed training, each GPU/process has a dataloader. 125 | In non-distributed training, there is only one dataloader for all GPUs. 126 | 127 | Args: 128 | dataset (Dataset): A PyTorch dataset. 129 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 130 | batch size of each GPU. 131 | workers_per_gpu (int): How many subprocesses to use for data loading 132 | for each GPU. 133 | num_gpus (int): Number of GPUs. Only used in non-distributed training. 134 | dist (bool): Distributed training/test or not. Default: True. 135 | shuffle (bool): Whether to shuffle the data at every epoch. 136 | Default: True. 137 | kwargs: any keyword argument to be used to initialize DataLoader 138 | 139 | Returns: 140 | DataLoader: A PyTorch dataloader. 141 | """ 142 | rank, world_size = get_dist_info() 143 | if dist: 144 | # DistributedGroupSampler will definitely shuffle the data to satisfy 145 | # that images on each GPU are in the same group 146 | if shuffle: 147 | sampler = DistributedGroupSampler(dataset, 148 | samples_per_gpu, 149 | world_size, 150 | rank, 151 | seed=seed) 152 | else: 153 | sampler = DistributedSampler(dataset, 154 | world_size, 155 | rank, 156 | shuffle=False, 157 | seed=seed) 158 | batch_size = samples_per_gpu 159 | num_workers = workers_per_gpu 160 | else: 161 | sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None 162 | batch_size = num_gpus * samples_per_gpu 163 | num_workers = num_gpus * workers_per_gpu 164 | 165 | init_fn = partial( 166 | worker_init_fn, num_workers=num_workers, rank=rank, 167 | seed=seed) if seed is not None else None 168 | 169 | data_loader = DataLoader(dataset, 170 | batch_size=batch_size, 171 | sampler=sampler, 172 | num_workers=num_workers, 173 | collate_fn=partial( 174 | multi_collate_fn, 175 | samples_per_gpu=samples_per_gpu), 176 | pin_memory=False, 177 | worker_init_fn=init_fn, 178 | **kwargs) 179 | 180 | return data_loader 181 | -------------------------------------------------------------------------------- /projects/datasets/point_coco.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Group R-CNN 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | # Written by Shilong Zhang 5 | # -------------------------------------------------------- 6 | 7 | import itertools 8 | import json 9 | import logging 10 | import os.path as osp 11 | import warnings 12 | from collections import OrderedDict, defaultdict 13 | from contextlib import contextmanager 14 | 15 | import mmcv 16 | import numpy as np 17 | from mmcv.utils import print_log 18 | from mmdet.datasets.api_wrappers import COCOeval 19 | from mmdet.datasets.builder import DATASETS 20 | from mmdet.datasets.coco import CocoDataset 21 | from mmdet.datasets.pipelines import Compose 22 | from mmdet.utils import get_root_logger 23 | from terminaltables import AsciiTable 24 | 25 | 26 | @contextmanager 27 | def local_numpy_seed(seed=None): 28 | """Run numpy codes with a local random seed. 29 | 30 | If seed is None, the default random state will be used. 31 | """ 32 | state = np.random.get_state() 33 | if seed is not None: 34 | np.random.seed(seed) 35 | try: 36 | yield 37 | finally: 38 | np.random.set_state(state) 39 | 40 | 41 | INF = 1e8 42 | 43 | 44 | @DATASETS.register_module() 45 | class PointCocoDataset(CocoDataset): 46 | def __init__(self, 47 | *args, 48 | seed=0, 49 | full_ann_ratio=0.1, 50 | student_thr=0, 51 | predictions_path=None, 52 | need_points=True, 53 | **kwargs): 54 | """Dataset of COCODataset with point annotations. 55 | 56 | Args: 57 | seed (int): The seed to split the dataset. Default: 0. 58 | full_ann_ratio (float): The ratio of images with bbox annotations 59 | used in traning. Default 0.1. 60 | student_thr (float): The threshold of predictions. Default: 0. 61 | predictions_path (str): Path of json that contains predictions 62 | of Group R-CNN. Defaults to None. 63 | need_points (bool): Control whether return the point annotations. 64 | Defaults to True. 65 | """ 66 | self.student_thr = student_thr 67 | self.seed = seed 68 | 69 | # `predictions_path` Only be used when finetune the student 70 | # model, we replace the 71 | # corresponding bbox annotation with the predictions of 72 | # Group R-CNN 73 | self.predictions_path = predictions_path 74 | 75 | self.need_points = need_points 76 | if self.predictions_path: 77 | with open(self.predictions_path, 'r') as f: 78 | student_dataset = json.load(f) 79 | 80 | self.student_anns = defaultdict(list) 81 | for item in student_dataset: 82 | self.student_anns[item['image_id']].append(item) 83 | 84 | self.full_ann_ratio = full_ann_ratio 85 | 86 | self.super_init(*args, **kwargs) 87 | if self.test_mode: 88 | self.full_ann_ratio = 0 89 | self.split_data(self.data_infos) 90 | 91 | # set group flag for the sampler 92 | if not self.test_mode: 93 | self._set_group_flag() 94 | 95 | def super_init(self, 96 | ann_file, 97 | pipeline, 98 | classes=None, 99 | data_root=None, 100 | img_prefix='', 101 | seg_prefix=None, 102 | proposal_file=None, 103 | test_mode=False, 104 | filter_empty_gt=True, 105 | **kwargs): 106 | self.ann_file = ann_file 107 | self.data_root = data_root 108 | self.img_prefix = img_prefix 109 | self.seg_prefix = seg_prefix 110 | self.proposal_file = proposal_file 111 | self.test_mode = test_mode 112 | self.filter_empty_gt = filter_empty_gt 113 | self.CLASSES = self.get_classes(classes) 114 | 115 | # join paths if data_root is specified 116 | if self.data_root is not None: 117 | if not osp.isabs(self.ann_file): 118 | self.ann_file = osp.join(self.data_root, self.ann_file) 119 | if not (self.img_prefix is None or osp.isabs(self.img_prefix)): 120 | self.img_prefix = osp.join(self.data_root, self.img_prefix) 121 | if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)): 122 | self.seg_prefix = osp.join(self.data_root, self.seg_prefix) 123 | if not (self.proposal_file is None 124 | or osp.isabs(self.proposal_file)): 125 | self.proposal_file = osp.join(self.data_root, 126 | self.proposal_file) 127 | # load annotations (and proposals) 128 | self.data_infos = self.load_annotations(self.ann_file) 129 | 130 | if self.proposal_file is not None: 131 | self.proposals = self.load_proposals(self.proposal_file) 132 | else: 133 | self.proposals = None 134 | 135 | # filter images too small and containing no annotations 136 | if not test_mode: 137 | valid_inds = self._filter_imgs() 138 | self.data_infos = [self.data_infos[i] for i in valid_inds] 139 | if self.proposals is not None: 140 | self.proposals = [self.proposals[i] for i in valid_inds] 141 | 142 | # processing pipeline 143 | self.pipeline = Compose(pipeline) 144 | 145 | def split_data(self, data_infos): 146 | """Split the dataset to two part,with bbox annotations and Point 147 | Annotations. 148 | 149 | Args: 150 | data_infos (list[dict]): List of image infos. 151 | """ 152 | self._total_length = len(data_infos) 153 | self._bbox_length = int(self.full_ann_ratio * self._total_length) 154 | self._point_length = self._total_length - self._bbox_length 155 | 156 | for img_info in self.data_infos: 157 | img_info['with_bbox_ann'] = False 158 | 159 | with local_numpy_seed(self.seed): 160 | self.bbox_ann_img_idxs = np.random.choice(self._total_length, 161 | size=self._bbox_length, 162 | replace=False) 163 | for img_id in self.bbox_ann_img_idxs: 164 | img_info = self.data_infos[img_id] 165 | img_info['with_bbox_ann'] = True 166 | 167 | def __len__(self): 168 | if self.test_mode or self.predictions_path is not None: 169 | return self._total_length 170 | else: 171 | return self._bbox_length 172 | 173 | def get_ann_info(self, idx, is_bbox_ann=False): 174 | 175 | img_id = self.data_infos[idx]['id'] 176 | ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) 177 | ann_info = self.coco.load_anns(ann_ids) 178 | parsed_anns = self._parse_ann_info(self.data_infos[idx], 179 | ann_info, 180 | is_bbox_ann=is_bbox_ann) 181 | return parsed_anns 182 | 183 | def get_predictions_ann_info(self, idx): 184 | """Replace real bbox annotations with predictions. Only be used when 185 | finetune a Student. 186 | 187 | Args: 188 | idx (int): the index of data_info/ 189 | 190 | Returns: 191 | dict: A dict contains predictions of Group R-CNN 192 | """ 193 | img_id = self.data_infos[idx]['id'] 194 | anns = self.student_anns[img_id] 195 | 196 | gt_bboxes = [] 197 | gt_labels = [] 198 | gt_bboxes_ignore = [] 199 | gt_labels_ignore = [] 200 | 201 | for ann in anns: 202 | score = ann['score'] 203 | if score > self.student_thr: 204 | gt_bboxes.append(ann['bbox']) 205 | gt_labels.append(self.cat2label[ann['category_id']]) 206 | else: 207 | # TODO point ignore 208 | gt_bboxes_ignore.append(ann['bbox']) 209 | gt_labels_ignore.append(self.cat2label[ann['category_id']]) 210 | 211 | if gt_bboxes: 212 | gt_bboxes = np.array(gt_bboxes, dtype=np.float32) 213 | # ccwh to xxyy 214 | gt_bboxes[:, 2:] = gt_bboxes[:, 2:] + gt_bboxes[:, :2] 215 | gt_labels = np.array(gt_labels, dtype=np.int64) 216 | else: 217 | gt_bboxes = np.zeros((0, 4), dtype=np.float32) 218 | gt_labels = np.array([], dtype=np.int64) 219 | if len(gt_bboxes) == 0: 220 | return None 221 | 222 | ann_dict = dict(bboxes=gt_bboxes, labels=gt_labels) 223 | 224 | return ann_dict 225 | 226 | def prepare_img_with_predictions(self, idx): 227 | """Get annotations of image when finetune a student. We replace 228 | original bbox annotations with predictions of Group R-CNN to if the 229 | images is not in division with bbox annotations. 230 | 231 | Args: 232 | idx (int): The index of corresponding image. 233 | 234 | Returns: 235 | dict: dict contains annotations. 236 | """ 237 | if idx in self.bbox_ann_img_idxs: 238 | img_info = self.data_infos[idx] 239 | ann_info = self.get_ann_info(idx) 240 | results = dict(img_info=img_info, ann_info=ann_info) 241 | self.pre_pipeline(results) 242 | else: 243 | # use predictions 244 | img_info = self.data_infos[idx] 245 | ann_info = self.get_predictions_ann_info(idx) 246 | # empty image 247 | if ann_info is None: 248 | return self.prepare_img_with_predictions((idx + 1) % len(self)) 249 | results = dict(img_info=img_info, ann_info=ann_info) 250 | self.pre_pipeline(results) 251 | 252 | return self.pipeline(results) 253 | 254 | def prepare_train_img(self, idx, is_bbox_ann=False): 255 | img_info = self.data_infos[idx] 256 | ann_info = self.get_ann_info(idx, is_bbox_ann=is_bbox_ann) 257 | results = dict(img_info=img_info, ann_info=ann_info) 258 | self.pre_pipeline(results) 259 | return self.pipeline(results) 260 | 261 | def prepare_test_img(self, idx): 262 | # load point annotation when test 263 | img_info = self.data_infos[idx] 264 | ann_info = self.get_ann_info(idx) 265 | results = dict(img_info=img_info, ann_info=ann_info) 266 | self.pre_pipeline(results) 267 | return self.pipeline(results) 268 | 269 | def _parse_ann_info(self, img_info, ann_info, is_bbox_ann=False): 270 | """Parse the annotations to a dict. We random sample a point from bbox 271 | if the image is in the division with bbox annotations. 272 | 273 | Args: 274 | img_info (dict): Dict contains the image information. 275 | ann_info (dict): Dict contains the annotations. 276 | is_bbox_ann (bool): Whether image is in the division that 277 | is with bbox annotations. 278 | 279 | Returns: 280 | dict: A dict contains point annotations and bbox annotations. 281 | """ 282 | gt_bboxes = [] 283 | gt_labels = [] 284 | gt_bboxes_ignore = [] 285 | gt_masks_ann = [] 286 | gt_points = [] 287 | 288 | for i, ann in enumerate(ann_info): 289 | if ann.get('ignore', False): 290 | continue 291 | x1, y1, w, h = ann['bbox'] 292 | inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) 293 | inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) 294 | if inter_w * inter_h == 0: 295 | continue 296 | if ann['area'] <= 0 or w < 1 or h < 1: 297 | continue 298 | if ann['category_id'] not in self.cat_ids: 299 | continue 300 | bbox = [x1, y1, x1 + w, y1 + h] 301 | if ann.get('iscrowd', False): 302 | gt_bboxes_ignore.append(bbox) 303 | else: 304 | gt_bboxes.append(bbox) 305 | gt_labels.append(self.cat2label[ann['category_id']]) 306 | gt_masks_ann.append(ann.get('segmentation', None)) 307 | 308 | # only return bbox annotation when train the student model 309 | if self.predictions_path is None: 310 | # use points in json 311 | if not is_bbox_ann: 312 | x1, y1 = ann['point'][:2] 313 | gt_points.append([x1, y1, x1 + 1, y1 + 1]) 314 | 315 | # random sample points in bbox 316 | else: 317 | # follow point as query 318 | x1 = np.random.uniform(x1 + 0.01 * w, x1 + 0.99 * w) 319 | y1 = np.random.uniform(y1 + 0.01 * h, y1 + 0.99 * h) 320 | # + 1 for vis, only x1, y1 would be actually used 321 | gt_points.append([x1, y1, x1 + 1, y1 + 1]) 322 | 323 | if gt_bboxes: 324 | gt_bboxes = np.array(gt_bboxes, dtype=np.float32) 325 | if not self.predictions_path: 326 | gt_points = np.array(gt_points, dtype=np.float32) 327 | gt_labels = np.array(gt_labels, dtype=np.int64) 328 | else: 329 | gt_bboxes = np.zeros((0, 4), dtype=np.float32) 330 | if not self.predictions_path: 331 | gt_points = np.zeros((0, 4), dtype=np.float32) 332 | gt_labels = np.array([], dtype=np.int64) 333 | 334 | if gt_bboxes_ignore: 335 | gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) 336 | else: 337 | gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) 338 | 339 | if not self.predictions_path and self.need_points: 340 | gt_bboxes = np.concatenate([gt_bboxes, gt_points], axis=0) 341 | 342 | ann = dict( 343 | bboxes=gt_bboxes, 344 | labels=gt_labels, 345 | bboxes_ignore=gt_bboxes_ignore, 346 | ) 347 | 348 | return ann 349 | 350 | def __getitem__(self, idx): 351 | 352 | # val the model with point ann in val dataset 353 | if self.test_mode: 354 | return self.prepare_test_img(idx) 355 | 356 | # train the student model 357 | if self.predictions_path: 358 | return self.prepare_img_with_predictions(idx) 359 | 360 | while True: 361 | bbox_ann_img_idx = self.bbox_ann_img_idxs[idx] 362 | # random sample points from bbox 363 | bbox_data = self.prepare_train_img(bbox_ann_img_idx, 364 | is_bbox_ann=True) 365 | 366 | if bbox_data is None: 367 | idx = (idx + 1) % len(self) 368 | continue 369 | else: 370 | bbox_data['img_metas'].data['mode'] = 'bbox' 371 | break 372 | # TODO remove this 373 | point_datas = [] 374 | point_datas.append(bbox_data) 375 | 376 | return point_datas 377 | 378 | def evaluate(self, 379 | results, 380 | metric='bbox', 381 | logger=None, 382 | jsonfile_prefix=None, 383 | classwise=False, 384 | proposal_nums=(100, 300, 1000), 385 | iou_thrs=None, 386 | metric_items=None): 387 | 388 | if logger is None: 389 | logger = get_root_logger() 390 | if isinstance(results[0], tuple): 391 | main_results = [item[0] for item in results] 392 | rpn_results = [item[1] for item in results] 393 | semi_results = [item[2] for item in results] 394 | semi = True 395 | else: 396 | main_results = results 397 | semi_results = None 398 | rpn_results = None 399 | semi = False 400 | eval_results = OrderedDict() 401 | prefix = ['rpn', 'nms_rpn_topk', 'rcnn'] 402 | for i, results in enumerate([main_results, rpn_results, semi_results]): 403 | if i == 1: 404 | continue 405 | metrics = metric if isinstance(metric, list) else [metric] 406 | allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast'] 407 | for metric in metrics: 408 | if metric not in allowed_metrics: 409 | raise KeyError(f'metric {metric} is not supported') 410 | if iou_thrs is None: 411 | iou_thrs = np.linspace(.5, 412 | 0.95, 413 | int(np.round((0.95 - .5) / .05)) + 1, 414 | endpoint=True) 415 | if metric_items is not None: 416 | if not isinstance(metric_items, list): 417 | metric_items = [metric_items] 418 | 419 | result_files, tmp_dir = self.format_results( 420 | results, jsonfile_prefix) 421 | 422 | cocoGt = self.coco 423 | for metric in metrics: 424 | msg = f'Evaluating {metric}...' 425 | if logger is None: 426 | msg = '\n' + msg 427 | print_log(msg, logger=logger) 428 | 429 | if metric == 'proposal_fast': 430 | ar = self.fast_eval_recall(results, 431 | proposal_nums, 432 | iou_thrs, 433 | logger='silent') 434 | log_msg = [] 435 | for i, num in enumerate(proposal_nums): 436 | eval_results[f'AR@{num}'] = ar[i] 437 | log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}') 438 | log_msg = ''.join(log_msg) 439 | print_log(log_msg, logger=logger) 440 | continue 441 | 442 | iou_type = 'bbox' if metric == 'proposal' else metric 443 | if metric not in result_files: 444 | raise KeyError(f'{metric} is not in results') 445 | try: 446 | predictions = mmcv.load(result_files[metric]) 447 | if iou_type == 'segm': 448 | 449 | for x in predictions: 450 | x.pop('bbox') 451 | warnings.simplefilter('once') 452 | warnings.warn( 453 | 'The key "bbox" is deleted for more ' 454 | 'accurate mask AP ' 455 | 'of small/medium/large instances ' 456 | 'since v2.12.0. This ' 457 | 'does not change the ' 458 | 'overall mAP calculation.', UserWarning) 459 | cocoDt = cocoGt.loadRes(predictions) 460 | except IndexError: 461 | print_log( 462 | 'The testing results of the whole dataset is empty.', 463 | logger=logger, 464 | level=logging.ERROR) 465 | break 466 | 467 | cocoEval = COCOeval(cocoGt, cocoDt, iou_type) 468 | cocoEval.params.catIds = self.cat_ids 469 | cocoEval.params.imgIds = self.img_ids 470 | cocoEval.params.maxDets = list(proposal_nums) 471 | cocoEval.params.iouThrs = iou_thrs 472 | # mapping of cocoEval.stats 473 | coco_metric_names = { 474 | 'mAP': 0, 475 | 'mAP_50': 1, 476 | 'mAP_75': 2, 477 | 'mAP_s': 3, 478 | 'mAP_m': 4, 479 | 'mAP_l': 5, 480 | 'AR@100': 6, 481 | 'AR@300': 7, 482 | 'AR@1000': 8, 483 | 'AR_s@1000': 9, 484 | 'AR_m@1000': 10, 485 | 'AR_l@1000': 11 486 | } 487 | if metric_items is not None: 488 | for metric_item in metric_items: 489 | if metric_item not in coco_metric_names: 490 | raise KeyError( 491 | f'metric item {metric_item} is not supported') 492 | 493 | if metric == 'proposal': 494 | cocoEval.params.useCats = 0 495 | cocoEval.evaluate() 496 | cocoEval.accumulate() 497 | cocoEval.summarize() 498 | if metric_items is None: 499 | metric_items = [ 500 | 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', 501 | 'AR_m@1000', 'AR_l@1000' 502 | ] 503 | 504 | for item in metric_items: 505 | val = float( 506 | f'{cocoEval.stats[coco_metric_names[item]]:.3f}') 507 | eval_results[item] = val 508 | else: 509 | cocoEval.evaluate() 510 | cocoEval.accumulate() 511 | cocoEval.summarize() 512 | if classwise: # Compute per-category AP 513 | # Compute per-category AP 514 | # from https://github.com/facebookresearch/detectron2/ 515 | precisions = cocoEval.eval['precision'] 516 | # precision: (iou, recall, cls, area range, max dets) 517 | assert len(self.cat_ids) == precisions.shape[2] 518 | 519 | results_per_category = [] 520 | for idx, catId in enumerate(self.cat_ids): 521 | # area range index 0: all area ranges 522 | # max dets index -1: typically 100 per image 523 | nm = self.coco.loadCats(catId)[0] 524 | precision = precisions[:, :, idx, 0, -1] 525 | precision = precision[precision > -1] 526 | if precision.size: 527 | ap = np.mean(precision) 528 | else: 529 | ap = float('nan') 530 | results_per_category.append( 531 | (f'{nm["name"]}', f'{float(ap):0.3f}')) 532 | 533 | num_columns = min(6, len(results_per_category) * 2) 534 | results_flatten = list( 535 | itertools.chain(*results_per_category)) 536 | headers = ['category', 'AP'] * (num_columns // 2) 537 | results_2d = itertools.zip_longest(*[ 538 | results_flatten[i::num_columns] 539 | for i in range(num_columns) 540 | ]) 541 | table_data = [headers] 542 | table_data += [result for result in results_2d] 543 | table = AsciiTable(table_data) 544 | print_log('\n' + table.table, logger=logger) 545 | 546 | if metric_items is None: 547 | metric_items = [ 548 | 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 549 | 'mAP_l' 550 | ] 551 | 552 | for metric_item in metric_items: 553 | key = f'{prefix[i]}_{metric}_{metric_item}' 554 | val = float( 555 | f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}' # noqa 556 | ) 557 | eval_results[key] = val 558 | ap = cocoEval.stats[:6] 559 | eval_results[f'{prefix[i]}_{metric}_mAP_copypaste'] = ( 560 | f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' 561 | f'{ap[4]:.3f} {ap[5]:.3f}') 562 | if not semi: 563 | break 564 | if tmp_dir is not None: 565 | tmp_dir.cleanup() 566 | return eval_results 567 | -------------------------------------------------------------------------------- /projects/datasets/transform.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Group R-CNN 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | # Written by Shilong Zhang 5 | # -------------------------------------------------------- 6 | 7 | import numpy as np 8 | from mmdet.datasets.builder import PIPELINES 9 | from mmdet.datasets.pipelines import RandomCrop 10 | 11 | 12 | @PIPELINES.register_module() 13 | class PointRandomCrop(RandomCrop): 14 | def _crop_data(self, results, crop_size, allow_negative_crop): 15 | """Function to randomly crop images, bounding boxes, masks, semantic 16 | segmentation maps. The difference with :Class:RandomCrop is this class 17 | add the process the point annotation. 18 | 19 | Args: 20 | results (dict): Result dict from loading pipeline. 21 | crop_size (tuple): Expected absolute size after cropping, (h, w). 22 | allow_negative_crop (bool): Whether to allow a crop that does not 23 | contain any bbox area. Default to False. 24 | 25 | Returns: 26 | dict: Randomly cropped results, 'img_shape' key in result dict is 27 | updated according to crop size. 28 | """ 29 | assert crop_size[0] > 0 and crop_size[1] > 0 30 | for key in results.get('img_fields', ['img']): 31 | img = results[key] 32 | margin_h = max(img.shape[0] - crop_size[0], 0) 33 | margin_w = max(img.shape[1] - crop_size[1], 0) 34 | offset_h = np.random.randint(0, margin_h + 1) 35 | offset_w = np.random.randint(0, margin_w + 1) 36 | crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] 37 | crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] 38 | 39 | # crop the image 40 | img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] 41 | img_shape = img.shape 42 | results[key] = img 43 | results['img_shape'] = img_shape 44 | 45 | with_bbox_ann = results['img_info']['with_bbox_ann'] 46 | # crop bboxes accordingly and clip to the image boundary 47 | 48 | num_gt = len(results['gt_labels']) 49 | # TODO check this aug 50 | for key in ['gt_bboxes']: 51 | # e.g. gt_bboxes and gt_bboxes_ignore 52 | bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h], 53 | dtype=np.float32) 54 | temp_bboxes = results[key] - bbox_offset 55 | 56 | if with_bbox_ann: 57 | bboxes = temp_bboxes[:num_gt] 58 | points = temp_bboxes[num_gt:] 59 | if self.bbox_clip_border: 60 | bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) 61 | bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) 62 | points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) 63 | points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) 64 | 65 | valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (bboxes[:, 3] > 66 | bboxes[:, 1]) 67 | # judge is there points in bbox 68 | if len(points): 69 | points = points[valid_inds] 70 | # If the crop does not contain any gt-bbox area and 71 | # allow_negative_crop is False, skip this image. 72 | if (key == 'gt_bboxes' and not valid_inds.any() 73 | and not allow_negative_crop): 74 | return None 75 | results[key] = np.concatenate([bboxes[valid_inds, :], points], 76 | axis=0) 77 | # label fields. e.g. gt_labels and gt_labels_ignore 78 | label_key = self.bbox2label.get(key) 79 | if label_key in results: 80 | results[label_key] = results[label_key][valid_inds] 81 | 82 | else: 83 | bboxes = temp_bboxes[:num_gt] 84 | # point here would be dep 85 | # we will sample point in model instead 86 | # of dataset if there is bbox annotation 87 | points = temp_bboxes[num_gt:] 88 | valid_inds = (0 <= points[:, 0] ) & ( points[:, 0]< img_shape[1]) \ 89 | & (0 <= points[:, 1]) & (points[:, 1]< img_shape[0]) # noqa 90 | if self.bbox_clip_border: 91 | bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) 92 | bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) 93 | points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) 94 | points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) 95 | 96 | points = points[valid_inds] 97 | # If the crop does not contain any gt-bbox area and 98 | # allow_negative_crop is False, skip this image. 99 | if (key == 'gt_bboxes' and not valid_inds.any() 100 | and not allow_negative_crop): 101 | return None 102 | results[key] = np.concatenate([bboxes[valid_inds, :], points], 103 | axis=0) 104 | # label fields. e.g. gt_labels and gt_labels_ignore 105 | label_key = self.bbox2label.get(key) 106 | if label_key in results: 107 | results[label_key] = results[label_key][valid_inds] 108 | 109 | return results 110 | -------------------------------------------------------------------------------- /projects/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Group R-CNN 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | # Written by Shilong Zhang 5 | # -------------------------------------------------------- 6 | from .group_rcnn import GroupRCNN 7 | from .group_roi_head import GroupRoIHead 8 | 9 | __all__ = ['GroupRCNN', 'GroupRoIHead'] 10 | -------------------------------------------------------------------------------- /projects/models/group_rcnn.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Group R-CNN 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | # Written by Shilong Zhang 5 | # -------------------------------------------------------- 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from mmcv.ops import batched_nms 10 | from mmdet.core import bbox2result, select_single_mlvl 11 | from mmdet.models.builder import DETECTORS 12 | from mmdet.models.detectors.two_stage import TwoStageDetector 13 | 14 | 15 | @DETECTORS.register_module() 16 | class GroupRCNN(TwoStageDetector): 17 | def __init__(self, 18 | *args, 19 | pre_topk=3, 20 | rpn_nms_topk=50, 21 | rpn_iou=0.7, 22 | roi_iou_threshold=0.5, 23 | num_projection_convs=1, 24 | **kwargs): 25 | """Group R-CNN. 26 | 27 | Args: 28 | pre_topk (int): Select nearest `pre_topk` points to form a 29 | group. Default to 3. 30 | rpn_nms_topk (int): The number of bboxes reserved each group. 31 | Default to 50. 32 | rpn_iou (float): The IOU Threshold when do the nms of RPN 33 | proposals. Default to 0.7. 34 | roi_iou_threshold (float): The IOU Threshold when do the nms of RPN 35 | proposals. Default to 0.5. 36 | num_projection_convs (int): The number of projection convs. Default 37 | to 1. 38 | """ 39 | self.pre_topk = pre_topk 40 | self.rpn_iou = rpn_iou 41 | self.rpn_nms_topk = rpn_nms_topk 42 | super(GroupRCNN, self).__init__(*args, **kwargs) 43 | self.roi_head.iou_threshold = roi_iou_threshold 44 | self.num_projection_convs = num_projection_convs 45 | self.projection_convs = nn.ModuleList() 46 | for _ in range(self.num_projection_convs): 47 | self.projection_convs.append( 48 | nn.Conv2d(in_channels=256, 49 | out_channels=256, 50 | kernel_size=3, 51 | padding=1)) 52 | 53 | def splte_preds(self, mlvl_pred, bbox_flag): 54 | point_flag = ~bbox_flag 55 | 56 | mlvl_normanl_pred = [item[bbox_flag] for item in mlvl_pred] 57 | mlvl_semi_pred = [item[point_flag] for item in mlvl_pred] 58 | 59 | return mlvl_normanl_pred, mlvl_semi_pred 60 | 61 | def forward_train(self, 62 | img, 63 | img_metas, 64 | gt_bboxes, 65 | gt_labels, 66 | gt_bboxes_ignore=None, 67 | gt_masks=None, 68 | **kwargs): 69 | gt_bboxes, gt_points = self.process_gts(gt_bboxes, gt_labels) 70 | 71 | x = self.extract_feat(img) 72 | 73 | losses = dict() 74 | feat_sizes = [item.size()[-2:] for item in x] 75 | mlvl_points = self.gen_points(feat_sizes, 76 | dtype=x[0].dtype, 77 | device=x[0].device) 78 | 79 | # used in roi_head 80 | rela_coods_list = self.get_relative_coordinate(mlvl_points, gt_points) 81 | mlti_assign_results = self.point_assign(mlvl_points, gt_points) 82 | rpn_losses, results_list = self.rpn_forward_train( 83 | x, 84 | img_metas, 85 | gt_bboxes, 86 | gt_labels=gt_labels, 87 | gt_bboxes_ignore=gt_bboxes_ignore, 88 | assign_results=mlti_assign_results) 89 | losses.update(rpn_losses) 90 | 91 | pred_bboxes, pred_scores = self._rpn_post_process(results_list, 92 | gt_labels=gt_labels) 93 | 94 | detached_x = [item.detach() for item in x] 95 | for conv in self.projection_convs: 96 | detached_x = [F.relu(conv(item)) for item in detached_x] 97 | 98 | roi_losses = self.roi_head.forward_train( 99 | detached_x, 100 | img_metas, 101 | pred_bboxes, 102 | gt_bboxes, 103 | gt_labels, 104 | gt_bboxes_ignore, 105 | rela_coods_list=rela_coods_list, 106 | gt_points=gt_points, 107 | **kwargs) 108 | losses.update(roi_losses) 109 | return losses 110 | 111 | def _rpn_post_process( 112 | self, 113 | pred_bboxes, 114 | gt_labels, 115 | ): 116 | """Do nms cross group and reserved `pre_topk` for each Group. 117 | 118 | Args: 119 | pred_bboxes (list[Tensor]): Regression predictions 120 | from RPN. Each has shape (group_size, num_gts, 4) 121 | gt_labels: (list[Tensor]): Gt_labels for multiple images. 122 | Each has shape (num_gts,) 123 | 124 | Returns: 125 | tuple[tensor]: 126 | 127 | - temp_all_pred_bboxes (list[tensor]): Bboxes 128 | after post process. Each has shape 129 | (rpn_nms_topk, num_gts, 4) 130 | - all_pred_scores (list[tensor]): Scores of 131 | each bboxes. Each tensor in the list has shape 132 | (rpn_nms_topk, num_gts) 133 | """ 134 | temp_all_pred_bboxes = [] 135 | all_pred_scores = [] 136 | for img_id, (bboxes, scores) in enumerate(pred_bboxes): 137 | bag_size = bboxes.shape[0] 138 | repeat_label = gt_labels[img_id][None].repeat(bag_size, 1) 139 | 140 | scores = torch.gather(scores, 2, repeat_label[..., 141 | None]).squeeze(-1) 142 | num_gt = bboxes.shape[1] 143 | 144 | if num_gt == 0: 145 | temp_all_pred_bboxes.append(scores.new_zeros(bag_size, 0, 5)) 146 | all_pred_scores.append(scores.new_zeros(bag_size, 0)) 147 | continue 148 | dets_with_score, keep = batched_nms( 149 | bboxes.view(-1, 4), scores.view(-1), repeat_label.view(-1), 150 | dict(max_num=1000, iou_threshold=self.rpn_iou)) 151 | 152 | num_pred = len(keep) 153 | gt_index = keep % num_gt 154 | arrange_gt_index = torch.arange(num_gt, device=keep.device)[:, 155 | None] 156 | # num_gt x num_pred 157 | keep_matrix = gt_index == arrange_gt_index 158 | temp_index = torch.arange(-num_pred, 159 | end=0, 160 | step=1, 161 | device=keep.device) 162 | keep_matrix = keep_matrix * temp_index 163 | rpn_nms_topk = min(num_pred, self.rpn_nms_topk) 164 | value_, index = keep_matrix.topk(rpn_nms_topk, 165 | dim=-1, 166 | largest=False) 167 | index = index.view(-1) 168 | dets_with_score = dets_with_score[index] 169 | 170 | num_pad = self.rpn_nms_topk - rpn_nms_topk 171 | padding = dets_with_score.new_zeros(num_gt, num_pad, 5) 172 | 173 | dets_with_score = dets_with_score.view(num_gt, rpn_nms_topk, 5) 174 | dets_with_score = torch.cat([dets_with_score, padding], dim=1) 175 | 176 | dets = dets_with_score[..., :4] 177 | det_scores = dets_with_score[..., 4] 178 | # topk * num_gt * 4 179 | dets = dets.permute(1, 0, 2) 180 | det_scores = det_scores.permute(1, 0) 181 | temp_all_pred_bboxes.append(dets.contiguous()) 182 | all_pred_scores.append(det_scores.contiguous()) 183 | 184 | return temp_all_pred_bboxes, all_pred_scores 185 | 186 | def process_gts(self, gt_bboxes=None, gt_labels=None): 187 | """Split point annotations from `gt_bboxes`. We concatenate point 188 | positions to bbox annotations in dataset. 189 | 190 | Args: 191 | gt_bboxes (list[Tensor]): Bbox annotations 192 | concatenated with point annotations of \ 193 | multiple images. 194 | Each tensor has shape (2 * num_intances, 4) 195 | gt_labels (list[Tensor]): Label of multiple 196 | images. Each has shape (num_instances,) 197 | 198 | Returns: 199 | tuple[List[Tensor]]: Bbox annotations and 200 | point annotations. 201 | """ 202 | gt_points = [] 203 | new_gt_bboxes = [] 204 | for img_id in range(len(gt_bboxes)): 205 | num_gt = len(gt_labels[img_id]) 206 | new_gt_bboxes.append(gt_bboxes[img_id][:num_gt]) 207 | gt_points.append(gt_bboxes[img_id][num_gt:][:, :2]) 208 | return new_gt_bboxes, gt_points 209 | 210 | def get_relative_coordinate(self, mlvl_points, points_list): 211 | """Calculate relative coordinates for each instance. 212 | 213 | Args: 214 | mlvl_points (list[Tensor]): Coordinates for multiple 215 | level of FPN. Each has shape (h, w, 2). 216 | points_list (list[Tensor]): Point annotations for 217 | multiple images. Each has shape (num_instances, 2) 218 | 219 | Returns: 220 | list[Tensor]: Relative coordinates for each instances. 221 | Each has shape (num_instances, h ,w, 2). 222 | """ 223 | real_coord_list = [] 224 | for img_id, single_img_points in enumerate(points_list): 225 | mlvl_real_coord = [] 226 | gt_points = points_list[img_id] 227 | for level in range(len(self.strides)): 228 | feat_points = mlvl_points[level] 229 | real_coods = gt_points[:, None, :] - feat_points 230 | if isinstance(self.strides[level], int): 231 | temp_stride = self.strides[level] 232 | else: 233 | temp_stride = self.strides[level][0] 234 | real_coods = real_coods.float() / temp_stride 235 | # num_gt * h * w * 2 236 | mlvl_real_coord.append(real_coods) 237 | real_coord_list.append(mlvl_real_coord) 238 | return real_coord_list 239 | 240 | def rpn_forward_train(self, 241 | x, 242 | img_metas, 243 | gt_bboxes, 244 | gt_labels=None, 245 | gt_bboxes_ignore=None, 246 | assign_results=None, 247 | **kwargs): 248 | outs = self.rpn_head(x) 249 | loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) 250 | rpn_losses = self.rpn_head.loss(*loss_inputs, 251 | gt_bboxes_ignore=gt_bboxes_ignore) 252 | 253 | results_list = self._rpn_get_bboxes(outs, img_metas, assign_results, 254 | gt_labels) 255 | 256 | return rpn_losses, results_list 257 | 258 | def _rpn_get_bboxes(self, 259 | outs, 260 | img_metas=None, 261 | assign_results=None, 262 | gt_labels=None): 263 | with torch.no_grad(): 264 | if len(outs) == 2: 265 | cls_scores, bbox_preds = outs 266 | with_score_factors = False 267 | else: 268 | cls_scores, bbox_preds, score_factors = outs 269 | with_score_factors = True 270 | assert len(cls_scores) == len(score_factors) 271 | num_levels = len(cls_scores) 272 | results_list = [] 273 | for img_id in range(len(img_metas)): 274 | img_meta = img_metas[img_id] 275 | gt_label = gt_labels[img_id] 276 | assign_result = assign_results[img_id] 277 | 278 | mlvl_cls_score = select_single_mlvl(cls_scores, img_id) 279 | mlvl_bbox_pred = select_single_mlvl(bbox_preds, img_id) 280 | if with_score_factors: 281 | mlvl_score_factor = select_single_mlvl( 282 | score_factors, img_id) 283 | else: 284 | mlvl_score_factor = [None for _ in range(num_levels)] 285 | results = self._get_dummy_bboxes_single( 286 | mlvl_cls_score, 287 | mlvl_bbox_pred, 288 | mlvl_score_factor, 289 | img_meta, 290 | gt_label=gt_label, 291 | assign_result=assign_result, 292 | img_id=img_id, 293 | ) 294 | results_list.append(results) 295 | return results_list 296 | 297 | def repeat_index(self, asssign_results, gt_labels): 298 | """Each point has several anchors. 299 | 300 | We repeat the index to the indices of corresponding bboxes 301 | """ 302 | num_base_priors = self.rpn_head.num_base_priors 303 | reapeated_indexs = [] 304 | if num_base_priors > 1: 305 | # topk * num_gt 306 | for single_lvl_results in asssign_results: 307 | temp_list = [ 308 | single_lvl_results * num_base_priors + i 309 | for i in range(num_base_priors) 310 | ] 311 | repeated_sinlge_lvl_results = torch.cat(temp_list, dim=0) 312 | reapeated_indexs.append(repeated_sinlge_lvl_results.view(-1)) 313 | else: 314 | reapeated_indexs = [item.view(-1) for item in asssign_results] 315 | repeat_labels = gt_labels[None].repeat(self.pre_topk * num_base_priors, 316 | 1) 317 | repeat_labels = repeat_labels.view(-1) 318 | 319 | return reapeated_indexs, repeat_labels 320 | 321 | def _get_dummy_bboxes_single(self, 322 | mlvl_cls_score, 323 | mlvl_bbox_pred, 324 | mlvl_score_factor, 325 | img_meta, 326 | gt_label=None, 327 | assign_result=None, 328 | **kwargs): 329 | img_shape = img_meta['img_shape'] 330 | num_gt = len(gt_label) 331 | assert num_gt == assign_result[0].shape[1] 332 | # avoid empty tensor 333 | shape_0 = self.pre_topk * self.rpn_head.num_base_priors 334 | assign_result, repeat_label = self.repeat_index( 335 | assign_result, gt_label) 336 | mlvl_bboxes = [] 337 | mlvl_scores = [] 338 | mlvl_priors = [] 339 | 340 | for level_idx, (cls_score, bbox_pred, score_factor, 341 | single_lvl_pos_index) in enumerate( 342 | zip(mlvl_cls_score, mlvl_bbox_pred, 343 | mlvl_score_factor, assign_result)): 344 | assert cls_score.size()[-2:] == bbox_pred.size()[-2:] 345 | featmap_size_hw = cls_score.shape[-2:] 346 | cls_score = cls_score.permute(1, 2, 0).reshape( 347 | -1, self.rpn_head.cls_out_channels) 348 | if self.rpn_head.loss_cls.use_sigmoid: 349 | scores = cls_score.sigmoid() 350 | else: 351 | scores = cls_score.softmax(-1) 352 | bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) 353 | priors = self.rpn_head.prior_generator.sparse_priors( 354 | single_lvl_pos_index, featmap_size_hw, level_idx, scores.dtype, 355 | scores.device) 356 | 357 | bbox_pred = bbox_pred[single_lvl_pos_index, :] 358 | scores = scores[single_lvl_pos_index, :] 359 | bboxes = self.rpn_head.bbox_coder.decode(priors, 360 | bbox_pred, 361 | max_shape=img_shape) 362 | mlvl_bboxes.append(bboxes.view(shape_0, num_gt, 4)) 363 | mlvl_priors.append(priors.view(shape_0, num_gt, 4)) 364 | mlvl_scores.append( 365 | scores.view(shape_0, num_gt, self.rpn_head.cls_out_channels)) 366 | 367 | mlgt_bboxes = torch.cat(mlvl_bboxes, dim=0) 368 | mlgt_scores = torch.cat(mlvl_scores, dim=0) 369 | returns_list = [mlgt_bboxes, mlgt_scores] 370 | return returns_list 371 | 372 | def gen_points( 373 | self, 374 | featmap_sizes, 375 | dtype, 376 | device, 377 | flatten=True, 378 | ): 379 | """Generate coordinates for FPN. 380 | 381 | Args: 382 | featmap_sizes (list[tensor]): (h ,w) of each level 383 | of FPN. 384 | dtype (torch.dtype): Type of generated coordinates. 385 | device (torch.device): Device of generated coordinates. 386 | flatten (bool): Whether to flatten the (h, w) dimensions. 387 | 388 | Returns: 389 | list[tensor]: Coordinates of each level. 390 | """ 391 | 392 | self.strides = self.rpn_head.prior_generator.strides 393 | 394 | def _get_points_single(featmap_size, 395 | stride, 396 | dtype, 397 | device, 398 | flatten=False, 399 | offset=0.0): 400 | """Get points of a single scale level.""" 401 | if isinstance(stride, (tuple, list)): 402 | stride = stride[0] 403 | h, w = featmap_size 404 | # First create Range with the default dtype, than convert to 405 | # target `dtype` for onnx exporting. 406 | x_range = torch.arange(w, device=device).to(dtype) 407 | y_range = torch.arange(h, device=device).to(dtype) 408 | y, x = torch.meshgrid(y_range, x_range) 409 | if flatten: 410 | y = y.flatten() 411 | x = x.flatten() 412 | 413 | points = torch.stack( 414 | (x * stride, y * stride), dim=-1) + stride * 0.0 415 | return points 416 | 417 | mlvl_points = [] 418 | for i in range(len(featmap_sizes)): 419 | mlvl_points.append( 420 | _get_points_single(featmap_sizes[i], self.strides[i], dtype, 421 | device, flatten)) 422 | 423 | return mlvl_points 424 | 425 | def point_assign(self, mlvl_points, gt_points): 426 | """Select the nearest `pre_topk` points for each point annotation. 427 | 428 | Args: 429 | mlvl_points (list[tensor]): Coordinates of each FPN 430 | level. 431 | gt_points (list[tensor]): Points annotations for 432 | multiple images. 433 | 434 | Returns: 435 | list[list[tensor]]: The outer list indicate the 436 | images. Inter list indicate the FPN level. 437 | And the tensor id the index of selected 438 | feature points. Each tensor has shape 439 | (pre_topk, num_instances) 440 | """ 441 | def nearest_k(mlvl_points, gt_points, pre_topk=3): 442 | 443 | mlvl_prior_index = [] 444 | for points in mlvl_points: 445 | distances = (points[:, None, :] - 446 | gt_points[None, :, :]).pow(2).sum(-1) 447 | min_pre_topk = min(len(distances), pre_topk) 448 | _, topk_idxs_per_level = distances.topk(min_pre_topk, 449 | dim=0, 450 | largest=False) 451 | mlvl_prior_index.append(topk_idxs_per_level) 452 | return mlvl_prior_index 453 | 454 | mlti_img_assign = [] 455 | for single_img_gt_points in gt_points: 456 | mlti_img_assign.append( 457 | nearest_k(mlvl_points, single_img_gt_points, self.pre_topk)) 458 | return mlti_img_assign 459 | 460 | def simple_test(self, 461 | img, 462 | img_metas, 463 | proposals=None, 464 | rescale=False, 465 | **kwargs): 466 | """Test without augmentation.""" 467 | 468 | assert self.with_bbox, 'Bbox head must be implemented.' 469 | 470 | gt_labels = kwargs.get('gt_labels', []) 471 | # remove aug 472 | gt_labels = [item[0] for item in gt_labels] 473 | gt_bboxes = kwargs.get('gt_bboxes', []) 474 | gt_bboxes = [item[0] for item in gt_bboxes] 475 | gt_bboxes, gt_points = self.process_gts(gt_labels=gt_labels, 476 | gt_bboxes=gt_bboxes) 477 | 478 | x = self.extract_feat(img) 479 | 480 | outs = self.rpn_head.forward(x) 481 | rpn_results_list = self.rpn_head.get_bboxes(*outs, 482 | img_metas=img_metas, 483 | rescale=True) 484 | 485 | feat_sizes = [item.size()[-2:] for item in outs[0]] 486 | mlvl_points = self.gen_points(feat_sizes, 487 | dtype=outs[0][0].dtype, 488 | device=outs[0][0].device) 489 | rela_coods_list = self.get_relative_coordinate(mlvl_points, gt_points) 490 | 491 | mlti_assign_results = self.point_assign(mlvl_points, gt_points) 492 | 493 | all_pred_results = self._rpn_get_bboxes( 494 | outs, 495 | img_metas=img_metas, 496 | assign_results=mlti_assign_results, 497 | gt_labels=gt_labels) 498 | all_pred_bboxes, all_pred_scores = self._rpn_post_process( 499 | all_pred_results, gt_labels=gt_labels) 500 | 501 | group_resutls_list = [] 502 | for pred_bboxes, pred_scores, pred_label, img_meta \ 503 | in zip(all_pred_bboxes, all_pred_scores, gt_labels, img_metas): 504 | pred_bboxes = torch.cat([pred_bboxes, pred_scores[..., None]], 505 | dim=-1) 506 | bag_size = len(pred_bboxes) 507 | pred_label = pred_label[None].repeat(bag_size, 1) 508 | pred_bboxes = pred_bboxes.view(-1, 5) 509 | pred_label = pred_label.view(-1) 510 | scale_factors = pred_bboxes.new_tensor(img_meta['scale_factor']) 511 | pred_bboxes[:, :4] = pred_bboxes[:, :4] / scale_factors 512 | group_resutls_list.append((pred_bboxes, pred_label)) 513 | 514 | if len(gt_labels[0]) > 0: 515 | 516 | for conv in self.projection_convs: 517 | x = [F.relu(conv(item)) for item in x] 518 | 519 | extra_gts, scores, gt_labels = self.roi_head.simple_test( 520 | x, 521 | all_pred_bboxes, 522 | img_metas, 523 | rela_coods_list=rela_coods_list, 524 | labels=gt_labels, 525 | all_pred_scores=all_pred_scores, 526 | gt_points=gt_points, 527 | ) 528 | roi_results_list = [] 529 | for img_id, (bboxes, score, img_meta) in enumerate( 530 | zip(extra_gts, scores, img_metas)): 531 | scale_factors = bboxes.new_tensor(img_meta['scale_factor']) 532 | bboxes = bboxes / scale_factors 533 | roi_results_list.append((torch.cat([bboxes, score[:, None]], 534 | dim=-1), gt_labels[img_id])) 535 | else: 536 | roi_results_list = [(torch.zeros(0, 5), gt_labels[0])] 537 | 538 | return self.encode_results(rpn_results_list, group_resutls_list, 539 | roi_results_list) 540 | 541 | def encode_results(self, rpn_results_list, group_resutls_list, 542 | roi_results_list): 543 | 544 | main_results = [ 545 | bbox2result(det_bboxes, det_labels, 80) 546 | for det_bboxes, det_labels in rpn_results_list 547 | ] 548 | rpn_results = [ 549 | bbox2result(det_bboxes, det_labels, 80) 550 | for det_bboxes, det_labels in group_resutls_list 551 | ] 552 | 553 | semi_results = [ 554 | bbox2result(det_bboxes, det_labels, 80) 555 | for det_bboxes, det_labels in roi_results_list 556 | ] 557 | 558 | results = [(main_results[img_id], rpn_results[img_id], 559 | semi_results[img_id]) 560 | for img_id in range(len(main_results))] 561 | 562 | return results 563 | -------------------------------------------------------------------------------- /projects/models/group_roi_head.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Group R-CNN 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | # Written by Shilong Zhang 5 | # -------------------------------------------------------- 6 | import copy 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from mmcv.ops.nms import batched_nms 12 | from mmcv.runner import force_fp32 13 | from mmdet.core import bbox2roi, bbox_overlaps 14 | from mmdet.models.builder import HEADS 15 | from mmdet.models.losses import accuracy 16 | from mmdet.models.roi_heads.cascade_roi_head import CascadeRoIHead 17 | 18 | 19 | @HEADS.register_module() 20 | class GroupRoIHead(CascadeRoIHead): 21 | def __init__(self, *args, pos_iou_thrs=[0.5, 0.6, 0.7], **kwargs): 22 | # Used to assign the label 23 | self.pos_iou_thrs = pos_iou_thrs 24 | super(GroupRoIHead, self).__init__(*args, **kwargs) 25 | self._init_dy_groupconv() 26 | 27 | def _init_dy_groupconv(self): 28 | # fusing the relative coordinates feature with norm proposal pooling 29 | # features 30 | self.compress_feat = nn.Conv2d(258, 256, 3, stride=1, padding=1) 31 | self.cls_embedding = nn.Embedding(80, 256) 32 | # 1 is the kernel size 33 | self.generate_params = nn.ModuleList([ 34 | nn.Linear(256, 1 * 1 * 256 * 256) for _ in range(self.num_stages) 35 | ]) 36 | self.avg_pool = nn.AvgPool2d((7, 7)) 37 | # fusing the mean roi features and category embedding 38 | self.compress = nn.Linear(256 + 256, 256) 39 | self.group_norm = nn.GroupNorm(32, 256) 40 | 41 | def _bbox_forward(self, 42 | stage, 43 | x, 44 | rois, 45 | coord_feats, 46 | group_size=None, 47 | rois_per_image=None, 48 | gt_labels=None, 49 | **kwargs): 50 | """Box head forward function used in both training and testing. 51 | 52 | Args: 53 | stage (int): The index of stage. 54 | x (list[Tensor]): FPN Features. 55 | rois (Tensor): Has shape (num_proposals, 5). 56 | coord_feats (Tensor): Pooling feature from relative coordinates 57 | feature map. Has shape (num_proposals, 2, 7, 7). 58 | group_size (list[int]): Size of instance group. Default to None. 59 | rois_per_image (list[int]): Number of proposals of each image. 60 | Default to None. 61 | gt_labels (list[Tensor]): Gt labels of multiple images. 62 | Default to None. 63 | 64 | Returns: 65 | dict: dict of model predictions. 66 | """ 67 | bbox_roi_extractor = self.bbox_roi_extractor[stage] 68 | bbox_head = self.bbox_head[stage] 69 | bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], 70 | rois) 71 | bbox_feats = torch.cat([bbox_feats, coord_feats], dim=1) 72 | # compress from 258 to 256 73 | bbox_feats = self.compress_feat(bbox_feats) 74 | params = \ 75 | self.generate_conv_params(bbox_feats, 76 | rois_per_image, 77 | group_size, 78 | gt_labels, 79 | stage=stage) 80 | bag_size = group_size[0] 81 | num_all_gt = int(bbox_feats.size(0) / bag_size) 82 | bbox_feats = bbox_feats.view(bag_size, num_all_gt * 256, 7, 7) 83 | bbox_feats = F.conv2d(bbox_feats, 84 | params, 85 | stride=1, 86 | padding=0, 87 | groups=num_all_gt) 88 | bbox_feats = bbox_feats.view(bag_size * num_all_gt, 256, 7, 7) 89 | # stable the training 90 | bbox_feats = self.group_norm(bbox_feats) 91 | 92 | bbox_feats = F.relu(bbox_feats) 93 | cls_score, bbox_pred = bbox_head(bbox_feats) 94 | 95 | bbox_results = dict(cls_score=cls_score, 96 | bbox_pred=bbox_pred, 97 | bbox_feats=bbox_feats) 98 | return bbox_results 99 | 100 | def generate_conv_params( 101 | self, 102 | rois_feats, 103 | rois_per_image, 104 | group_size, 105 | labels, 106 | stage=0, 107 | ): 108 | """Generate parameters for dynamic group conv. 109 | 110 | Args: 111 | rois_feats (tensor): Pooling feature from FPN. Has shape 112 | (num_proposals, 256, 7 , 7). 113 | rois_per_image (list[int]): Number of proposals of each 114 | image. 115 | group_size (list[int]): Instance group size of each image. 116 | labels (list[tensor]): Gt labels for multiple images. 117 | stage (int): Index of stage. 118 | 119 | Returns: 120 | Tensor: Parameters for dynamic group conv. Has shape 121 | (num_gts * 256, 256, 1, 1), which arrange as 122 | (C_in, C_out, kernel_size, Kernel_size). 123 | """ 124 | start = 0 125 | param_list = [] 126 | ori_pool_rois_feats = self.avg_pool(rois_feats).squeeze() 127 | 128 | for img_id in range(len(labels)): 129 | num_rois = rois_per_image[img_id] 130 | end = num_rois + start 131 | pool_rois_feats = ori_pool_rois_feats[start:end] 132 | start = end 133 | bag_size = group_size[img_id] 134 | label = labels[img_id] 135 | num_gt = len(label) 136 | bag_embeds = self.cls_embedding.weight[label] 137 | pool_rois_feats = pool_rois_feats.view(bag_size, num_gt, 256) 138 | pool_rois_feats = pool_rois_feats.mean(0) 139 | pool_rois_feats = torch.cat([bag_embeds, pool_rois_feats], dim=-1) 140 | pool_rois_feats = self.compress(pool_rois_feats) 141 | 142 | params = self.generate_params[stage](pool_rois_feats) 143 | # use group conv 144 | conv_weight = params.view(num_gt, 256, 256, 1, 1) 145 | conv_weight = conv_weight.reshape(num_gt * 256, 256, 1, 1) 146 | # num_gt * 147 | param_list.append(conv_weight) 148 | params = torch.cat(param_list, dim=0) 149 | 150 | return params 151 | 152 | def _first_coord_pooling(self, coord_feats, proposal_list): 153 | """Pooling relative coordinates for first stage.""" 154 | rois_list = [] 155 | start_index = 0 156 | 157 | for img_id, bag_bboxes in enumerate(proposal_list): 158 | bag_size, num_gt, _ = bag_bboxes.size() 159 | roi_index = torch.arange(start_index, 160 | start_index + num_gt, 161 | device=bag_bboxes.device) 162 | roi_index = roi_index[None, :, None].repeat(bag_size, 1, 1).float() 163 | bag_bboxes = torch.cat([roi_index, bag_bboxes], dim=-1) 164 | bag_rois = bag_bboxes.view(-1, 5) 165 | rois_list.append(bag_rois) 166 | start_index += num_gt 167 | 168 | rois = torch.cat(rois_list, 0).contiguous() 169 | self.roi_index = rois[:, :1] 170 | # keep same during three stages 171 | self.cood_roi_extractor = copy.deepcopy(self.bbox_roi_extractor[0]) 172 | self.cood_roi_extractor.out_channels = 2 173 | self.coord_feats = coord_feats[:self.cood_roi_extractor.num_inputs] 174 | 175 | coord_feats = self.cood_roi_extractor( 176 | coord_feats[:self.cood_roi_extractor.num_inputs], rois) 177 | 178 | return coord_feats 179 | 180 | def _not_first_coord_pooling(self, rois): 181 | """Pooling relative coordinates for second and third stage.""" 182 | rois = torch.cat([self.roi_index, rois], dim=-1) 183 | coord_feats = self.cood_roi_extractor( 184 | self.coord_feats[:self.cood_roi_extractor.num_inputs], rois) 185 | return coord_feats 186 | 187 | def instance_assign(self, stage, anchors, gt_bboxes, gt_labels, 188 | group_size): 189 | 190 | repeat_gts = [] 191 | repeat_labels = [] 192 | num_gts = 0 193 | group_size_each_gt = [] 194 | for gt, label, bag_size in zip(gt_bboxes, gt_labels, group_size): 195 | num_gts += len(gt) 196 | gt = gt[None, :, :].repeat(bag_size, 1, 1) 197 | label = label[None, :].repeat(bag_size, 1) 198 | repeat_gts.append(gt.view(-1, 4)) 199 | repeat_labels.append(label.view(-1)) 200 | group_size_each_gt.extend(gt.size(1) * [bag_size]) 201 | 202 | repeat_gts = torch.cat(repeat_gts, dim=0) 203 | repeat_labels = torch.cat(repeat_labels, dim=0) 204 | 205 | self.repeat_labels = repeat_labels 206 | 207 | match_quality_matrix = bbox_overlaps(anchors, 208 | repeat_gts, 209 | is_aligned=True) 210 | 211 | pos_mask = match_quality_matrix > self.pos_iou_thrs[stage] 212 | targets_weight = match_quality_matrix.new_ones(len(pos_mask)) 213 | 214 | bbox_targets = self.bbox_head[stage].bbox_coder.encode( 215 | anchors, repeat_gts) 216 | all_labels = torch.ones_like( 217 | repeat_labels) * self.bbox_head[0].num_classes 218 | all_labels[pos_mask] = repeat_labels[pos_mask] 219 | 220 | pos_bbox_targets = bbox_targets[pos_mask] 221 | 222 | return pos_mask, pos_bbox_targets, all_labels, targets_weight 223 | 224 | @force_fp32(apply_to=('cls_score', 'bbox_pred')) 225 | def loss(self, 226 | stage, 227 | cls_score, 228 | bbox_pred, 229 | rois, 230 | bbox_targets, 231 | labels, 232 | pos_mask, 233 | reduction_override=None): 234 | 235 | label_weights = torch.ones_like(labels) 236 | bbox_weights = torch.ones_like(bbox_targets) 237 | losses = dict() 238 | avg_factor = max(pos_mask.sum(), pos_mask.new_ones(1).sum()) 239 | 240 | loss_cls_ = self.bbox_head[stage].loss_cls( 241 | cls_score, 242 | labels, 243 | label_weights, 244 | avg_factor=avg_factor, 245 | reduction_override=reduction_override) 246 | 247 | losses['loss_cls'] = loss_cls_ 248 | losses['acc'] = accuracy(cls_score, labels) 249 | # will be divided by num_gts of single batch outside 250 | losses['avg_pos'] = pos_mask.sum() 251 | pos_inds = pos_mask 252 | # do not perform bounding box regression for BG anymore. 253 | if pos_inds.any(): 254 | if self.bbox_head[stage].reg_decoded_bbox: 255 | # When the regression loss (e.g. `IouLoss`, 256 | # `GIouLoss`, `DIouLoss`) is applied directly on 257 | # the decoded bounding boxes, it decodes the 258 | # already encoded coordinates to absolute format. 259 | bbox_pred = self.bbox_head[stage].bbox_coder.decode( 260 | rois[:, 1:], bbox_pred) 261 | if self.bbox_head[stage].reg_class_agnostic: 262 | pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 263 | 4)[pos_inds.type(torch.bool)] 264 | else: 265 | pos_bbox_pred = bbox_pred.view( 266 | bbox_pred.size(0), -1, 267 | 4)[pos_inds.type(torch.bool), 268 | labels[pos_inds.type(torch.bool)]] 269 | losses['loss_bbox'] = self.bbox_head[stage].loss_bbox( 270 | pos_bbox_pred, 271 | bbox_targets, 272 | bbox_weights, 273 | avg_factor=avg_factor, 274 | reduction_override=reduction_override) 275 | else: 276 | losses['loss_bbox'] = bbox_pred.sum() * 0 277 | 278 | return losses 279 | 280 | def forward_train(self, 281 | x, 282 | img_metas, 283 | proposal_list, 284 | gt_bboxes, 285 | gt_labels, 286 | gt_bboxes_ignore=None, 287 | rela_coods_list=None, 288 | gt_points=None, 289 | **kwargs): 290 | """Get loss of single iter. 291 | 292 | Args: 293 | x (list(Tensor)): FPN Feature maps 294 | img_metas(list[dict]): Meta information for multiple images. 295 | proposal_list (list[Tensor]): Proposals of each instance 296 | group in multiple images. Each has shape 297 | (rpn_nms_topk, num_gts, 4). 298 | gt_bboxes (list[Tensor]): Gt bboxes for multiple images. 299 | Each has shape (num_gts, 4). 300 | gt_labels (list[Tensor]): Gt labels for multiple images. 301 | Each has shape (num_gts,). 302 | rela_coods_list (list[list[tensor]]): Relative coordinates for 303 | FPN in multiple images. Each tensor has shape 304 | (num_instances, h*w, 2). 305 | gt_points (list[Tensor]): Gt points for multiple images. 306 | Each has shape (num_gts, 2). 307 | 308 | Returns: 309 | dict: losses of RPN and RoIHead. 310 | """ 311 | # when nms_topk may has different bag size 312 | group_size = [item.size(0) for item in proposal_list] 313 | wh_each_level = [item.shape[-2:] for item in x] 314 | num_img = len(rela_coods_list) 315 | num_level = len(rela_coods_list[0]) 316 | all_num_gts = 0 317 | format_rela_coods_list = [] 318 | for img_id in range(num_img): 319 | real_coods = rela_coods_list[img_id] 320 | mlvl_coord_list = [] 321 | for level in range(num_level): 322 | format_coords = real_coods[level] 323 | num_gt = format_coords.size(0) 324 | all_num_gts += num_gt 325 | format_coords = format_coords.view(num_gt, 326 | *wh_each_level[level], 327 | 2).permute(0, 3, 1, 2) 328 | mlvl_coord_list.append(format_coords) 329 | format_rela_coods_list.append(mlvl_coord_list) 330 | 331 | mlvl_concate_coods = [] 332 | for level in range(num_level): 333 | mlti_img_cood = [ 334 | format_rela_coods_list[img_id][level] 335 | for img_id in range(num_img) 336 | ] 337 | concat_coods = torch.cat(mlti_img_cood, dim=0).contiguous() 338 | mlvl_concate_coods.append(concat_coods) 339 | 340 | losses = dict() 341 | 342 | rois_per_image = [ 343 | item.size(0) * item.size(1) for item in proposal_list 344 | ] 345 | 346 | rois = None 347 | 348 | for stage in range(self.num_stages): 349 | if stage == 0: 350 | coord_feats = self._first_coord_pooling( 351 | mlvl_concate_coods, proposal_list) 352 | feat_rois = bbox2roi( 353 | [item.view(-1, 4).detach() for item in proposal_list]) 354 | else: 355 | coord_feats = self._not_first_coord_pooling(rois) 356 | feat_rois = rois.split(rois_per_image, dim=0) 357 | feat_rois = bbox2roi( 358 | [item.view(-1, 4).detach() for item in feat_rois]) 359 | 360 | pos_mask, pos_bbox_targets, pos_labels, \ 361 | targets_reweight = self.instance_assign( 362 | stage, feat_rois[:, 1:], gt_bboxes, gt_labels, group_size) 363 | 364 | bbox_results = self._bbox_forward( 365 | stage, 366 | x, 367 | feat_rois, 368 | coord_feats, 369 | group_size, 370 | rois_per_image, 371 | gt_labels=gt_labels, 372 | gt_points=gt_points, 373 | img_metas=img_metas, 374 | ) 375 | 376 | single_stage_loss = self.loss(stage, bbox_results['cls_score'], 377 | bbox_results['bbox_pred'], feat_rois, 378 | pos_bbox_targets, pos_labels, 379 | pos_mask) 380 | single_stage_loss['avg_pos'] = single_stage_loss[ 381 | 'avg_pos'] / float(all_num_gts) * 5 382 | 383 | for name, value in single_stage_loss.items(): 384 | losses[f's{stage}.{name}'] = (value * 385 | self.stage_loss_weights[stage] 386 | if 'loss' in name else value) 387 | 388 | # refine bboxes 389 | if stage < self.num_stages - 1: 390 | with torch.no_grad(): 391 | rois = self.bbox_head[stage].bbox_coder.decode( 392 | feat_rois[:, 1:], 393 | bbox_results['bbox_pred'], 394 | ) 395 | 396 | return losses 397 | 398 | def simple_test(self, 399 | x, 400 | proposal_list, 401 | img_metas, 402 | rela_coods_list=None, 403 | labels=None, 404 | gt_points=None, 405 | **kwargs): 406 | """Get predictions of single iter. 407 | 408 | Args: 409 | x (list(Tensor)): FPN Feature maps 410 | proposal_list (list[Tensor]): Proposals of each instance 411 | group in multiple images. Each has shape 412 | (rpn_nms_topk, num_gts, 4). 413 | img_metas(list[dict]): Meta information for multiple images. 414 | rela_coods_list (list[list[tensor]]): Relative coordinates for 415 | FPN in multiple images. Each tensor has shape 416 | (num_instances, h*w, 2). 417 | labels (list[Tensor]): Gt labels for multiple images. 418 | Each has shape (num_gts,). 419 | gt_points (list[Tensor]): Gt points for multiple images. 420 | Each has shape (num_gts, 2). 421 | 422 | Returns: 423 | Tuple: 424 | 425 | - pred_bboxes (list[tensor]): Bbox prediction 426 | of multiple images. Each has shape (num_instances, 4). 427 | - pred_scores (list[tensor]): Score of each bbox in 428 | multiple images. Each has shape (num_instances,). 429 | - pred_labels (list[tensor]): Label of each bbox in 430 | multiple images. Each has shape (num_instances,). 431 | """ 432 | num_images = len(proposal_list) 433 | group_size = [item.size(0) for item in proposal_list] 434 | ms_scores = [] 435 | wh_each_level = [item.shape[-2:] for item in x] 436 | num_img = len(rela_coods_list) 437 | num_level = len(rela_coods_list[0]) 438 | format_rela_coods_list = [] 439 | 440 | repeat_labels = [] 441 | for label, bag_size in zip(labels, group_size): 442 | label = label[None, :].repeat(bag_size, 1) 443 | repeat_labels.append(label.view(-1)) 444 | # used in post process in Group R-CNN 445 | self.repeat_labels = torch.cat(repeat_labels, dim=0) 446 | 447 | for img_id in range(num_img): 448 | real_coods = rela_coods_list[img_id] 449 | mlvl_coord_list = [] 450 | for level in range(num_level): 451 | format_coords = real_coods[level] 452 | num_gt = format_coords.size(0) 453 | format_coords = format_coords.view(num_gt, 454 | *wh_each_level[level], 455 | 2).permute(0, 3, 1, 2) 456 | mlvl_coord_list.append(format_coords) 457 | format_rela_coods_list.append(mlvl_coord_list) 458 | mlvl_concate_coods = [] 459 | 460 | for level in range(num_level): 461 | mlti_img_cood = [ 462 | format_rela_coods_list[img_id][level] 463 | for img_id in range(num_img) 464 | ] 465 | concat_coods = torch.cat(mlti_img_cood, dim=0).contiguous() 466 | mlvl_concate_coods.append(concat_coods) 467 | 468 | rois_per_image = [ 469 | item.size(0) * item.size(1) for item in proposal_list 470 | ] 471 | for stage in range(self.num_stages): 472 | self.current_stage = stage 473 | if stage == 0: 474 | coord_feats = self._first_coord_pooling( 475 | mlvl_concate_coods, proposal_list) 476 | feat_rois = bbox2roi( 477 | [item.view(-1, 4).detach() for item in proposal_list]) 478 | else: 479 | coord_feats = self._not_first_coord_pooling( 480 | torch.cat(proposal_list, dim=0)) 481 | feat_rois = proposal_list 482 | feat_rois = bbox2roi( 483 | [item.view(-1, 4).detach() for item in feat_rois]) 484 | 485 | bbox_results = self._bbox_forward( 486 | stage, 487 | x, 488 | feat_rois, 489 | coord_feats, 490 | rois_per_image=rois_per_image, 491 | group_size=group_size, 492 | gt_labels=labels, 493 | gt_points=gt_points, 494 | img_metas=img_metas, 495 | ) 496 | 497 | bbox_preds = bbox_results['bbox_pred'] 498 | 499 | if self.bbox_head[-1].loss_cls.use_sigmoid: 500 | cls_score = bbox_results['cls_score'].sigmoid() 501 | num_classes = cls_score.size(-1) 502 | else: 503 | cls_score = bbox_results['cls_score'].softmax(-1) 504 | num_classes = cls_score.size(-1) - 1 505 | cls_score = cls_score[:, :num_classes] 506 | 507 | decode_bboxes = [] 508 | all_scores = [] 509 | 510 | for img_id in range(num_images): 511 | img_shape = img_metas[img_id]['img_shape'] 512 | img_mask = feat_rois[:, 0] == img_id 513 | temp_rois = feat_rois[img_mask] 514 | temp_bbox_pred = bbox_preds[img_mask] 515 | bboxes = self.bbox_head[stage].bbox_coder.decode( 516 | temp_rois[..., 1:], temp_bbox_pred, max_shape=img_shape) 517 | temp_scores = cls_score[img_mask] 518 | bboxes = bboxes.view(group_size[img_id], -1, 4) 519 | temp_scores = temp_scores.view(group_size[img_id], -1, 520 | num_classes) 521 | decode_bboxes.append(bboxes) 522 | all_scores.append(temp_scores) 523 | 524 | ms_scores.append(all_scores) 525 | proposal_list = [item.view(-1, 4) for item in decode_bboxes] 526 | 527 | ms_scores = [ 528 | sum([score[i] for score in ms_scores]) / float(len(ms_scores)) 529 | for i in range(num_images) 530 | ] 531 | 532 | pred_bboxes = [] 533 | pred_scores = [] 534 | pred_labels = [] 535 | # select bbox from each group 536 | for img_id in range(num_images): 537 | all_class_scores = ms_scores[img_id] 538 | if all_class_scores.numel(): 539 | repeat_label = labels[img_id][None].repeat( 540 | group_size[img_id], 1) 541 | scores = torch.gather(all_class_scores, 2, 542 | repeat_label[..., None]).squeeze(-1) 543 | 544 | num_gt = decode_bboxes[img_id].shape[1] 545 | dets, keep = batched_nms( 546 | decode_bboxes[img_id].view(-1, 4), scores.view(-1), 547 | repeat_label.view(-1), 548 | dict(max_num=1000, iou_threshold=self.iou_threshold)) 549 | num_pred = len(keep) 550 | gt_index = keep % num_gt 551 | arrange_gt_index = torch.arange(num_gt, 552 | device=keep.device)[:, None] 553 | # num_gt x num_pred 554 | keep_matrix = gt_index == arrange_gt_index 555 | temp_index = torch.arange(-num_pred, 556 | end=0, 557 | step=1, 558 | device=keep.device) 559 | keep_matrix = keep_matrix * temp_index 560 | 561 | value_, index = keep_matrix.min(dim=-1) 562 | dets = dets[index] 563 | pred_bboxes.append(dets[:, :4]) 564 | pred_scores.append(dets[:, -1]) 565 | pred_labels.append(labels[img_id]) 566 | 567 | else: 568 | pred_bboxes.append(all_class_scores.new_zeros(0, 4)) 569 | pred_scores.append(all_class_scores.new_zeros(0)) 570 | pred_labels.append(all_class_scores.new_zeros(0)) 571 | 572 | return pred_bboxes, pred_scores, pred_labels 573 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | WORK_DIR=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/train.py --seed 0 $CONFIG --work-dir=${WORK_DIR} --launcher pytorch ${@:3} 11 | -------------------------------------------------------------------------------- /tools/generate_anns.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Group R-CNN 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | # Written by Shilong Zhang 5 | # -------------------------------------------------------- 6 | import json 7 | import sys 8 | 9 | import mmcv 10 | import numpy as np 11 | import pycocotools.mask as maskUtils 12 | from mmdet.datasets.api_wrappers import COCO 13 | 14 | from projects.datasets.point_coco import local_numpy_seed 15 | 16 | sys.path.insert(0, './') 17 | 18 | 19 | class PointGenerator(object): 20 | def __init__(self, ann_file): 21 | self.ann_file = ann_file 22 | self.coco = COCO(ann_file) 23 | self.seed = 0 24 | 25 | def generate_points(self): 26 | 27 | save_json = dict() 28 | save_json['images'] = self.coco.dataset['images'] 29 | save_json['annotations'] = [] 30 | annotations = self.coco.dataset['annotations'] 31 | save_json['categories'] = self.coco.dataset['categories'] 32 | 33 | id_info = dict() 34 | for img_info in self.coco.dataset['images']: 35 | id_info[img_info['id']] = img_info 36 | prog_bar = mmcv.ProgressBar(len(annotations)) 37 | with local_numpy_seed(self.seed): 38 | for ann in annotations: 39 | prog_bar.update() 40 | img_info = id_info[ann['image_id']] 41 | segm = ann.get('segmentation', None) 42 | if isinstance(segm, list): 43 | 44 | rles = maskUtils.frPyObjects(segm, img_info['height'], 45 | img_info['width']) 46 | rle = maskUtils.merge(rles) 47 | elif isinstance(segm['counts'], list): 48 | # uncompressed RLE 49 | rle = maskUtils.frPyObjects(segm, img_info['height'], 50 | img_info['width']) 51 | else: 52 | # rle 53 | rle = segm 54 | mask = maskUtils.decode(rle) 55 | if mask.sum() > 0: 56 | ys, xs = np.nonzero(mask) 57 | point_idx = np.random.randint(len(xs)) 58 | x1 = int(xs[point_idx]) 59 | y1 = int(ys[point_idx]) 60 | ann['point'] = [x1, y1, x1, y1] 61 | else: 62 | x1, y1, w, h = ann['bbox'] 63 | x1 = np.random.uniform(x1, x1 + w) 64 | y1 = np.random.uniform(y1, y1 + h) 65 | ann['point'] = [x1, y1, x1, y1] 66 | 67 | save_json['annotations'].append(ann) 68 | mmcv.mkdir_or_exist('./point_ann/') 69 | ann_name = self.ann_file.split('/')[-1] 70 | with open(f'./point_ann/{ann_name}', 'w') as f: 71 | json.dump(save_json, f) 72 | 73 | 74 | if __name__ == '__main__': 75 | 76 | args = sys.argv 77 | if len(args) > 1: 78 | ann_file = args[1] 79 | 80 | point_generator = PointGenerator(ann_file=ann_file) 81 | point_generator.generate_points() 82 | -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | SRUN_ARGS=${SRUN_ARGS:-""} 13 | PY_ARGS=${@:5} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --seed 0 --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Group R-CNN 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | # Written by Shilong Zhang 5 | # -------------------------------------------------------- 6 | import argparse 7 | import os 8 | import os.path as osp 9 | import time 10 | import warnings 11 | 12 | import mmcv 13 | import torch 14 | from mmcv import Config, DictAction 15 | from mmcv.cnn import fuse_conv_bn 16 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 17 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, 18 | wrap_fp16_model) 19 | from mmdet.datasets import build_dataset, replace_ImageToTensor 20 | from mmdet.models import build_detector 21 | 22 | from projects.api.test import pointdet_multi_gpu_test, pointdet_single_gpu_test 23 | from projects.datasets import * # noqa 24 | from projects.datasets.builder import build_point_dataloader 25 | from projects.models import * # noqa 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser( 30 | description='MMDet test (and eval) a model') 31 | parser.add_argument('config', help='test config file path') 32 | parser.add_argument('checkpoint', help='checkpoint file') 33 | parser.add_argument( 34 | '--ceph', 35 | action='store_true', 36 | help='whether not to evaluate the checkpoint during training') 37 | parser.add_argument( 38 | '--vis', 39 | action='store_true', 40 | help='whether not to evaluate the checkpoint during training') 41 | parser.add_argument( 42 | '--work-dir', 43 | help='the directory to save the file containing evaluation metrics') 44 | parser.add_argument('--out', help='output result file in pickle format') 45 | parser.add_argument( 46 | '--fuse-conv-bn', 47 | action='store_true', 48 | help='Whether to fuse conv and bn, this will slightly increase' 49 | 'the inference speed') 50 | parser.add_argument( 51 | '--format-only', 52 | action='store_true', 53 | help='Format the output results without perform evaluation. It is' 54 | 'useful when you want to format the result to a specific format and ' 55 | 'submit it to the test server') 56 | parser.add_argument( 57 | '--eval', 58 | type=str, 59 | nargs='+', 60 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",' 61 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') 62 | parser.add_argument('--show', action='store_true', help='show results') 63 | parser.add_argument('--show-dir', 64 | help='directory where painted images will be saved') 65 | parser.add_argument('--show-score-thr', 66 | type=float, 67 | default=0.3, 68 | help='score threshold (default: 0.3)') 69 | parser.add_argument('--gpu-collect', 70 | action='store_true', 71 | help='whether to use gpu to collect results.') 72 | parser.add_argument( 73 | '--tmpdir', 74 | help='tmp directory used for collecting results from multiple ' 75 | 'workers, available when gpu-collect is not specified') 76 | parser.add_argument( 77 | '--cfg-options', 78 | nargs='+', 79 | action=DictAction, 80 | help='override some settings in the used config, the key-value pair ' 81 | 'in xxx=yyy format will be merged into config file. If the value to ' 82 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 83 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 84 | 'Note that the quotation marks are necessary and that no white space ' 85 | 'is allowed.') 86 | parser.add_argument( 87 | '--options', 88 | nargs='+', 89 | action=DictAction, 90 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 91 | 'format will be kwargs for dataset.evaluate() function (deprecate), ' 92 | 'change to --eval-options instead.') 93 | parser.add_argument( 94 | '--eval-options', 95 | nargs='+', 96 | action=DictAction, 97 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 98 | 'format will be kwargs for dataset.evaluate() function') 99 | parser.add_argument('--launcher', 100 | choices=['none', 'pytorch', 'slurm', 'mpi'], 101 | default='none', 102 | help='job launcher') 103 | parser.add_argument('--local_rank', type=int, default=0) 104 | args = parser.parse_args() 105 | if 'LOCAL_RANK' not in os.environ: 106 | os.environ['LOCAL_RANK'] = str(args.local_rank) 107 | 108 | if args.options and args.eval_options: 109 | raise ValueError( 110 | '--options and --eval-options cannot be both ' 111 | 'specified, --options is deprecated in favor of --eval-options') 112 | if args.options: 113 | warnings.warn('--options is deprecated in favor of --eval-options') 114 | args.eval_options = args.options 115 | return args 116 | 117 | 118 | def main(): 119 | args = parse_args() 120 | 121 | assert args.out or args.eval or args.format_only or args.show \ 122 | or args.show_dir, \ 123 | ('Please specify at least one operation (save/eval/format/show the ' 124 | 'results / save the results) with the argument "--out", "--eval"' 125 | ', "--format-only", "--show" or "--show-dir"') 126 | 127 | if args.eval and args.format_only: 128 | raise ValueError('--eval and --format_only cannot be both specified') 129 | 130 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): 131 | raise ValueError('The output file must be a pkl file.') 132 | 133 | cfg = Config.fromfile(args.config) 134 | 135 | if args.cfg_options is not None: 136 | for k, v in args.cfg_options.items(): 137 | args.cfg_options[k] = eval(v) 138 | cfg.merge_from_dict(args.cfg_options) 139 | if args.vis: 140 | cfg.model.vis = True 141 | # import modules from string list. 142 | if cfg.get('custom_imports', None): 143 | from mmcv.utils import import_modules_from_strings 144 | import_modules_from_strings(**cfg['custom_imports']) 145 | # set cudnn_benchmark 146 | if cfg.get('cudnn_benchmark', False): 147 | torch.backends.cudnn.benchmark = True 148 | 149 | cfg.model.pretrained = None 150 | if cfg.model.get('neck'): 151 | if isinstance(cfg.model.neck, list): 152 | for neck_cfg in cfg.model.neck: 153 | if neck_cfg.get('rfp_backbone'): 154 | if neck_cfg.rfp_backbone.get('pretrained'): 155 | neck_cfg.rfp_backbone.pretrained = None 156 | elif cfg.model.neck.get('rfp_backbone'): 157 | if cfg.model.neck.rfp_backbone.get('pretrained'): 158 | cfg.model.neck.rfp_backbone.pretrained = None 159 | 160 | # in case the test dataset is concatenated 161 | samples_per_gpu = 1 162 | if isinstance(cfg.data.test, dict): 163 | cfg.data.test.test_mode = True 164 | samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) 165 | if samples_per_gpu > 1: 166 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 167 | cfg.data.test.pipeline = replace_ImageToTensor( 168 | cfg.data.test.pipeline) 169 | elif isinstance(cfg.data.test, list): 170 | for ds_cfg in cfg.data.test: 171 | ds_cfg.test_mode = True 172 | samples_per_gpu = max( 173 | [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test]) 174 | if samples_per_gpu > 1: 175 | for ds_cfg in cfg.data.test: 176 | ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) 177 | 178 | # init distributed env first, since logger depends on the dist info. 179 | if args.launcher == 'none': 180 | distributed = False 181 | else: 182 | distributed = True 183 | init_dist(args.launcher, **cfg.dist_params) 184 | 185 | rank, _ = get_dist_info() 186 | # allows not to create 187 | if args.work_dir is not None and rank == 0: 188 | mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) 189 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 190 | json_file = osp.join(args.work_dir, f'eval_{timestamp}.json') 191 | 192 | # build the dataloader 193 | dataset = build_dataset(cfg.data.test) 194 | data_loader = build_point_dataloader( 195 | dataset, 196 | samples_per_gpu=samples_per_gpu, 197 | workers_per_gpu=cfg.data.workers_per_gpu, 198 | dist=distributed, 199 | shuffle=False) 200 | 201 | # build the model and load checkpoint 202 | cfg.model.train_cfg = None 203 | model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) 204 | fp16_cfg = cfg.get('fp16', None) 205 | if fp16_cfg is not None: 206 | wrap_fp16_model(model) 207 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') 208 | if args.fuse_conv_bn: 209 | model = fuse_conv_bn(model) 210 | # old versions did not save class info in checkpoints, this walkaround is 211 | # for backward compatibility 212 | if 'CLASSES' in checkpoint.get('meta', {}): 213 | model.CLASSES = checkpoint['meta']['CLASSES'] 214 | else: 215 | model.CLASSES = dataset.CLASSES 216 | 217 | if not distributed: 218 | model = MMDataParallel(model, device_ids=[0]) 219 | outputs = pointdet_single_gpu_test(model, data_loader, args.show, 220 | args.show_dir, args.show_score_thr) 221 | else: 222 | model = MMDistributedDataParallel( 223 | model.cuda(), 224 | device_ids=[torch.cuda.current_device()], 225 | broadcast_buffers=False) 226 | outputs = pointdet_multi_gpu_test(model, data_loader, args.tmpdir, 227 | args.gpu_collect) 228 | 229 | rank, _ = get_dist_info() 230 | if rank == 0: 231 | if args.out: 232 | print(f'\nwriting results to {args.out}') 233 | mmcv.dump(outputs, args.out) 234 | kwargs = {} if args.eval_options is None else args.eval_options 235 | if args.format_only: 236 | outputs = [item[2] for item in outputs] 237 | dataset.format_results(outputs, **kwargs) 238 | if args.eval: 239 | eval_kwargs = cfg.get('evaluation', {}).copy() 240 | # hard-code way to remove EvalHook args 241 | for key in [ 242 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 243 | 'rule' 244 | ]: 245 | eval_kwargs.pop(key, None) 246 | eval_kwargs.update(dict(metric=args.eval, **kwargs)) 247 | metric = dataset.evaluate(outputs, **eval_kwargs) 248 | print(metric) 249 | metric_dict = dict(config=args.config, metric=metric) 250 | if args.work_dir is not None and rank == 0: 251 | mmcv.dump(metric_dict, json_file) 252 | 253 | 254 | if __name__ == '__main__': 255 | main() 256 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Group R-CNN 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | # Written by Shilong Zhang 5 | # -------------------------------------------------------- 6 | import argparse 7 | import copy 8 | import os 9 | import os.path as osp 10 | import time 11 | import warnings 12 | 13 | import mmcv 14 | import torch 15 | from mmcv import Config, DictAction 16 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 17 | from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, 18 | Fp16OptimizerHook, OptimizerHook, build_optimizer, 19 | build_runner, get_dist_info, init_dist) 20 | from mmcv.utils import build_from_cfg, get_git_hash 21 | from mmdet import __version__ 22 | from mmdet.apis import set_random_seed 23 | from mmdet.datasets import (build_dataloader, build_dataset, 24 | replace_ImageToTensor) 25 | from mmdet.models import build_detector 26 | from mmdet.utils import collect_env, get_root_logger 27 | 28 | from projects.api.test import PointdetDistEvalHook, PointdetEvalHook 29 | from projects.datasets import * # noqa 30 | from projects.datasets.builder import build_point_dataloader 31 | from projects.models import * # noqa 32 | 33 | 34 | def train_detector(model, 35 | dataset, 36 | cfg, 37 | distributed=False, 38 | validate=False, 39 | timestamp=None, 40 | meta=None): 41 | logger = get_root_logger(log_level=cfg.log_level) 42 | # prepare data loaders 43 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 44 | if 'imgs_per_gpu' in cfg.data: 45 | logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. ' 46 | 'Please use "samples_per_gpu" instead') 47 | if 'samples_per_gpu' in cfg.data: 48 | logger.warning( 49 | f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' 50 | f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' 51 | f'={cfg.data.imgs_per_gpu} is used in this experiments') 52 | else: 53 | logger.warning( 54 | 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' 55 | f'{cfg.data.imgs_per_gpu} in this experiments') 56 | cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu 57 | 58 | data_loaders = [ 59 | build_point_dataloader( 60 | ds, 61 | cfg.data.samples_per_gpu, 62 | cfg.data.workers_per_gpu, 63 | # cfg.gpus will be ignored if distributed 64 | len(cfg.gpu_ids), 65 | dist=distributed, 66 | seed=cfg.seed) for ds in dataset 67 | ] 68 | 69 | # put model on gpus 70 | if distributed: 71 | find_unused_parameters = cfg.get('find_unused_parameters', False) 72 | # Sets the `find_unused_parameters` parameter in 73 | # torch.nn.parallel.DistributedDataParallel 74 | model = MMDistributedDataParallel( 75 | model.cuda(), 76 | device_ids=[torch.cuda.current_device()], 77 | broadcast_buffers=False, 78 | find_unused_parameters=find_unused_parameters) 79 | else: 80 | model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), 81 | device_ids=cfg.gpu_ids) 82 | 83 | # build runner 84 | optimizer = build_optimizer(model, cfg.optimizer) 85 | 86 | if 'runner' not in cfg: 87 | cfg.runner = { 88 | 'type': 'EpochBasedRunner', 89 | 'max_epochs': cfg.total_epochs 90 | } 91 | warnings.warn( 92 | 'config is now expected to have a `runner` section, ' 93 | 'please set `runner` in your config.', UserWarning) 94 | else: 95 | if 'total_epochs' in cfg: 96 | assert cfg.total_epochs == cfg.runner.max_epochs 97 | 98 | runner = build_runner(cfg.runner, 99 | default_args=dict(model=model, 100 | optimizer=optimizer, 101 | work_dir=cfg.work_dir, 102 | logger=logger, 103 | meta=meta)) 104 | 105 | # an ugly workaround to make .log and .log.json filenames the same 106 | runner.timestamp = timestamp 107 | 108 | # fp16 setting 109 | fp16_cfg = cfg.get('fp16', None) 110 | if fp16_cfg is not None: 111 | optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config, 112 | **fp16_cfg, 113 | distributed=distributed) 114 | elif distributed and 'type' not in cfg.optimizer_config: 115 | optimizer_config = OptimizerHook(**cfg.optimizer_config) 116 | else: 117 | optimizer_config = cfg.optimizer_config 118 | 119 | # register hooks 120 | runner.register_training_hooks(cfg.lr_config, optimizer_config, 121 | cfg.checkpoint_config, cfg.log_config, 122 | cfg.get('momentum_config', None)) 123 | if distributed: 124 | if isinstance(runner, EpochBasedRunner): 125 | runner.register_hook(DistSamplerSeedHook()) 126 | 127 | # register eval hooks 128 | if validate: 129 | # Support batch_size > 1 in validation 130 | val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1) 131 | if val_samples_per_gpu > 1: 132 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 133 | cfg.data.val.pipeline = replace_ImageToTensor( 134 | cfg.data.val.pipeline) 135 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 136 | val_dataloader = build_dataloader( 137 | val_dataset, 138 | samples_per_gpu=val_samples_per_gpu, 139 | workers_per_gpu=cfg.data.workers_per_gpu, 140 | dist=distributed, 141 | shuffle=False) 142 | eval_cfg = cfg.get('evaluation', {}) 143 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 144 | eval_hook = PointdetDistEvalHook if distributed else PointdetEvalHook 145 | # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the 146 | # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'. 147 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg), 148 | priority='LOW') 149 | 150 | # user-defined hooks 151 | if cfg.get('custom_hooks', None): 152 | custom_hooks = cfg.custom_hooks 153 | assert isinstance(custom_hooks, list), \ 154 | f'custom_hooks expect list type, but got {type(custom_hooks)}' 155 | for hook_cfg in cfg.custom_hooks: 156 | assert isinstance(hook_cfg, dict), \ 157 | 'Each item in custom_hooks expects dict type, but got ' \ 158 | f'{type(hook_cfg)}' 159 | hook_cfg = hook_cfg.copy() 160 | priority = hook_cfg.pop('priority', 'NORMAL') 161 | hook = build_from_cfg(hook_cfg, HOOKS) 162 | runner.register_hook(hook, priority=priority) 163 | 164 | if cfg.resume_from: 165 | runner.resume(cfg.resume_from) 166 | elif cfg.load_from: 167 | runner.load_checkpoint(cfg.load_from) 168 | runner.run(data_loaders, cfg.workflow) 169 | 170 | 171 | def parse_args(): 172 | parser = argparse.ArgumentParser(description='Train a detector') 173 | parser.add_argument('config', help='train config file path') 174 | parser.add_argument('--work-dir', help='the dir to save logs and models') 175 | parser.add_argument('--resume-from', 176 | help='the checkpoint file to resume from') 177 | parser.add_argument( 178 | '--no-validate', 179 | action='store_true', 180 | help='whether not to evaluate the checkpoint during training') 181 | group_gpus = parser.add_mutually_exclusive_group() 182 | group_gpus.add_argument('--gpus', 183 | type=int, 184 | help='number of gpus to use ' 185 | '(only applicable to non-distributed training)') 186 | group_gpus.add_argument('--gpu-ids', 187 | type=int, 188 | nargs='+', 189 | help='ids of gpus to use ' 190 | '(only applicable to non-distributed training)') 191 | parser.add_argument('--seed', type=int, default=None, help='random seed') 192 | parser.add_argument('--port', type=int, default=20001, help='random seed') 193 | parser.add_argument( 194 | '--deterministic', 195 | action='store_true', 196 | help='whether to set deterministic options for CUDNN backend.') 197 | parser.add_argument( 198 | '--options', 199 | nargs='+', 200 | action=DictAction, 201 | help='override some settings in the used config, the key-value pair ' 202 | 'in xxx=yyy format will be merged into config file (deprecate), ' 203 | 'change to --cfg-options instead.') 204 | 205 | parser.add_argument( 206 | '--cfg-options', 207 | nargs='+', 208 | action=DictAction, 209 | help='override some settings in the used config, the key-value pair ' 210 | 'in xxx=yyy format will be merged into config file. If the value to ' 211 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 212 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 213 | 'Note that the quotation marks are necessary and that no white space ' 214 | 'is allowed.') 215 | parser.add_argument('--launcher', 216 | choices=['none', 'pytorch', 'slurm', 'mpi'], 217 | default='none', 218 | help='job launcher') 219 | parser.add_argument('--local_rank', type=int, default=0) 220 | args = parser.parse_args() 221 | if 'LOCAL_RANK' not in os.environ: 222 | os.environ['LOCAL_RANK'] = str(args.local_rank) 223 | 224 | if args.options and args.cfg_options: 225 | raise ValueError( 226 | '--options and --cfg-options cannot be both ' 227 | 'specified, --options is deprecated in favor of --cfg-options') 228 | if args.options: 229 | warnings.warn('--options is deprecated in favor of --cfg-options') 230 | args.cfg_options = args.options 231 | 232 | return args 233 | 234 | 235 | def main(): 236 | args = parse_args() 237 | 238 | cfg = Config.fromfile(args.config) 239 | if args.cfg_options is not None: 240 | for k, v in args.cfg_options.items(): 241 | args.cfg_options[k] = eval(v) 242 | cfg.merge_from_dict(args.cfg_options) 243 | # import modules from string list. 244 | if cfg.get('custom_imports', None): 245 | from mmcv.utils import import_modules_from_strings 246 | import_modules_from_strings(**cfg['custom_imports']) 247 | # set cudnn_benchmark 248 | if cfg.get('cudnn_benchmark', False): 249 | torch.backends.cudnn.benchmark = True 250 | 251 | # work_dir is determined in this priority: CLI > segment in file > filename 252 | if args.work_dir is not None: 253 | # update configs according to CLI args if args.work_dir is not None 254 | cfg.work_dir = args.work_dir 255 | elif cfg.get('work_dir', None) is None: 256 | # use config filename as default work_dir if cfg.work_dir is None 257 | cfg.work_dir = osp.join('./work_dirs', 258 | osp.splitext(osp.basename(args.config))[0]) 259 | if args.resume_from is not None: 260 | cfg.resume_from = args.resume_from 261 | if args.gpu_ids is not None: 262 | cfg.gpu_ids = args.gpu_ids 263 | else: 264 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 265 | 266 | # init distributed env first, since logger depends on the dist info. 267 | if args.launcher == 'none': 268 | distributed = False 269 | else: 270 | distributed = True 271 | init_dist(args.launcher, **cfg.dist_params) 272 | # re-set gpu_ids with distributed training mode 273 | _, world_size = get_dist_info() 274 | cfg.gpu_ids = range(world_size) 275 | 276 | # create work_dir 277 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 278 | # dump config 279 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 280 | # init the logger before other steps 281 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 282 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 283 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 284 | 285 | # init the meta dict to record some important information such as 286 | # environment info and seed, which will be logged 287 | meta = dict() 288 | # log env info 289 | env_info_dict = collect_env() 290 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) 291 | dash_line = '-' * 60 + '\n' 292 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 293 | dash_line) 294 | meta['env_info'] = env_info 295 | meta['config'] = cfg.pretty_text 296 | # log some basic info 297 | logger.info(f'Distributed training: {distributed}') 298 | logger.info(f'Config:\n{cfg.pretty_text}') 299 | 300 | # set random seeds 301 | if args.seed is not None: 302 | logger.info(f'Set random seed to {args.seed}, ' 303 | f'deterministic: {args.deterministic}') 304 | set_random_seed(args.seed, deterministic=args.deterministic) 305 | cfg.seed = args.seed 306 | meta['seed'] = args.seed 307 | meta['exp_name'] = osp.basename(args.config) 308 | 309 | model = build_detector(cfg.model, 310 | train_cfg=cfg.get('train_cfg'), 311 | test_cfg=cfg.get('test_cfg')) 312 | model.init_weights() 313 | 314 | datasets = [build_dataset(cfg.data.train)] 315 | if len(cfg.workflow) == 2: 316 | val_dataset = copy.deepcopy(cfg.data.val) 317 | val_dataset.pipeline = cfg.data.train.pipeline 318 | datasets.append(build_dataset(val_dataset)) 319 | if cfg.checkpoint_config is not None: 320 | # save mmdet version, config file content and class names in 321 | # checkpoints as meta data 322 | cfg.checkpoint_config.meta = dict(mmdet_version=__version__ + 323 | get_git_hash()[:7], 324 | CLASSES=datasets[0].CLASSES) 325 | # add an attribute for visualization convenience 326 | model.CLASSES = datasets[0].CLASSES 327 | train_detector(model, 328 | datasets, 329 | cfg, 330 | distributed=distributed, 331 | validate=(not args.no_validate), 332 | timestamp=timestamp, 333 | meta=meta) 334 | 335 | 336 | if __name__ == '__main__': 337 | main() 338 | --------------------------------------------------------------------------------