├── .gitignore ├── LICENSE ├── README.md ├── adet ├── __init__.py ├── checkpoint │ ├── __init__.py │ └── adet_checkpoint.py ├── config │ ├── __init__.py │ ├── config.py │ └── defaults.py ├── data │ ├── __init__.py │ ├── augmentation.py │ ├── builtin.py │ ├── dataset_mapper.py │ ├── datasets │ │ └── text.py │ └── detection_utils.py ├── evaluation │ ├── __init__.py │ ├── lexicon_procesor.py │ ├── rrc_evaluation_funcs.py │ ├── text_eval_script.py │ └── text_evaluation.py ├── layers │ ├── __init__.py │ ├── csrc │ │ ├── DeformAttn │ │ │ ├── ms_deform_attn.h │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ ├── ms_deform_attn_cpu.h │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ ├── ms_deform_attn_cuda.h │ │ │ └── ms_deform_im2col_cuda.cuh │ │ ├── cuda_version.cu │ │ └── vision.cpp │ ├── deformable_transformer.py │ ├── ms_deform_attn.py │ └── pos_encoding.py ├── modeling │ ├── __init__.py │ ├── testr │ │ ├── __init__.py │ │ ├── losses.py │ │ ├── matcher.py │ │ └── models.py │ └── transformer_detector.py └── utils │ ├── __init__.py │ ├── comm.py │ ├── misc.py │ └── visualizer.py ├── configs └── TESTR │ ├── Base-TESTR.yaml │ ├── CTW1500 │ ├── Base-CTW1500-Polygon.yaml │ ├── Base-CTW1500.yaml │ ├── TESTR_R_50.yaml │ └── TESTR_R_50_Polygon.yaml │ ├── ICDAR15 │ ├── Base-ICDAR15-Polygon.yaml │ └── TESTR_R_50_Polygon.yaml │ ├── Pretrain │ ├── Base-Pretrain-Polygon.yaml │ ├── Base-Pretrain.yaml │ ├── TESTR_R_50.yaml │ └── TESTR_R_50_Polygon.yaml │ └── TotalText │ ├── Base-TotalText-Polygon.yaml │ ├── Base-TotalText.yaml │ ├── TESTR_R_50.yaml │ └── TESTR_R_50_Polygon.yaml ├── demo ├── demo.py └── predictor.py ├── figures └── arch.svg ├── setup.py └── tools └── train_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | # output dir 2 | output 3 | instant_test_output 4 | inference_test_output 5 | 6 | 7 | *.jpg 8 | *.png 9 | *.txt 10 | 11 | # compilation and distribution 12 | __pycache__ 13 | _ext 14 | *.pyc 15 | *.so 16 | AdelaiDet.egg-info/ 17 | build/ 18 | dist/ 19 | 20 | # pytorch/python/numpy formats 21 | *.pth 22 | *.pkl 23 | *.npy 24 | 25 | # ipython/jupyter notebooks 26 | *.ipynb 27 | **/.ipynb_checkpoints/ 28 | 29 | # Editor temporaries 30 | *.swn 31 | *.swo 32 | *.swp 33 | *~ 34 | 35 | # Pycharm editor settings 36 | .idea 37 | .vscode 38 | .python-version 39 | 40 | # project dirs 41 | /datasets/coco 42 | /datasets/lvis 43 | /datasets/pic 44 | /datasets/ytvos 45 | /models 46 | /demo_outputs 47 | /example_inputs 48 | /debug 49 | /weights 50 | /export 51 | eval.sh 52 | 53 | demo/performance.py 54 | demo/demo2.py 55 | train.sh 56 | benchmark.sh 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TESTR: Text Spotting Transformers 2 | 3 | This repository is the official implementations for the following paper: 4 | 5 | [Text Spotting Transformers](https://openaccess.thecvf.com/content/CVPR2022/html/Zhang_Text_Spotting_Transformers_CVPR_2022_paper.html) 6 | 7 | [Xiang Zhang](https://xzhang.dev), Yongwen Su, [Subarna Tripathi](https://subarnatripathi.github.io), and [Zhuowen Tu](https://pages.ucsd.edu/~ztu/), CVPR 2022 8 | 9 | 10 | 11 | ## Getting Started 12 | We use the following environment in our experiments. It's recommended to install the dependencies via Anaconda 13 | 14 | + CUDA 11.3 15 | + Python 3.8 16 | + PyTorch 1.10.1 17 | + Official Pre-Built Detectron2 18 | 19 | #### Installation 20 | 21 | Please refer to the **Installation** section of AdelaiDet: [README.md](https://github.com/aim-uofa/AdelaiDet/blob/master/README.md). 22 | 23 | If you have not installed Detectron2, following the official guide: [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/main/INSTALL.md). 24 | 25 | After that, build this repository with 26 | 27 | ```bash 28 | python setup.py build develop 29 | ``` 30 | 31 | #### Preparing Datasets 32 | 33 | Please download TotalText, CTW1500, MLT, and CurvedSynText150k according to the guide provided by AdelaiDet: [README.md](https://github.com/aim-uofa/AdelaiDet/blob/master/datasets/README.md). 34 | 35 | ICDAR2015 dataset can be download via [link](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/xiz102_ucsd_edu/EWgEM5BSRjBEua4B_qLrGR0BaombUL8K3d23ldXOb7wUNA?e=7VzH34). 36 | 37 | Extract all the datasets and make sure you organize them as follows 38 | 39 | ``` 40 | - datasets 41 | | - CTW1500 42 | | | - annotations 43 | | | - ctwtest_text_image 44 | | | - ctwtrain_text_image 45 | | - totaltext (or icdar2015) 46 | | | - test_images 47 | | | - train_images 48 | | | - test.json 49 | | | - train.json 50 | | - mlt2017 (or syntext1, syntext2) 51 | | - annotations 52 | | - images 53 | ``` 54 | 55 | After that, download [polygonal annotations](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/xiz102_ucsd_edu/ES4aqkvamlJAgiPNFJuYkX4BLo-5cDx9TD_6pnMJnVhXpw?e=tu9D8t), along with [evaluation files](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/xiz102_ucsd_edu/Ea5oF7VFoe5NngUoPmLTerQBMdiVUhHcx2pPu3Q5p3hZvg?e=2NJNWh) and extract them under `datasets` folder. 56 | 57 | #### Visualization Demo 58 | 59 | You can try to visualize the predictions of the network using the following command: 60 | 61 | ```bash 62 | python demo/demo.py --config-file --input --output --opts MODEL.WEIGHTS MODEL.TRANSFORMER.INFERENCE_TH_TEST 0.3 63 | ``` 64 | 65 | You may want to adjust `INFERENCE_TH_TEST` to filter out predictions with lower scores. 66 | 67 | #### Training 68 | 69 | You can train from scratch or finetune the model by putting pretrained weights in `weights` folder. 70 | 71 | Example commands: 72 | 73 | ```bash 74 | python tools/train_net.py --config-file --num-gpus 8 75 | ``` 76 | 77 | All configuration files can be found in `configs/TESTR`, excluding those files named `Base-xxxx.yaml`. 78 | 79 | `TESTR_R_50.yaml` is the config for TESTR-Bezier, while `TESTR_R_50_Polygon.yaml` is for TESTR-Polygon. 80 | 81 | #### Evaluation 82 | 83 | ```bash 84 | python tools/train_net.py --config-file --eval-only MODEL.WEIGHTS 85 | ``` 86 | 87 | ## Pretrained Models 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 |
DatasetAnnotation TypeLexiconDet-PDet-RDet-FE2E-PE2E-RE2E-FLink
PretrainBezierNone88.8776.4782.2063.5856.9260.06OneDrive
PolygonalNone88.1877.5182.5066.1961.1463.57OneDrive
TotalTextBezierNone92.8383.6588.0074.2669.0571.56OneDrive
Full---86.4280.3583.28
PolygonalNone93.3681.3586.9476.8569.9873.25OneDrive
Full---88.0080.1383.88
CTW1500BezierNone89.7183.0786.2755.4451.3453.31OneDrive
Full---83.0576.9079.85
PolygonalNone92.0482.6387.0859.1453.0955.95OneDrive
Full---86.1677.3481.51
ICDAR15PolygonalNone90.3189.7090.0065.4965.0565.27OneDrive
Strong---87.1183.2985.16
Weak---80.3678.3879.36
Generic---73.8273.3373.57
251 | 252 | The `Lite` models only use the image feature from the last stage of ResNet. 253 | 254 | | Method | Annotation Type | Lexicon | Det-P | Det-R | Det-F | E2E-P | E2E-R | E2E-F | Link | 255 | | ---------------- | --------------- | ------- | ------ | ------ | ------ | ------ | ------ | ------ | ------------------------------------------------------------ | 256 | | Pretrain (Lite) | Polygonal | None | 90.28 | 72.58 | 80.47 | 59.49 | 50.22 | 54.46 | [OneDrive](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/xiz102_ucsd_edu/EcG-WKHN7dlNnzUJ5g301goBNknB-_IyADfVoW9q8efIIA?e=8dHW3K) | 257 | | TotalText (Lite) | Polygonal | None | 92.16 | 79.09 | 85.12 | 66.42 | 59.06 | 62.52 | [OneDrive](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/xiz102_ucsd_edu/ETL5VCes0eJBuktGqHqGu_wBltwbngIhqmqePIWfaWgGxw?e=hDkana) | 258 | 259 | ## Citation 260 | ``` 261 | @InProceedings{Zhang_2022_CVPR, 262 | author = {Zhang, Xiang and Su, Yongwen and Tripathi, Subarna and Tu, Zhuowen}, 263 | title = {Text Spotting Transformers}, 264 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 265 | month = {June}, 266 | year = {2022}, 267 | pages = {9519-9528} 268 | } 269 | ``` 270 | 271 | ## License 272 | This repository is released under the Apache License 2.0. License can be found in [LICENSE](LICENSE) file. 273 | 274 | 275 | ## Acknowledgement 276 | 277 | Thanks to [AdelaiDet](https://github.com/aim-uofa/AdelaiDet) for a standardized training and inference framework, and [Deformable-DETR](https://github.com/fundamentalvision/Deformable-DETR) for the implementation of multi-scale deformable cross-attention. 278 | -------------------------------------------------------------------------------- /adet/__init__.py: -------------------------------------------------------------------------------- 1 | from adet import modeling 2 | 3 | __version__ = "0.1.1" 4 | -------------------------------------------------------------------------------- /adet/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | from .adet_checkpoint import AdetCheckpointer 2 | 3 | __all__ = ["AdetCheckpointer"] 4 | -------------------------------------------------------------------------------- /adet/checkpoint/adet_checkpoint.py: -------------------------------------------------------------------------------- 1 | import pickle, os 2 | from fvcore.common.file_io import PathManager 3 | from detectron2.checkpoint import DetectionCheckpointer 4 | 5 | 6 | class AdetCheckpointer(DetectionCheckpointer): 7 | """ 8 | Same as :class:`DetectronCheckpointer`, but is able to convert models 9 | in AdelaiDet, such as LPF backbone. 10 | """ 11 | def _load_file(self, filename): 12 | if filename.endswith(".pkl"): 13 | with PathManager.open(filename, "rb") as f: 14 | data = pickle.load(f, encoding="latin1") 15 | if "model" in data and "__author__" in data: 16 | # file is in Detectron2 model zoo format 17 | self.logger.info("Reading a file from '{}'".format(data["__author__"])) 18 | return data 19 | else: 20 | # assume file is from Caffe2 / Detectron1 model zoo 21 | if "blobs" in data: 22 | # Detection models have "blobs", but ImageNet models don't 23 | data = data["blobs"] 24 | data = {k: v for k, v in data.items() if not k.endswith("_momentum")} 25 | if "weight_order" in data: 26 | del data["weight_order"] 27 | return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} 28 | 29 | loaded = super()._load_file(filename) # load native pth checkpoint 30 | if "model" not in loaded: 31 | loaded = {"model": loaded} 32 | 33 | basename = os.path.basename(filename).lower() 34 | if "lpf" in basename or "dla" in basename: 35 | loaded["matching_heuristics"] = True 36 | return loaded 37 | -------------------------------------------------------------------------------- /adet/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import get_cfg 2 | 3 | __all__ = [ 4 | "get_cfg", 5 | ] 6 | -------------------------------------------------------------------------------- /adet/config/config.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode 2 | 3 | 4 | def get_cfg() -> CfgNode: 5 | """ 6 | Get a copy of the default config. 7 | 8 | Returns: 9 | a detectron2 CfgNode instance. 10 | """ 11 | from .defaults import _C 12 | 13 | return _C.clone() 14 | -------------------------------------------------------------------------------- /adet/config/defaults.py: -------------------------------------------------------------------------------- 1 | from detectron2.config.defaults import _C 2 | from detectron2.config import CfgNode as CN 3 | 4 | 5 | # ---------------------------------------------------------------------------- # 6 | # Additional Configs 7 | # ---------------------------------------------------------------------------- # 8 | _C.MODEL.MOBILENET = False 9 | _C.MODEL.BACKBONE.ANTI_ALIAS = False 10 | _C.MODEL.RESNETS.DEFORM_INTERVAL = 1 11 | _C.INPUT.HFLIP_TRAIN = True 12 | _C.INPUT.CROP.CROP_INSTANCE = True 13 | 14 | # ---------------------------------------------------------------------------- # 15 | # FCOS Head 16 | # ---------------------------------------------------------------------------- # 17 | _C.MODEL.FCOS = CN() 18 | 19 | # This is the number of foreground classes. 20 | _C.MODEL.FCOS.NUM_CLASSES = 80 21 | _C.MODEL.FCOS.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"] 22 | _C.MODEL.FCOS.FPN_STRIDES = [8, 16, 32, 64, 128] 23 | _C.MODEL.FCOS.PRIOR_PROB = 0.01 24 | _C.MODEL.FCOS.INFERENCE_TH_TRAIN = 0.05 25 | _C.MODEL.FCOS.INFERENCE_TH_TEST = 0.05 26 | _C.MODEL.FCOS.NMS_TH = 0.6 27 | _C.MODEL.FCOS.PRE_NMS_TOPK_TRAIN = 1000 28 | _C.MODEL.FCOS.PRE_NMS_TOPK_TEST = 1000 29 | _C.MODEL.FCOS.POST_NMS_TOPK_TRAIN = 100 30 | _C.MODEL.FCOS.POST_NMS_TOPK_TEST = 100 31 | _C.MODEL.FCOS.TOP_LEVELS = 2 32 | _C.MODEL.FCOS.NORM = "GN" # Support GN or none 33 | _C.MODEL.FCOS.USE_SCALE = True 34 | 35 | # The options for the quality of box prediction 36 | # It can be "ctrness" (as described in FCOS paper) or "iou" 37 | # Using "iou" here generally has ~0.4 better AP on COCO 38 | # Note that for compatibility, we still use the term "ctrness" in the code 39 | _C.MODEL.FCOS.BOX_QUALITY = "ctrness" 40 | 41 | # Multiply centerness before threshold 42 | # This will affect the final performance by about 0.05 AP but save some time 43 | _C.MODEL.FCOS.THRESH_WITH_CTR = False 44 | 45 | # Focal loss parameters 46 | _C.MODEL.FCOS.LOSS_ALPHA = 0.25 47 | _C.MODEL.FCOS.LOSS_GAMMA = 2.0 48 | 49 | # The normalizer of the classification loss 50 | # The normalizer can be "fg" (normalized by the number of the foreground samples), 51 | # "moving_fg" (normalized by the MOVING number of the foreground samples), 52 | # or "all" (normalized by the number of all samples) 53 | _C.MODEL.FCOS.LOSS_NORMALIZER_CLS = "fg" 54 | _C.MODEL.FCOS.LOSS_WEIGHT_CLS = 1.0 55 | 56 | _C.MODEL.FCOS.SIZES_OF_INTEREST = [64, 128, 256, 512] 57 | _C.MODEL.FCOS.USE_RELU = True 58 | _C.MODEL.FCOS.USE_DEFORMABLE = False 59 | 60 | # the number of convolutions used in the cls and bbox tower 61 | _C.MODEL.FCOS.NUM_CLS_CONVS = 4 62 | _C.MODEL.FCOS.NUM_BOX_CONVS = 4 63 | _C.MODEL.FCOS.NUM_SHARE_CONVS = 0 64 | _C.MODEL.FCOS.CENTER_SAMPLE = True 65 | _C.MODEL.FCOS.POS_RADIUS = 1.5 66 | _C.MODEL.FCOS.LOC_LOSS_TYPE = 'giou' 67 | _C.MODEL.FCOS.YIELD_PROPOSAL = False 68 | _C.MODEL.FCOS.YIELD_BOX_FEATURES = False 69 | 70 | # ---------------------------------------------------------------------------- # 71 | # VoVNet backbone 72 | # ---------------------------------------------------------------------------- # 73 | _C.MODEL.VOVNET = CN() 74 | _C.MODEL.VOVNET.CONV_BODY = "V-39-eSE" 75 | _C.MODEL.VOVNET.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"] 76 | 77 | # Options: FrozenBN, GN, "SyncBN", "BN" 78 | _C.MODEL.VOVNET.NORM = "FrozenBN" 79 | _C.MODEL.VOVNET.OUT_CHANNELS = 256 80 | _C.MODEL.VOVNET.BACKBONE_OUT_CHANNELS = 256 81 | 82 | # ---------------------------------------------------------------------------- # 83 | # DLA backbone 84 | # ---------------------------------------------------------------------------- # 85 | 86 | _C.MODEL.DLA = CN() 87 | _C.MODEL.DLA.CONV_BODY = "DLA34" 88 | _C.MODEL.DLA.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"] 89 | 90 | # Options: FrozenBN, GN, "SyncBN", "BN" 91 | _C.MODEL.DLA.NORM = "FrozenBN" 92 | 93 | # ---------------------------------------------------------------------------- # 94 | # BAText Options 95 | # ---------------------------------------------------------------------------- # 96 | _C.MODEL.BATEXT = CN() 97 | _C.MODEL.BATEXT.VOC_SIZE = 96 98 | _C.MODEL.BATEXT.NUM_CHARS = 25 99 | _C.MODEL.BATEXT.POOLER_RESOLUTION = (8, 32) 100 | _C.MODEL.BATEXT.IN_FEATURES = ["p2", "p3", "p4"] 101 | _C.MODEL.BATEXT.POOLER_SCALES = (0.25, 0.125, 0.0625) 102 | _C.MODEL.BATEXT.SAMPLING_RATIO = 1 103 | _C.MODEL.BATEXT.CONV_DIM = 256 104 | _C.MODEL.BATEXT.NUM_CONV = 2 105 | _C.MODEL.BATEXT.RECOGNITION_LOSS = "ctc" 106 | _C.MODEL.BATEXT.RECOGNIZER = "attn" 107 | _C.MODEL.BATEXT.CANONICAL_SIZE = 96 # largest min_size for level 3 (stride=8) 108 | _C.MODEL.BATEXT.USE_COORDCONV = False 109 | _C.MODEL.BATEXT.USE_AET = False 110 | _C.MODEL.BATEXT.CUSTOM_DICT = "" # Path to the class file. 111 | 112 | # ---------------------------------------------------------------------------- # 113 | # BlendMask Options 114 | # ---------------------------------------------------------------------------- # 115 | _C.MODEL.BLENDMASK = CN() 116 | _C.MODEL.BLENDMASK.ATTN_SIZE = 14 117 | _C.MODEL.BLENDMASK.TOP_INTERP = "bilinear" 118 | _C.MODEL.BLENDMASK.BOTTOM_RESOLUTION = 56 119 | _C.MODEL.BLENDMASK.POOLER_TYPE = "ROIAlignV2" 120 | _C.MODEL.BLENDMASK.POOLER_SAMPLING_RATIO = 1 121 | _C.MODEL.BLENDMASK.POOLER_SCALES = (0.25,) 122 | _C.MODEL.BLENDMASK.INSTANCE_LOSS_WEIGHT = 1.0 123 | _C.MODEL.BLENDMASK.VISUALIZE = False 124 | 125 | # ---------------------------------------------------------------------------- # 126 | # Basis Module Options 127 | # ---------------------------------------------------------------------------- # 128 | _C.MODEL.BASIS_MODULE = CN() 129 | _C.MODEL.BASIS_MODULE.NAME = "ProtoNet" 130 | _C.MODEL.BASIS_MODULE.NUM_BASES = 4 131 | _C.MODEL.BASIS_MODULE.LOSS_ON = False 132 | _C.MODEL.BASIS_MODULE.ANN_SET = "coco" 133 | _C.MODEL.BASIS_MODULE.CONVS_DIM = 128 134 | _C.MODEL.BASIS_MODULE.IN_FEATURES = ["p3", "p4", "p5"] 135 | _C.MODEL.BASIS_MODULE.NORM = "SyncBN" 136 | _C.MODEL.BASIS_MODULE.NUM_CONVS = 3 137 | _C.MODEL.BASIS_MODULE.COMMON_STRIDE = 8 138 | _C.MODEL.BASIS_MODULE.NUM_CLASSES = 80 139 | _C.MODEL.BASIS_MODULE.LOSS_WEIGHT = 0.3 140 | 141 | # ---------------------------------------------------------------------------- # 142 | # MEInst Head 143 | # ---------------------------------------------------------------------------- # 144 | _C.MODEL.MEInst = CN() 145 | 146 | # This is the number of foreground classes. 147 | _C.MODEL.MEInst.NUM_CLASSES = 80 148 | _C.MODEL.MEInst.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"] 149 | _C.MODEL.MEInst.FPN_STRIDES = [8, 16, 32, 64, 128] 150 | _C.MODEL.MEInst.PRIOR_PROB = 0.01 151 | _C.MODEL.MEInst.INFERENCE_TH_TRAIN = 0.05 152 | _C.MODEL.MEInst.INFERENCE_TH_TEST = 0.05 153 | _C.MODEL.MEInst.NMS_TH = 0.6 154 | _C.MODEL.MEInst.PRE_NMS_TOPK_TRAIN = 1000 155 | _C.MODEL.MEInst.PRE_NMS_TOPK_TEST = 1000 156 | _C.MODEL.MEInst.POST_NMS_TOPK_TRAIN = 100 157 | _C.MODEL.MEInst.POST_NMS_TOPK_TEST = 100 158 | _C.MODEL.MEInst.TOP_LEVELS = 2 159 | _C.MODEL.MEInst.NORM = "GN" # Support GN or none 160 | _C.MODEL.MEInst.USE_SCALE = True 161 | 162 | # Multiply centerness before threshold 163 | # This will affect the final performance by about 0.05 AP but save some time 164 | _C.MODEL.MEInst.THRESH_WITH_CTR = False 165 | 166 | # Focal loss parameters 167 | _C.MODEL.MEInst.LOSS_ALPHA = 0.25 168 | _C.MODEL.MEInst.LOSS_GAMMA = 2.0 169 | _C.MODEL.MEInst.SIZES_OF_INTEREST = [64, 128, 256, 512] 170 | _C.MODEL.MEInst.USE_RELU = True 171 | _C.MODEL.MEInst.USE_DEFORMABLE = False 172 | _C.MODEL.MEInst.LAST_DEFORMABLE = False 173 | _C.MODEL.MEInst.TYPE_DEFORMABLE = "DCNv1" # or DCNv2. 174 | 175 | # the number of convolutions used in the cls and bbox tower 176 | _C.MODEL.MEInst.NUM_CLS_CONVS = 4 177 | _C.MODEL.MEInst.NUM_BOX_CONVS = 4 178 | _C.MODEL.MEInst.NUM_SHARE_CONVS = 0 179 | _C.MODEL.MEInst.CENTER_SAMPLE = True 180 | _C.MODEL.MEInst.POS_RADIUS = 1.5 181 | _C.MODEL.MEInst.LOC_LOSS_TYPE = 'giou' 182 | 183 | # ---------------------------------------------------------------------------- # 184 | # Mask Encoding 185 | # ---------------------------------------------------------------------------- # 186 | # Whether to use mask branch. 187 | _C.MODEL.MEInst.MASK_ON = True 188 | # IOU overlap ratios [IOU_THRESHOLD] 189 | # Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD) 190 | # Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD) 191 | _C.MODEL.MEInst.IOU_THRESHOLDS = [0.5] 192 | _C.MODEL.MEInst.IOU_LABELS = [0, 1] 193 | # Whether to use class_agnostic or class_specific. 194 | _C.MODEL.MEInst.AGNOSTIC = True 195 | # Some operations in mask encoding. 196 | _C.MODEL.MEInst.WHITEN = True 197 | _C.MODEL.MEInst.SIGMOID = True 198 | 199 | # The number of convolutions used in the mask tower. 200 | _C.MODEL.MEInst.NUM_MASK_CONVS = 4 201 | 202 | # The dim of mask before/after mask encoding. 203 | _C.MODEL.MEInst.DIM_MASK = 60 204 | _C.MODEL.MEInst.MASK_SIZE = 28 205 | # The default path for parameters of mask encoding. 206 | _C.MODEL.MEInst.PATH_COMPONENTS = "datasets/coco/components/" \ 207 | "coco_2017_train_class_agnosticTrue_whitenTrue_sigmoidTrue_60.npz" 208 | # An indicator for encoding parameters loading during training. 209 | _C.MODEL.MEInst.FLAG_PARAMETERS = False 210 | # The loss for mask branch, can be mse now. 211 | _C.MODEL.MEInst.MASK_LOSS_TYPE = "mse" 212 | 213 | # Whether to use gcn in mask prediction. 214 | # Large Kernel Matters -- https://arxiv.org/abs/1703.02719 215 | _C.MODEL.MEInst.USE_GCN_IN_MASK = False 216 | _C.MODEL.MEInst.GCN_KERNEL_SIZE = 9 217 | # Whether to compute loss on original mask (binary mask). 218 | _C.MODEL.MEInst.LOSS_ON_MASK = False 219 | 220 | # ---------------------------------------------------------------------------- # 221 | # CondInst Options 222 | # ---------------------------------------------------------------------------- # 223 | _C.MODEL.CONDINST = CN() 224 | 225 | # the downsampling ratio of the final instance masks to the input image 226 | _C.MODEL.CONDINST.MASK_OUT_STRIDE = 4 227 | _C.MODEL.CONDINST.BOTTOM_PIXELS_REMOVED = -1 228 | 229 | # if not -1, we only compute the mask loss for MAX_PROPOSALS random proposals PER GPU 230 | _C.MODEL.CONDINST.MAX_PROPOSALS = -1 231 | # if not -1, we only compute the mask loss for top `TOPK_PROPOSALS_PER_IM` proposals 232 | # PER IMAGE in terms of their detection scores 233 | _C.MODEL.CONDINST.TOPK_PROPOSALS_PER_IM = -1 234 | 235 | _C.MODEL.CONDINST.MASK_HEAD = CN() 236 | _C.MODEL.CONDINST.MASK_HEAD.CHANNELS = 8 237 | _C.MODEL.CONDINST.MASK_HEAD.NUM_LAYERS = 3 238 | _C.MODEL.CONDINST.MASK_HEAD.USE_FP16 = False 239 | _C.MODEL.CONDINST.MASK_HEAD.DISABLE_REL_COORDS = False 240 | 241 | _C.MODEL.CONDINST.MASK_BRANCH = CN() 242 | _C.MODEL.CONDINST.MASK_BRANCH.OUT_CHANNELS = 8 243 | _C.MODEL.CONDINST.MASK_BRANCH.IN_FEATURES = ["p3", "p4", "p5"] 244 | _C.MODEL.CONDINST.MASK_BRANCH.CHANNELS = 128 245 | _C.MODEL.CONDINST.MASK_BRANCH.NORM = "BN" 246 | _C.MODEL.CONDINST.MASK_BRANCH.NUM_CONVS = 4 247 | _C.MODEL.CONDINST.MASK_BRANCH.SEMANTIC_LOSS_ON = False 248 | 249 | # The options for BoxInst, which can train the instance segmentation model with box annotations only 250 | # Please refer to the paper https://arxiv.org/abs/2012.02310 251 | _C.MODEL.BOXINST = CN() 252 | # Whether to enable BoxInst 253 | _C.MODEL.BOXINST.ENABLED = False 254 | _C.MODEL.BOXINST.BOTTOM_PIXELS_REMOVED = 10 255 | 256 | _C.MODEL.BOXINST.PAIRWISE = CN() 257 | _C.MODEL.BOXINST.PAIRWISE.SIZE = 3 258 | _C.MODEL.BOXINST.PAIRWISE.DILATION = 2 259 | _C.MODEL.BOXINST.PAIRWISE.WARMUP_ITERS = 10000 260 | _C.MODEL.BOXINST.PAIRWISE.COLOR_THRESH = 0.3 261 | 262 | # ---------------------------------------------------------------------------- # 263 | # TOP Module Options 264 | # ---------------------------------------------------------------------------- # 265 | _C.MODEL.TOP_MODULE = CN() 266 | _C.MODEL.TOP_MODULE.NAME = "conv" 267 | _C.MODEL.TOP_MODULE.DIM = 16 268 | 269 | # ---------------------------------------------------------------------------- # 270 | # BiFPN options 271 | # ---------------------------------------------------------------------------- # 272 | 273 | _C.MODEL.BiFPN = CN() 274 | # Names of the input feature maps to be used by BiFPN 275 | # They must have contiguous power of 2 strides 276 | # e.g., ["res2", "res3", "res4", "res5"] 277 | _C.MODEL.BiFPN.IN_FEATURES = ["res2", "res3", "res4", "res5"] 278 | _C.MODEL.BiFPN.OUT_CHANNELS = 160 279 | _C.MODEL.BiFPN.NUM_REPEATS = 6 280 | 281 | # Options: "" (no norm), "GN" 282 | _C.MODEL.BiFPN.NORM = "" 283 | 284 | # ---------------------------------------------------------------------------- # 285 | # SOLOv2 Options 286 | # ---------------------------------------------------------------------------- # 287 | _C.MODEL.SOLOV2 = CN() 288 | 289 | # Instance hyper-parameters 290 | _C.MODEL.SOLOV2.INSTANCE_IN_FEATURES = ["p2", "p3", "p4", "p5", "p6"] 291 | _C.MODEL.SOLOV2.FPN_INSTANCE_STRIDES = [8, 8, 16, 32, 32] 292 | _C.MODEL.SOLOV2.FPN_SCALE_RANGES = ((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)) 293 | _C.MODEL.SOLOV2.SIGMA = 0.2 294 | # Channel size for the instance head. 295 | _C.MODEL.SOLOV2.INSTANCE_IN_CHANNELS = 256 296 | _C.MODEL.SOLOV2.INSTANCE_CHANNELS = 512 297 | # Convolutions to use in the instance head. 298 | _C.MODEL.SOLOV2.NUM_INSTANCE_CONVS = 4 299 | _C.MODEL.SOLOV2.USE_DCN_IN_INSTANCE = False 300 | _C.MODEL.SOLOV2.TYPE_DCN = 'DCN' 301 | _C.MODEL.SOLOV2.NUM_GRIDS = [40, 36, 24, 16, 12] 302 | # Number of foreground classes. 303 | _C.MODEL.SOLOV2.NUM_CLASSES = 80 304 | _C.MODEL.SOLOV2.NUM_KERNELS = 256 305 | _C.MODEL.SOLOV2.NORM = "GN" 306 | _C.MODEL.SOLOV2.USE_COORD_CONV = True 307 | _C.MODEL.SOLOV2.PRIOR_PROB = 0.01 308 | 309 | # Mask hyper-parameters. 310 | # Channel size for the mask tower. 311 | _C.MODEL.SOLOV2.MASK_IN_FEATURES = ["p2", "p3", "p4", "p5"] 312 | _C.MODEL.SOLOV2.MASK_IN_CHANNELS = 256 313 | _C.MODEL.SOLOV2.MASK_CHANNELS = 128 314 | _C.MODEL.SOLOV2.NUM_MASKS = 256 315 | 316 | # Test cfg. 317 | _C.MODEL.SOLOV2.NMS_PRE = 500 318 | _C.MODEL.SOLOV2.SCORE_THR = 0.1 319 | _C.MODEL.SOLOV2.UPDATE_THR = 0.05 320 | _C.MODEL.SOLOV2.MASK_THR = 0.5 321 | _C.MODEL.SOLOV2.MAX_PER_IMG = 100 322 | # NMS type: matrix OR mask. 323 | _C.MODEL.SOLOV2.NMS_TYPE = "matrix" 324 | # Matrix NMS kernel type: gaussian OR linear. 325 | _C.MODEL.SOLOV2.NMS_KERNEL = "gaussian" 326 | _C.MODEL.SOLOV2.NMS_SIGMA = 2 327 | 328 | # Loss cfg. 329 | _C.MODEL.SOLOV2.LOSS = CN() 330 | _C.MODEL.SOLOV2.LOSS.FOCAL_USE_SIGMOID = True 331 | _C.MODEL.SOLOV2.LOSS.FOCAL_ALPHA = 0.25 332 | _C.MODEL.SOLOV2.LOSS.FOCAL_GAMMA = 2.0 333 | _C.MODEL.SOLOV2.LOSS.FOCAL_WEIGHT = 1.0 334 | _C.MODEL.SOLOV2.LOSS.DICE_WEIGHT = 3.0 335 | 336 | 337 | # ---------------------------------------------------------------------------- # 338 | # (Deformable) Transformer Options 339 | # ---------------------------------------------------------------------------- # 340 | _C.MODEL.TRANSFORMER = CN() 341 | _C.MODEL.TRANSFORMER.USE_POLYGON = False 342 | _C.MODEL.TRANSFORMER.ENABLED = False 343 | _C.MODEL.TRANSFORMER.INFERENCE_TH_TEST = 0.45 344 | _C.MODEL.TRANSFORMER.VOC_SIZE = 96 345 | _C.MODEL.TRANSFORMER.NUM_CHARS = 25 346 | _C.MODEL.TRANSFORMER.AUX_LOSS = True 347 | _C.MODEL.TRANSFORMER.ENC_LAYERS = 6 348 | _C.MODEL.TRANSFORMER.DEC_LAYERS = 6 349 | _C.MODEL.TRANSFORMER.DIM_FEEDFORWARD = 1024 350 | _C.MODEL.TRANSFORMER.HIDDEN_DIM = 256 351 | _C.MODEL.TRANSFORMER.DROPOUT = 0.1 352 | _C.MODEL.TRANSFORMER.NHEADS = 8 353 | _C.MODEL.TRANSFORMER.NUM_QUERIES = 100 354 | _C.MODEL.TRANSFORMER.ENC_N_POINTS = 4 355 | _C.MODEL.TRANSFORMER.DEC_N_POINTS = 4 356 | _C.MODEL.TRANSFORMER.POSITION_EMBEDDING_SCALE = 6.283185307179586 # 2 PI 357 | _C.MODEL.TRANSFORMER.NUM_FEATURE_LEVELS = 4 358 | _C.MODEL.TRANSFORMER.NUM_CTRL_POINTS = 8 359 | 360 | _C.MODEL.TRANSFORMER.LOSS = CN() 361 | _C.MODEL.TRANSFORMER.LOSS.AUX_LOSS = True 362 | _C.MODEL.TRANSFORMER.LOSS.POINT_CLASS_WEIGHT = 2.0 363 | _C.MODEL.TRANSFORMER.LOSS.POINT_COORD_WEIGHT = 5.0 364 | _C.MODEL.TRANSFORMER.LOSS.POINT_TEXT_WEIGHT = 2.0 365 | _C.MODEL.TRANSFORMER.LOSS.BOX_CLASS_WEIGHT = 2.0 366 | _C.MODEL.TRANSFORMER.LOSS.BOX_COORD_WEIGHT = 5.0 367 | _C.MODEL.TRANSFORMER.LOSS.BOX_GIOU_WEIGHT = 2.0 368 | _C.MODEL.TRANSFORMER.LOSS.FOCAL_ALPHA = 0.25 369 | _C.MODEL.TRANSFORMER.LOSS.FOCAL_GAMMA = 2.0 370 | 371 | 372 | _C.SOLVER.OPTIMIZER = "ADAMW" 373 | _C.SOLVER.LR_BACKBONE = 1e-5 374 | _C.SOLVER.LR_BACKBONE_NAMES = [] 375 | _C.SOLVER.LR_LINEAR_PROJ_NAMES = [] 376 | _C.SOLVER.LR_LINEAR_PROJ_MULT = 0.1 377 | 378 | _C.TEST.USE_LEXICON = False 379 | # 1 - Generic, 2 - Weak, 3 - Strong (for icdar2015) 380 | # 1 - Full lexicon (for totaltext/ctw1500) 381 | _C.TEST.LEXICON_TYPE = 1 382 | _C.TEST.WEIGHTED_EDIT_DIST = False 383 | -------------------------------------------------------------------------------- /adet/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import builtin # ensure the builtin datasets are registered 2 | from .dataset_mapper import DatasetMapperWithBasis 3 | 4 | 5 | __all__ = ["DatasetMapperWithBasis"] 6 | -------------------------------------------------------------------------------- /adet/data/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from fvcore.transforms import transform as T 5 | 6 | from detectron2.data.transforms import RandomCrop, StandardAugInput 7 | from detectron2.structures import BoxMode 8 | 9 | 10 | def gen_crop_transform_with_instance(crop_size, image_size, instances, crop_box=True): 11 | """ 12 | Generate a CropTransform so that the cropping region contains 13 | the center of the given instance. 14 | 15 | Args: 16 | crop_size (tuple): h, w in pixels 17 | image_size (tuple): h, w 18 | instance (dict): an annotation dict of one instance, in Detectron2's 19 | dataset format. 20 | """ 21 | bbox = random.choice(instances) 22 | crop_size = np.asarray(crop_size, dtype=np.int32) 23 | center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5 24 | assert ( 25 | image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1] 26 | ), "The annotation bounding box is outside of the image!" 27 | assert ( 28 | image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1] 29 | ), "Crop size is larger than image size!" 30 | 31 | min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0) 32 | max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0) 33 | max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32)) 34 | 35 | y0 = np.random.randint(min_yx[0], max_yx[0] + 1) 36 | x0 = np.random.randint(min_yx[1], max_yx[1] + 1) 37 | 38 | # if some instance is cropped extend the box 39 | if not crop_box: 40 | num_modifications = 0 41 | modified = True 42 | 43 | # convert crop_size to float 44 | crop_size = crop_size.astype(np.float32) 45 | while modified: 46 | modified, x0, y0, crop_size = adjust_crop(x0, y0, crop_size, instances) 47 | num_modifications += 1 48 | if num_modifications > 100: 49 | raise ValueError( 50 | "Cannot finished cropping adjustment within 100 tries (#instances {}).".format( 51 | len(instances) 52 | ) 53 | ) 54 | return T.CropTransform(0, 0, image_size[1], image_size[0]) 55 | 56 | return T.CropTransform(*map(int, (x0, y0, crop_size[1], crop_size[0]))) 57 | 58 | 59 | def adjust_crop(x0, y0, crop_size, instances, eps=1e-3): 60 | modified = False 61 | 62 | x1 = x0 + crop_size[1] 63 | y1 = y0 + crop_size[0] 64 | 65 | for bbox in instances: 66 | 67 | if bbox[0] < x0 - eps and bbox[2] > x0 + eps: 68 | crop_size[1] += x0 - bbox[0] 69 | x0 = bbox[0] 70 | modified = True 71 | 72 | if bbox[0] < x1 - eps and bbox[2] > x1 + eps: 73 | crop_size[1] += bbox[2] - x1 74 | x1 = bbox[2] 75 | modified = True 76 | 77 | if bbox[1] < y0 - eps and bbox[3] > y0 + eps: 78 | crop_size[0] += y0 - bbox[1] 79 | y0 = bbox[1] 80 | modified = True 81 | 82 | if bbox[1] < y1 - eps and bbox[3] > y1 + eps: 83 | crop_size[0] += bbox[3] - y1 84 | y1 = bbox[3] 85 | modified = True 86 | 87 | return modified, x0, y0, crop_size 88 | 89 | 90 | class RandomCropWithInstance(RandomCrop): 91 | """ Instance-aware cropping. 92 | """ 93 | 94 | def __init__(self, crop_type, crop_size, crop_instance=True): 95 | """ 96 | Args: 97 | crop_instance (bool): if False, extend cropping boxes to avoid cropping instances 98 | """ 99 | super().__init__(crop_type, crop_size) 100 | self.crop_instance = crop_instance 101 | self.input_args = ("image", "boxes") 102 | 103 | def get_transform(self, img, boxes): 104 | image_size = img.shape[:2] 105 | crop_size = self.get_crop_size(image_size) 106 | return gen_crop_transform_with_instance( 107 | crop_size, image_size, boxes, crop_box=self.crop_instance 108 | ) 109 | -------------------------------------------------------------------------------- /adet/data/builtin.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from detectron2.data.datasets.register_coco import register_coco_instances 4 | from detectron2.data.datasets.builtin_meta import _get_builtin_metadata 5 | 6 | from .datasets.text import register_text_instances 7 | 8 | # register plane reconstruction 9 | 10 | _PREDEFINED_SPLITS_PIC = { 11 | "pic_person_train": ("pic/image/train", "pic/annotations/train_person.json"), 12 | "pic_person_val": ("pic/image/val", "pic/annotations/val_person.json"), 13 | } 14 | 15 | metadata_pic = { 16 | "thing_classes": ["person"] 17 | } 18 | 19 | _PREDEFINED_SPLITS_TEXT = { 20 | # datasets with bezier annotations 21 | "totaltext_train": ("totaltext/train_images", "totaltext/train.json"), 22 | "totaltext_val": ("totaltext/test_images", "totaltext/test.json"), 23 | "ctw1500_word_train": ("CTW1500/ctwtrain_text_image", "CTW1500/annotations/train_ctw1500_maxlen100_v2.json"), 24 | "ctw1500_word_test": ("CTW1500/ctwtest_text_image","CTW1500/annotations/test_ctw1500_maxlen100.json"), 25 | "syntext1_train": ("syntext1/images", "syntext1/annotations/train.json"), 26 | "syntext2_train": ("syntext2/images", "syntext2/annotations/train.json"), 27 | "mltbezier_word_train": ("mlt2017/images","mlt2017/annotations/train.json"), 28 | "rects_train": ("ReCTS/ReCTS_train_images", "ReCTS/annotations/rects_train.json"), 29 | "rects_val": ("ReCTS/ReCTS_val_images", "ReCTS/annotations/rects_val.json"), 30 | "rects_test": ("ReCTS/ReCTS_test_images", "ReCTS/annotations/rects_test.json"), 31 | "art_train": ("ArT/rename_artimg_train", "ArT/annotations/abcnet_art_train.json"), 32 | "lsvt_train": ("LSVT/rename_lsvtimg_train", "LSVT/annotations/abcnet_lsvt_train.json"), 33 | "chnsyn_train": ("ChnSyn/syn_130k_images", "ChnSyn/annotations/chn_syntext.json"), 34 | # datasets with polygon annotations 35 | "totaltext_poly_train": ("totaltext/train_images", "totaltext/train_poly.json"), 36 | "totaltext_poly_val": ("totaltext/test_images", "totaltext/test_poly.json"), 37 | "ctw1500_word_poly_train": ("CTW1500/ctwtrain_text_image", "CTW1500/annotations/train_poly.json"), 38 | "ctw1500_word_poly_test": ("CTW1500/ctwtest_text_image","CTW1500/annotations/test_poly.json"), 39 | "syntext1_poly_train": ("syntext1/images", "syntext1/annotations/train_poly.json"), 40 | "syntext2_poly_train": ("syntext2/images", "syntext2/annotations/train_poly.json"), 41 | "mltbezier_word_poly_train": ("mlt2017/images","mlt2017/annotations/train_poly.json"), 42 | "icdar2015_train": ("icdar2015/train_images", "icdar2015/train_poly.json"), 43 | "icdar2015_test": ("icdar2015/test_images", "icdar2015/test_poly.json"), 44 | "icdar2019_train": ("icdar2019/train_images", "icdar2019/train_poly.json"), 45 | "textocr_train": ("textocr/train_images", "textocr/annotations/train_poly.json"), 46 | "textocr_val": ("textocr/train_images", "textocr/annotations/val_poly.json"), 47 | } 48 | 49 | metadata_text = { 50 | "thing_classes": ["text"] 51 | } 52 | 53 | 54 | def register_all_coco(root="datasets"): 55 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_PIC.items(): 56 | # Assume pre-defined datasets live in `./datasets`. 57 | register_coco_instances( 58 | key, 59 | metadata_pic, 60 | os.path.join(root, json_file) if "://" not in json_file else json_file, 61 | os.path.join(root, image_root), 62 | ) 63 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_TEXT.items(): 64 | # Assume pre-defined datasets live in `./datasets`. 65 | register_text_instances( 66 | key, 67 | metadata_text, 68 | os.path.join(root, json_file) if "://" not in json_file else json_file, 69 | os.path.join(root, image_root), 70 | ) 71 | 72 | 73 | register_all_coco() 74 | -------------------------------------------------------------------------------- /adet/data/dataset_mapper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import torch 7 | from fvcore.common.file_io import PathManager 8 | from PIL import Image 9 | from pycocotools import mask as maskUtils 10 | 11 | from detectron2.data import detection_utils as utils 12 | from detectron2.data import transforms as T 13 | from detectron2.data.dataset_mapper import DatasetMapper 14 | from detectron2.data.detection_utils import SizeMismatchError 15 | from detectron2.structures import BoxMode 16 | 17 | from .augmentation import RandomCropWithInstance 18 | from .detection_utils import (annotations_to_instances, build_augmentation, 19 | transform_instance_annotations) 20 | 21 | """ 22 | This file contains the default mapping that's applied to "dataset dicts". 23 | """ 24 | 25 | __all__ = ["DatasetMapperWithBasis"] 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def segmToRLE(segm, img_size): 31 | h, w = img_size 32 | if type(segm) == list: 33 | # polygon -- a single object might consist of multiple parts 34 | # we merge all parts into one mask rle code 35 | rles = maskUtils.frPyObjects(segm, h, w) 36 | rle = maskUtils.merge(rles) 37 | elif type(segm["counts"]) == list: 38 | # uncompressed RLE 39 | rle = maskUtils.frPyObjects(segm, h, w) 40 | else: 41 | # rle 42 | rle = segm 43 | return rle 44 | 45 | 46 | def segmToMask(segm, img_size): 47 | rle = segmToRLE(segm, img_size) 48 | m = maskUtils.decode(rle) 49 | return m 50 | 51 | 52 | class DatasetMapperWithBasis(DatasetMapper): 53 | """ 54 | This caller enables the default Detectron2 mapper to read an additional basis semantic label 55 | """ 56 | 57 | def __init__(self, cfg, is_train=True): 58 | super().__init__(cfg, is_train) 59 | 60 | # Rebuild augmentations 61 | logger.info( 62 | "Rebuilding the augmentations. The previous augmentations will be overridden." 63 | ) 64 | self.augmentation = build_augmentation(cfg, is_train) 65 | 66 | if cfg.INPUT.CROP.ENABLED and is_train: 67 | self.augmentation.insert( 68 | 0, 69 | RandomCropWithInstance( 70 | cfg.INPUT.CROP.TYPE, 71 | cfg.INPUT.CROP.SIZE, 72 | cfg.INPUT.CROP.CROP_INSTANCE, 73 | ), 74 | ) 75 | logging.getLogger(__name__).info( 76 | "Cropping used in training: " + str(self.augmentation[0]) 77 | ) 78 | 79 | self.basis_loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON 80 | self.ann_set = cfg.MODEL.BASIS_MODULE.ANN_SET 81 | self.boxinst_enabled = cfg.MODEL.BOXINST.ENABLED 82 | 83 | if self.boxinst_enabled: 84 | self.use_instance_mask = False 85 | self.recompute_boxes = False 86 | 87 | def __call__(self, dataset_dict): 88 | """ 89 | Args: 90 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 91 | 92 | Returns: 93 | dict: a format that builtin models in detectron2 accept 94 | """ 95 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 96 | # USER: Write your own image loading if it's not from a file 97 | try: 98 | image = utils.read_image( 99 | dataset_dict["file_name"], format=self.image_format 100 | ) 101 | except Exception as e: 102 | print(dataset_dict["file_name"]) 103 | print(e) 104 | raise e 105 | try: 106 | utils.check_image_size(dataset_dict, image) 107 | except SizeMismatchError as e: 108 | expected_wh = (dataset_dict["width"], dataset_dict["height"]) 109 | image_wh = (image.shape[1], image.shape[0]) 110 | if (image_wh[1], image_wh[0]) == expected_wh: 111 | print("transposing image {}".format(dataset_dict["file_name"])) 112 | image = image.transpose(1, 0, 2) 113 | else: 114 | raise e 115 | 116 | # USER: Remove if you don't do semantic/panoptic segmentation. 117 | if "sem_seg_file_name" in dataset_dict: 118 | sem_seg_gt = utils.read_image( 119 | dataset_dict.pop("sem_seg_file_name"), "L" 120 | ).squeeze(2) 121 | else: 122 | sem_seg_gt = None 123 | 124 | boxes = np.asarray( 125 | [ 126 | BoxMode.convert( 127 | instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS 128 | ) 129 | for instance in dataset_dict["annotations"] 130 | ] 131 | ) 132 | aug_input = T.StandardAugInput(image, boxes=boxes, sem_seg=sem_seg_gt) 133 | transforms = aug_input.apply_augmentations(self.augmentation) 134 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg 135 | 136 | image_shape = image.shape[:2] # h, w 137 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 138 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 139 | # Therefore it's important to use torch.Tensor. 140 | dataset_dict["image"] = torch.as_tensor( 141 | np.ascontiguousarray(image.transpose(2, 0, 1)) 142 | ) 143 | if sem_seg_gt is not None: 144 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) 145 | 146 | # USER: Remove if you don't use pre-computed proposals. 147 | # Most users would not need this feature. 148 | if self.proposal_topk: 149 | utils.transform_proposals( 150 | dataset_dict, 151 | image_shape, 152 | transforms, 153 | proposal_topk=self.proposal_topk, 154 | min_box_size=self.proposal_min_box_size, 155 | ) 156 | 157 | if not self.is_train: 158 | dataset_dict.pop("annotations", None) 159 | dataset_dict.pop("sem_seg_file_name", None) 160 | dataset_dict.pop("pano_seg_file_name", None) 161 | return dataset_dict 162 | 163 | if "annotations" in dataset_dict: 164 | # USER: Modify this if you want to keep them for some reason. 165 | for anno in dataset_dict["annotations"]: 166 | if not self.use_instance_mask: 167 | anno.pop("segmentation", None) 168 | if not self.use_keypoint: 169 | anno.pop("keypoints", None) 170 | 171 | # USER: Implement additional transformations if you have other types of data 172 | annos = [ 173 | transform_instance_annotations( 174 | obj, 175 | transforms, 176 | image_shape, 177 | keypoint_hflip_indices=self.keypoint_hflip_indices, 178 | ) 179 | for obj in dataset_dict.pop("annotations") 180 | if obj.get("iscrowd", 0) == 0 181 | ] 182 | instances = annotations_to_instances( 183 | annos, image_shape, mask_format=self.instance_mask_format 184 | ) 185 | 186 | # After transforms such as cropping are applied, the bounding box may no longer 187 | # tightly bound the object. As an example, imagine a triangle object 188 | # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight 189 | # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to 190 | if self.recompute_boxes: 191 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes() 192 | dataset_dict["instances"] = utils.filter_empty_instances(instances) 193 | 194 | if self.basis_loss_on and self.is_train: 195 | # load basis supervisions 196 | if self.ann_set == "coco": 197 | basis_sem_path = ( 198 | dataset_dict["file_name"] 199 | .replace("train2017", "thing_train2017") 200 | .replace("image/train", "thing_train") 201 | ) 202 | else: 203 | basis_sem_path = ( 204 | dataset_dict["file_name"] 205 | .replace("coco", "lvis") 206 | .replace("train2017", "thing_train") 207 | ) 208 | # change extension to npz 209 | basis_sem_path = osp.splitext(basis_sem_path)[0] + ".npz" 210 | basis_sem_gt = np.load(basis_sem_path)["mask"] 211 | basis_sem_gt = transforms.apply_segmentation(basis_sem_gt) 212 | basis_sem_gt = torch.as_tensor(basis_sem_gt.astype("long")) 213 | dataset_dict["basis_sem"] = basis_sem_gt 214 | return dataset_dict 215 | -------------------------------------------------------------------------------- /adet/data/datasets/text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import contextlib 3 | import io 4 | import logging 5 | import os 6 | from fvcore.common.timer import Timer 7 | from fvcore.common.file_io import PathManager 8 | 9 | from detectron2.structures import BoxMode 10 | 11 | from detectron2.data import DatasetCatalog, MetadataCatalog 12 | 13 | """ 14 | This file contains functions to parse COCO-format text annotations into dicts in "Detectron2 format". 15 | """ 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | __all__ = ["load_text_json", "register_text_instances"] 21 | 22 | 23 | def register_text_instances(name, metadata, json_file, image_root): 24 | """ 25 | Register a dataset in json annotation format for text detection and recognition. 26 | 27 | Args: 28 | name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train". 29 | metadata (dict): extra metadata associated with this dataset. It can be an empty dict. 30 | json_file (str): path to the json instance annotation file. 31 | image_root (str or path-like): directory which contains all the images. 32 | """ 33 | DatasetCatalog.register(name, lambda: load_text_json(json_file, image_root, name)) 34 | MetadataCatalog.get(name).set( 35 | json_file=json_file, image_root=image_root, evaluator_type="text", **metadata 36 | ) 37 | 38 | 39 | def load_text_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None): 40 | """ 41 | Load a json file with totaltext annotation format. 42 | Currently supports text detection and recognition. 43 | 44 | Args: 45 | json_file (str): full path to the json file in totaltext annotation format. 46 | image_root (str or path-like): the directory where the images in this json file exists. 47 | dataset_name (str): the name of the dataset (e.g., coco_2017_train). 48 | If provided, this function will also put "thing_classes" into 49 | the metadata associated with this dataset. 50 | extra_annotation_keys (list[str]): list of per-annotation keys that should also be 51 | loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints", 52 | "category_id", "segmentation"). The values for these keys will be returned as-is. 53 | For example, the densepose annotations are loaded in this way. 54 | 55 | Returns: 56 | list[dict]: a list of dicts in Detectron2 standard dataset dicts format. (See 57 | `Using Custom Datasets `_ ) 58 | 59 | Notes: 60 | 1. This function does not read the image files. 61 | The results do not have the "image" field. 62 | """ 63 | from pycocotools.coco import COCO 64 | 65 | timer = Timer() 66 | json_file = PathManager.get_local_path(json_file) 67 | with contextlib.redirect_stdout(io.StringIO()): 68 | coco_api = COCO(json_file) 69 | if timer.seconds() > 1: 70 | logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) 71 | 72 | id_map = None 73 | if dataset_name is not None: 74 | meta = MetadataCatalog.get(dataset_name) 75 | cat_ids = sorted(coco_api.getCatIds()) 76 | cats = coco_api.loadCats(cat_ids) 77 | # The categories in a custom json file may not be sorted. 78 | thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])] 79 | meta.thing_classes = thing_classes 80 | 81 | # In COCO, certain category ids are artificially removed, 82 | # and by convention they are always ignored. 83 | # We deal with COCO's id issue and translate 84 | # the category ids to contiguous ids in [0, 80). 85 | 86 | # It works by looking at the "categories" field in the json, therefore 87 | # if users' own json also have incontiguous ids, we'll 88 | # apply this mapping as well but print a warning. 89 | if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)): 90 | if "coco" not in dataset_name: 91 | logger.warning( 92 | """ 93 | Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you. 94 | """ 95 | ) 96 | id_map = {v: i for i, v in enumerate(cat_ids)} 97 | meta.thing_dataset_id_to_contiguous_id = id_map 98 | 99 | # sort indices for reproducible results 100 | img_ids = sorted(coco_api.imgs.keys()) 101 | # imgs is a list of dicts, each looks something like: 102 | # {'license': 4, 103 | # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg', 104 | # 'file_name': 'COCO_val2014_000000001268.jpg', 105 | # 'height': 427, 106 | # 'width': 640, 107 | # 'date_captured': '2013-11-17 05:57:24', 108 | # 'id': 1268} 109 | imgs = coco_api.loadImgs(img_ids) 110 | # anns is a list[list[dict]], where each dict is an annotation 111 | # record for an object. The inner list enumerates the objects in an image 112 | # and the outer list enumerates over images. Example of anns[0]: 113 | # [{'segmentation': [[192.81, 114 | # 247.09, 115 | # ... 116 | # 219.03, 117 | # 249.06]], 118 | # 'area': 1035.749, 119 | # 'rec': [84, 72, ... 96], 120 | # 'bezier_pts': [169.0, 425.0, ..., ] 121 | # 'iscrowd': 0, 122 | # 'image_id': 1268, 123 | # 'bbox': [192.81, 224.8, 74.73, 33.43], 124 | # 'category_id': 16, 125 | # 'id': 42986}, 126 | # ...] 127 | anns = [coco_api.imgToAnns[img_id] for img_id in img_ids] 128 | 129 | if "minival" not in json_file: 130 | # The popular valminusminival & minival annotations for COCO2014 contain this bug. 131 | # However the ratio of buggy annotations there is tiny and does not affect accuracy. 132 | # Therefore we explicitly white-list them. 133 | ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] 134 | assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format( 135 | json_file 136 | ) 137 | 138 | imgs_anns = list(zip(imgs, anns)) 139 | 140 | logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file)) 141 | 142 | dataset_dicts = [] 143 | 144 | ann_keys = ["iscrowd", "bbox", "rec", "category_id"] + (extra_annotation_keys or []) 145 | 146 | num_instances_without_valid_segmentation = 0 147 | 148 | for (img_dict, anno_dict_list) in imgs_anns: 149 | record = {} 150 | record["file_name"] = os.path.join(image_root, img_dict["file_name"]) 151 | record["height"] = img_dict["height"] 152 | record["width"] = img_dict["width"] 153 | image_id = record["image_id"] = img_dict["id"] 154 | 155 | objs = [] 156 | for anno in anno_dict_list: 157 | # Check that the image_id in this annotation is the same as 158 | # the image_id we're looking at. 159 | # This fails only when the data parsing logic or the annotation file is buggy. 160 | 161 | # The original COCO valminusminival2014 & minival2014 annotation files 162 | # actually contains bugs that, together with certain ways of using COCO API, 163 | # can trigger this assertion. 164 | assert anno["image_id"] == image_id 165 | 166 | assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.' 167 | 168 | obj = {key: anno[key] for key in ann_keys if key in anno} 169 | 170 | segm = anno.get("segmentation", None) 171 | if segm: # either list[list[float]] or dict(RLE) 172 | if not isinstance(segm, dict): 173 | # filter out invalid polygons (< 3 points) 174 | segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] 175 | if len(segm) == 0: 176 | num_instances_without_valid_segmentation += 1 177 | continue # ignore this instance 178 | obj["segmentation"] = segm 179 | 180 | bezierpts = anno.get("bezier_pts", None) 181 | # Bezier Points are the control points for BezierAlign Text recognition (BAText) 182 | if bezierpts: # list[float] 183 | obj["beziers"] = bezierpts 184 | 185 | polypts = anno.get("polys", None) 186 | if polypts: 187 | obj["polygons"] = polypts 188 | 189 | text = anno.get("rec", None) 190 | if text: 191 | obj["text"] = text 192 | 193 | obj["bbox_mode"] = BoxMode.XYWH_ABS 194 | if id_map: 195 | obj["category_id"] = id_map[obj["category_id"]] 196 | objs.append(obj) 197 | record["annotations"] = objs 198 | dataset_dicts.append(record) 199 | 200 | if num_instances_without_valid_segmentation > 0: 201 | logger.warning( 202 | "Filtered out {} instances without valid segmentation. " 203 | "There might be issues in your dataset generation process.".format( 204 | num_instances_without_valid_segmentation 205 | ) 206 | ) 207 | return dataset_dicts -------------------------------------------------------------------------------- /adet/data/detection_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from detectron2.data import transforms as T 7 | from detectron2.data.detection_utils import \ 8 | annotations_to_instances as d2_anno_to_inst 9 | from detectron2.data.detection_utils import \ 10 | transform_instance_annotations as d2_transform_inst_anno 11 | 12 | 13 | def transform_instance_annotations( 14 | annotation, transforms, image_size, *, keypoint_hflip_indices=None 15 | ): 16 | 17 | annotation = d2_transform_inst_anno( 18 | annotation, 19 | transforms, 20 | image_size, 21 | keypoint_hflip_indices=keypoint_hflip_indices, 22 | ) 23 | 24 | if "beziers" in annotation: 25 | beziers = transform_ctrl_pnts_annotations(annotation["beziers"], transforms) 26 | annotation["beziers"] = beziers 27 | 28 | if "polygons" in annotation: 29 | polys = transform_ctrl_pnts_annotations(annotation["polygons"], transforms) 30 | annotation["polygons"] = polys 31 | 32 | return annotation 33 | 34 | 35 | def transform_ctrl_pnts_annotations(pnts, transforms): 36 | """ 37 | Transform keypoint annotations of an image. 38 | 39 | Args: 40 | beziers (list[float]): Nx16 float in Detectron2 Dataset format. 41 | transforms (TransformList): 42 | """ 43 | # (N*2,) -> (N, 2) 44 | pnts = np.asarray(pnts, dtype="float64").reshape(-1, 2) 45 | pnts = transforms.apply_coords(pnts).reshape(-1) 46 | 47 | # This assumes that HorizFlipTransform is the only one that does flip 48 | do_hflip = ( 49 | sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1 50 | ) 51 | if do_hflip: 52 | raise ValueError("Flipping text data is not supported (also disencouraged).") 53 | 54 | return pnts 55 | 56 | 57 | def annotations_to_instances(annos, image_size, mask_format="polygon"): 58 | instance = d2_anno_to_inst(annos, image_size, mask_format) 59 | 60 | if not annos: 61 | return instance 62 | 63 | # add attributes 64 | if "beziers" in annos[0]: 65 | beziers = [obj.get("beziers", []) for obj in annos] 66 | instance.beziers = torch.as_tensor(beziers, dtype=torch.float32) 67 | 68 | if "rec" in annos[0]: 69 | text = [obj.get("rec", []) for obj in annos] 70 | instance.text = torch.as_tensor(text, dtype=torch.int32) 71 | 72 | if "polygons" in annos[0]: 73 | polys = [obj.get("polygons", []) for obj in annos] 74 | instance.polygons = torch.as_tensor(polys, dtype=torch.float32) 75 | 76 | return instance 77 | 78 | 79 | def build_augmentation(cfg, is_train): 80 | """ 81 | With option to don't use hflip 82 | 83 | Returns: 84 | list[Augmentation] 85 | """ 86 | if is_train: 87 | min_size = cfg.INPUT.MIN_SIZE_TRAIN 88 | max_size = cfg.INPUT.MAX_SIZE_TRAIN 89 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING 90 | else: 91 | min_size = cfg.INPUT.MIN_SIZE_TEST 92 | max_size = cfg.INPUT.MAX_SIZE_TEST 93 | sample_style = "choice" 94 | if sample_style == "range": 95 | assert ( 96 | len(min_size) == 2 97 | ), "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) 98 | 99 | logger = logging.getLogger(__name__) 100 | 101 | augmentation = [] 102 | augmentation.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) 103 | if is_train: 104 | if cfg.INPUT.HFLIP_TRAIN: 105 | augmentation.append(T.RandomFlip()) 106 | logger.info("Augmentations used in training: " + str(augmentation)) 107 | return augmentation 108 | 109 | 110 | build_transform_gen = build_augmentation 111 | """ 112 | Alias for backward-compatibility. 113 | """ 114 | -------------------------------------------------------------------------------- /adet/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_evaluation import TextEvaluator 2 | from .text_eval_script import text_eval_main 3 | from . import rrc_evaluation_funcs -------------------------------------------------------------------------------- /adet/evaluation/lexicon_procesor.py: -------------------------------------------------------------------------------- 1 | import editdistance 2 | import numpy as np 3 | from numba import njit 4 | from numba.core import types 5 | from numba.typed import Dict 6 | 7 | @njit 8 | def weighted_edit_distance(word1: str, word2: str, scores: np.ndarray, ct_labels_inv): 9 | m: int = len(word1) 10 | n: int = len(word2) 11 | dp = np.zeros((n+1, m+1), dtype=np.float32) 12 | dp[0, :] = np.arange(m+1) 13 | dp[:, 0] = np.arange(n+1) 14 | for i in range(1, n + 1): ## word2 15 | for j in range(1, m + 1): ## word1 16 | delect_cost = _ed_delete_cost(j-1, i-1, word1, word2, scores, ct_labels_inv) ## delect a[i] 17 | insert_cost = _ed_insert_cost(j-1, i-1, word1, word2, scores, ct_labels_inv) ## insert b[j] 18 | if word1[j - 1] != word2[i - 1]: 19 | replace_cost = _ed_replace_cost(j-1, i-1, word1, word2, scores, ct_labels_inv) ## replace a[i] with b[j] 20 | else: 21 | replace_cost = 0 22 | dp[i][j] = min(dp[i-1][j] + insert_cost, dp[i][j-1] + delect_cost, dp[i-1][j-1] + replace_cost) 23 | 24 | return dp[n][m] 25 | 26 | @njit 27 | def _ed_delete_cost(j, i, word1, word2, scores, ct_labels_inv): 28 | ## delete a[i] 29 | return _get_score(scores[j], word1[j], ct_labels_inv) 30 | 31 | @njit 32 | def _ed_insert_cost(i, j, word1, word2, scores, ct_labels_inv): 33 | ## insert b[j] 34 | if i < len(word1) - 1: 35 | return (_get_score(scores[i], word1[i], ct_labels_inv) + _get_score(scores[i+1], word1[i+1], ct_labels_inv))/2 36 | else: 37 | return _get_score(scores[i], word1[i], ct_labels_inv) 38 | 39 | @njit 40 | def _ed_replace_cost(i, j, word1, word2, scores, ct_labels_inv): 41 | ## replace a[i] with b[j] 42 | # if word1 == "eeatpisaababarait".upper(): 43 | # print(scores[c2][i]/scores[c1][i]) 44 | return max(1 - _get_score(scores[i], word2[j], ct_labels_inv)/_get_score(scores[i], word1[i], ct_labels_inv)*5, 0) 45 | 46 | @njit 47 | def _get_score(scores, char, ct_labels_inv): 48 | upper = ct_labels_inv[char.upper()] 49 | lower = ct_labels_inv[char.lower()] 50 | return max(scores[upper], scores[lower]) 51 | 52 | class LexiconMatcher: 53 | def __init__(self, dataset, lexicon_type, use_lexicon, ct_labels, weighted_ed=False): 54 | self.use_lexicon = use_lexicon 55 | self.lexicon_type = lexicon_type 56 | self.dataset = dataset 57 | self.ct_labels_inv = Dict.empty( 58 | key_type=types.string, 59 | value_type=types.int64, 60 | ) 61 | for i, c in enumerate(ct_labels): 62 | self.ct_labels_inv[c] = i 63 | # maps char to index 64 | self.is_full_lex_dataset = "totaltext" in dataset or "ctw1500" in dataset 65 | self._load_lexicon(dataset, lexicon_type) 66 | self.weighted_ed = weighted_ed 67 | 68 | def find_match_word(self, rec_str, img_id=None, scores=None): 69 | if not self.use_lexicon: 70 | return rec_str 71 | rec_str = rec_str.upper() 72 | dist_min = 100 73 | match_word = '' 74 | match_dist = 100 75 | 76 | lexicons = self.lexicons if self.lexicon_type != 3 else self.lexicons[img_id] 77 | pairs = self.pairs if self.lexicon_type != 3 else self.pairs[img_id] 78 | 79 | # scores of shape (seq_len, n_symbols) must be provided for weighted editdistance 80 | assert not self.weighted_ed or scores is not None 81 | 82 | for word in lexicons: 83 | word = word.upper() 84 | if self.weighted_ed: 85 | ed = weighted_edit_distance(rec_str, word, scores, self.ct_labels_inv) 86 | else: 87 | ed = editdistance.eval(rec_str, word) 88 | if ed < dist_min: 89 | dist_min = ed 90 | match_word = pairs[word] 91 | match_dist = ed 92 | 93 | if self.is_full_lex_dataset: 94 | # always return matched results for the full lexicon (for totaltext/ctw1500) 95 | return match_word 96 | else: 97 | # filter unmatched words for icdar 98 | return match_word if match_dist < 2.5 or self.lexicon_type == 1 else None 99 | 100 | @staticmethod 101 | def _get_lexicon_path(dataset): 102 | if "icdar2015" in dataset: 103 | g_lexicon_path = "datasets/evaluation/lexicons/ic15/GenericVocabulary_new.txt" 104 | g_pairlist_path = "datasets/evaluation/lexicons/ic15/GenericVocabulary_pair_list.txt" 105 | w_lexicon_path = "datasets/evaluation/lexicons/ic15/ch4_test_vocabulary_new.txt" 106 | w_pairlist_path = "datasets/evaluation/lexicons/ic15/ch4_test_vocabulary_pair_list.txt" 107 | s_lexicon_paths = [ 108 | (str(fid+1), f"datasets/evaluation/lexicons/ic15/new_strong_lexicon/new_voc_img_{fid+1}.txt") for fid in range(500)] 109 | s_pairlist_paths = [ 110 | (str(fid+1), f"datasets/evaluation/lexicons/ic15/new_strong_lexicon/pair_voc_img_{fid+1}.txt") for fid in range(500)] 111 | elif "totaltext" in dataset: 112 | s_lexicon_paths = s_pairlist_paths = None 113 | g_lexicon_path = "datasets/evaluation/lexicons/totaltext/tt_lexicon.txt" 114 | g_pairlist_path = "datasets/evaluation/lexicons/totaltext/tt_pair_list.txt" 115 | w_lexicon_path = "datasets/evaluation/lexicons/totaltext/weak_voc_new.txt" 116 | w_pairlist_path = "datasets/evaluation/lexicons/totaltext/weak_voc_pair_list.txt" 117 | elif "ctw1500" in dataset: 118 | s_lexicon_paths = s_pairlist_paths = w_lexicon_path = w_pairlist_path = None 119 | g_lexicon_path = "datasets/evaluation/lexicons/ctw1500/ctw1500_lexicon.txt" 120 | g_pairlist_path = "datasets/evaluation/lexicons/ctw1500/ctw1500_pair_list.txt" 121 | return g_lexicon_path, g_pairlist_path, w_lexicon_path, w_pairlist_path, s_lexicon_paths, s_pairlist_paths 122 | 123 | def _load_lexicon(self, dataset, lexicon_type): 124 | if not self.use_lexicon: 125 | return 126 | g_lexicon_path, g_pairlist_path, w_lexicon_path, w_pairlist_path, s_lexicon_path, s_pairlist_path = self._get_lexicon_path( 127 | dataset) 128 | if lexicon_type in (1, 2): 129 | # generic/weak lexicon 130 | lexicon_path = g_lexicon_path if lexicon_type == 1 else w_lexicon_path 131 | pairlist_path = g_pairlist_path if lexicon_type == 1 else w_pairlist_path 132 | if lexicon_path is None or pairlist_path is None: 133 | self.use_lexicon = False 134 | return 135 | with open(pairlist_path) as fp: 136 | pairs = dict() 137 | for line in fp.readlines(): 138 | line = line.strip() 139 | if self.is_full_lex_dataset: 140 | # might contain space in key word 141 | split = line.split(' ') 142 | half = len(split) // 2 143 | word = ' '.join(split[:half]).upper() 144 | else: 145 | word = line.split(' ')[0].upper() 146 | word_gt = line[len(word)+1:] 147 | pairs[word] = word_gt 148 | with open(lexicon_path) as fp: 149 | lexicons = [] 150 | for line in fp.readlines(): 151 | lexicons.append(line.strip()) 152 | self.lexicons = lexicons 153 | self.pairs = pairs 154 | elif lexicon_type == 3: 155 | # strong lexicon 156 | if s_lexicon_path is None or s_pairlist_path is None: 157 | self.use_lexicon = False 158 | return 159 | lexicons, pairlists = dict(), dict() 160 | for (fid, lexicon_path), (_, pairlist_path) in zip(s_lexicon_path, s_pairlist_path): 161 | with open(lexicon_path) as fp: 162 | lexicon = [] 163 | for line in fp.readlines(): 164 | lexicon.append(line.strip()) 165 | with open(pairlist_path) as fp: 166 | pairs = dict() 167 | for line in fp.readlines(): 168 | line = line.strip() 169 | word = line.split(' ')[0].upper() 170 | word_gt = line[len(word)+1:] 171 | pairs[word] = word_gt 172 | lexicons[fid] = lexicon 173 | pairlists[fid] = pairs 174 | self.lexicons = lexicons 175 | self.pairs = pairlists 176 | -------------------------------------------------------------------------------- /adet/evaluation/text_evaluation.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import copy 3 | import io 4 | import itertools 5 | import json 6 | import logging 7 | import numpy as np 8 | import os 9 | import re 10 | import torch 11 | from collections import OrderedDict 12 | from fvcore.common.file_io import PathManager 13 | from pycocotools.coco import COCO 14 | 15 | from detectron2.utils import comm 16 | from detectron2.data import MetadataCatalog 17 | from detectron2.evaluation.evaluator import DatasetEvaluator 18 | 19 | import glob 20 | import shutil 21 | from shapely.geometry import Polygon, LinearRing 22 | from adet.evaluation import text_eval_script 23 | import zipfile 24 | import pickle 25 | 26 | from adet.evaluation.lexicon_procesor import LexiconMatcher 27 | 28 | NULL_CHAR = u'口' 29 | 30 | class TextEvaluator(DatasetEvaluator): 31 | """ 32 | Evaluate text proposals and recognition. 33 | """ 34 | 35 | def __init__(self, dataset_name, cfg, distributed, output_dir=None): 36 | self._tasks = ("polygon", "recognition") 37 | self._distributed = distributed 38 | self._output_dir = output_dir 39 | 40 | self._cpu_device = torch.device("cpu") 41 | self._logger = logging.getLogger(__name__) 42 | 43 | self._metadata = MetadataCatalog.get(dataset_name) 44 | if not hasattr(self._metadata, "json_file"): 45 | raise AttributeError( 46 | f"json_file was not found in MetaDataCatalog for '{dataset_name}'." 47 | ) 48 | 49 | self.voc_size = cfg.MODEL.BATEXT.VOC_SIZE 50 | self.use_customer_dictionary = cfg.MODEL.BATEXT.CUSTOM_DICT 51 | self.use_polygon = cfg.MODEL.TRANSFORMER.USE_POLYGON 52 | if not self.use_customer_dictionary: 53 | self.CTLABELS = [' ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~'] 54 | else: 55 | with open(self.use_customer_dictionary, 'rb') as fp: 56 | self.CTLABELS = pickle.load(fp) 57 | self._lexicon_matcher = LexiconMatcher(dataset_name, cfg.TEST.LEXICON_TYPE, cfg.TEST.USE_LEXICON, 58 | self.CTLABELS + [NULL_CHAR], 59 | weighted_ed=cfg.TEST.WEIGHTED_EDIT_DIST) 60 | assert(int(self.voc_size - 1) == len(self.CTLABELS)), "voc_size is not matched dictionary size, got {} and {}.".format(int(self.voc_size - 1), len(self.CTLABELS)) 61 | 62 | json_file = PathManager.get_local_path(self._metadata.json_file) 63 | with contextlib.redirect_stdout(io.StringIO()): 64 | self._coco_api = COCO(json_file) 65 | 66 | # use dataset_name to decide eval_gt_path 67 | if "totaltext" in dataset_name: 68 | self._text_eval_gt_path = "datasets/evaluation/gt_totaltext.zip" 69 | self._word_spotting = True 70 | elif "ctw1500" in dataset_name: 71 | self._text_eval_gt_path = "datasets/evaluation/gt_ctw1500.zip" 72 | self._word_spotting = False 73 | elif "icdar2015" in dataset_name: 74 | self._text_eval_gt_path = "datasets/evaluation/gt_icdar2015.zip" 75 | self._word_spotting = False 76 | else: 77 | self._text_eval_gt_path = "" 78 | self._text_eval_confidence = cfg.MODEL.FCOS.INFERENCE_TH_TEST 79 | 80 | def reset(self): 81 | self._predictions = [] 82 | 83 | def process(self, inputs, outputs): 84 | for input, output in zip(inputs, outputs): 85 | prediction = {"image_id": input["image_id"]} 86 | 87 | instances = output["instances"].to(self._cpu_device) 88 | prediction["instances"] = self.instances_to_coco_json(instances, input["image_id"]) 89 | self._predictions.append(prediction) 90 | 91 | def to_eval_format(self, file_path, temp_dir="temp_det_results", cf_th=0.5): 92 | def fis_ascii(s): 93 | a = (ord(c) < 128 for c in s) 94 | return all(a) 95 | 96 | def de_ascii(s): 97 | a = [c for c in s if ord(c) < 128] 98 | outa = '' 99 | for i in a: 100 | outa +=i 101 | return outa 102 | 103 | with open(file_path, 'r') as f: 104 | data = json.load(f) 105 | with open('temp_all_det_cors.txt', 'w') as f2: 106 | for ix in range(len(data)): 107 | if data[ix]['score'] > 0.1: 108 | outstr = '{}: '.format(data[ix]['image_id']) 109 | xmin = 1000000 110 | ymin = 1000000 111 | xmax = 0 112 | ymax = 0 113 | for i in range(len(data[ix]['polys'])): 114 | outstr = outstr + str(int(data[ix]['polys'][i][0])) +','+str(int(data[ix]['polys'][i][1])) +',' 115 | ass = de_ascii(data[ix]['rec']) 116 | if len(ass)>=0: # 117 | outstr = outstr + str(round(data[ix]['score'], 3)) +',####'+ass+'\n' 118 | f2.writelines(outstr) 119 | f2.close() 120 | dirn = temp_dir 121 | lsc = [cf_th] 122 | fres = open('temp_all_det_cors.txt', 'r').readlines() 123 | for isc in lsc: 124 | if not os.path.isdir(dirn): 125 | os.mkdir(dirn) 126 | 127 | for line in fres: 128 | line = line.strip() 129 | s = line.split(': ') 130 | filename = '{:07d}.txt'.format(int(s[0])) 131 | outName = os.path.join(dirn, filename) 132 | with open(outName, 'a') as fout: 133 | ptr = s[1].strip().split(',####') 134 | score = ptr[0].split(',')[-1] 135 | if float(score) < isc: 136 | continue 137 | cors = ','.join(e for e in ptr[0].split(',')[:-1]) 138 | fout.writelines(cors+',####'+ptr[1]+'\n') 139 | os.remove("temp_all_det_cors.txt") 140 | 141 | def sort_detection(self, temp_dir): 142 | origin_file = temp_dir 143 | output_file = "final_"+temp_dir 144 | 145 | if not os.path.isdir(output_file): 146 | os.mkdir(output_file) 147 | 148 | files = glob.glob(origin_file+'*.txt') 149 | files.sort() 150 | 151 | for i in files: 152 | out = i.replace(origin_file, output_file) 153 | fin = open(i, 'r').readlines() 154 | fout = open(out, 'w') 155 | for iline, line in enumerate(fin): 156 | ptr = line.strip().split(',####') 157 | rec = ptr[1] 158 | cors = ptr[0].split(',') 159 | assert(len(cors) %2 == 0), 'cors invalid.' 160 | pts = [(int(cors[j]), int(cors[j+1])) for j in range(0,len(cors),2)] 161 | try: 162 | pgt = Polygon(pts) 163 | except Exception as e: 164 | print(e) 165 | print('An invalid detection in {} line {} is removed ... '.format(i, iline)) 166 | continue 167 | 168 | if not pgt.is_valid: 169 | print('An invalid detection in {} line {} is removed ... '.format(i, iline)) 170 | continue 171 | 172 | pRing = LinearRing(pts) 173 | if pRing.is_ccw: 174 | pts.reverse() 175 | outstr = '' 176 | for ipt in pts[:-1]: 177 | outstr += (str(int(ipt[0]))+','+ str(int(ipt[1]))+',') 178 | outstr += (str(int(pts[-1][0]))+','+ str(int(pts[-1][1]))) 179 | outstr = outstr+',####' + rec 180 | fout.writelines(outstr+'\n') 181 | fout.close() 182 | os.chdir(output_file) 183 | 184 | def zipdir(path, ziph): 185 | # ziph is zipfile handle 186 | for root, dirs, files in os.walk(path): 187 | for file in files: 188 | ziph.write(os.path.join(root, file)) 189 | 190 | zipf = zipfile.ZipFile('../det.zip', 'w', zipfile.ZIP_DEFLATED) 191 | zipdir('./', zipf) 192 | zipf.close() 193 | os.chdir("../") 194 | # clean temp files 195 | shutil.rmtree(origin_file) 196 | shutil.rmtree(output_file) 197 | return "det.zip" 198 | 199 | def evaluate_with_official_code(self, result_path, gt_path): 200 | return text_eval_script.text_eval_main(det_file=result_path, gt_file=gt_path, is_word_spotting=self._word_spotting) 201 | 202 | def evaluate(self): 203 | if self._distributed: 204 | comm.synchronize() 205 | predictions = comm.gather(self._predictions, dst=0) 206 | predictions = list(itertools.chain(*predictions)) 207 | 208 | if not comm.is_main_process(): 209 | return {} 210 | else: 211 | predictions = self._predictions 212 | 213 | if len(predictions) == 0: 214 | self._logger.warning("[COCOEvaluator] Did not receive valid predictions.") 215 | return {} 216 | 217 | coco_results = list(itertools.chain(*[x["instances"] for x in predictions])) 218 | PathManager.mkdirs(self._output_dir) 219 | 220 | file_path = os.path.join(self._output_dir, "text_results.json") 221 | self._logger.info("Saving results to {}".format(file_path)) 222 | with PathManager.open(file_path, "w") as f: 223 | f.write(json.dumps(coco_results)) 224 | f.flush() 225 | 226 | self._results = OrderedDict() 227 | 228 | if not self._text_eval_gt_path: 229 | return copy.deepcopy(self._results) 230 | # eval text 231 | temp_dir = "temp_det_results/" 232 | self.to_eval_format(file_path, temp_dir, self._text_eval_confidence) 233 | result_path = self.sort_detection(temp_dir) 234 | text_result = self.evaluate_with_official_code(result_path, self._text_eval_gt_path) 235 | os.remove(result_path) 236 | 237 | # parse 238 | template = "(\S+): (\S+): (\S+), (\S+): (\S+), (\S+): (\S+)" 239 | for task in ("e2e_method", "det_only_method"): 240 | result = text_result[task] 241 | groups = re.match(template, result).groups() 242 | self._results[groups[0]] = {groups[i*2+1]: float(groups[(i+1)*2]) for i in range(3)} 243 | 244 | return copy.deepcopy(self._results) 245 | 246 | 247 | def instances_to_coco_json(self, instances, img_id): 248 | num_instances = len(instances) 249 | if num_instances == 0: 250 | return [] 251 | 252 | scores = instances.scores.tolist() 253 | if self.use_polygon: 254 | pnts = instances.polygons.numpy() 255 | else: 256 | pnts = instances.beziers.numpy() 257 | recs = instances.recs.numpy() 258 | rec_scores = instances.rec_scores.numpy() 259 | 260 | results = [] 261 | for pnt, rec, score, rec_score in zip(pnts, recs, scores, rec_scores): 262 | # convert beziers to polygons 263 | poly = self.pnt_to_polygon(pnt) 264 | s = self.decode(rec) 265 | word = self._lexicon_matcher.find_match_word(s, img_id=str(img_id), scores=rec_score) 266 | if word is None: 267 | continue 268 | result = { 269 | "image_id": img_id, 270 | "category_id": 1, 271 | "polys": poly, 272 | "rec": word, 273 | "score": score, 274 | } 275 | results.append(result) 276 | return results 277 | 278 | 279 | def pnt_to_polygon(self, ctrl_pnt): 280 | if self.use_polygon: 281 | return ctrl_pnt.reshape(-1, 2).tolist() 282 | else: 283 | u = np.linspace(0, 1, 20) 284 | ctrl_pnt = ctrl_pnt.reshape(2, 4, 2).transpose(0, 2, 1).reshape(4, 4) 285 | points = np.outer((1 - u) ** 3, ctrl_pnt[:, 0]) \ 286 | + np.outer(3 * u * ((1 - u) ** 2), ctrl_pnt[:, 1]) \ 287 | + np.outer(3 * (u ** 2) * (1 - u), ctrl_pnt[:, 2]) \ 288 | + np.outer(u ** 3, ctrl_pnt[:, 3]) 289 | 290 | # convert points to polygon 291 | points = np.concatenate((points[:, :2], points[:, 2:]), axis=0) 292 | return points.tolist() 293 | 294 | def ctc_decode(self, rec): 295 | # ctc decoding 296 | last_char = False 297 | s = '' 298 | for c in rec: 299 | c = int(c) 300 | if c < self.voc_size - 1: 301 | if last_char != c: 302 | if self.voc_size == 96: 303 | s += self.CTLABELS[c] 304 | last_char = c 305 | else: 306 | s += str(chr(self.CTLABELS[c])) 307 | last_char = c 308 | elif c == self.voc_size -1: 309 | s += u'口' 310 | else: 311 | last_char = False 312 | return s 313 | 314 | 315 | def decode(self, rec): 316 | s = '' 317 | for c in rec: 318 | c = int(c) 319 | if c < self.voc_size - 1: 320 | if self.voc_size == 96: 321 | s += self.CTLABELS[c] 322 | else: 323 | s += str(chr(self.CTLABELS[c])) 324 | elif c == self.voc_size -1: 325 | s += NULL_CHAR 326 | 327 | return s 328 | 329 | -------------------------------------------------------------------------------- /adet/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .ms_deform_attn import MSDeformAttn 2 | 3 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /adet/layers/csrc/DeformAttn/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /adet/layers/csrc/DeformAttn/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /adet/layers/csrc/DeformAttn/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /adet/layers/csrc/DeformAttn/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /adet/layers/csrc/DeformAttn/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /adet/layers/csrc/cuda_version.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace adet { 4 | int get_cudart_version() { 5 | return CUDART_VERSION; 6 | } 7 | } // namespace adet 8 | -------------------------------------------------------------------------------- /adet/layers/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | #include "DeformAttn/ms_deform_attn.h" 3 | 4 | namespace adet { 5 | 6 | #ifdef WITH_CUDA 7 | extern int get_cudart_version(); 8 | #endif 9 | 10 | std::string get_cuda_version() { 11 | #ifdef WITH_CUDA 12 | std::ostringstream oss; 13 | 14 | // copied from 15 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 16 | auto printCudaStyleVersion = [&](int v) { 17 | oss << (v / 1000) << "." << (v / 10 % 100); 18 | if (v % 10 != 0) { 19 | oss << "." << (v % 10); 20 | } 21 | }; 22 | printCudaStyleVersion(get_cudart_version()); 23 | return oss.str(); 24 | #else 25 | return std::string("not available"); 26 | #endif 27 | } 28 | 29 | // similar to 30 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp 31 | std::string get_compiler_version() { 32 | std::ostringstream ss; 33 | #if defined(__GNUC__) 34 | #ifndef __clang__ 35 | { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } 36 | #endif 37 | #endif 38 | 39 | #if defined(__clang_major__) 40 | { 41 | ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." 42 | << __clang_patchlevel__; 43 | } 44 | #endif 45 | 46 | #if defined(_MSC_VER) 47 | { ss << "MSVC " << _MSC_FULL_VER; } 48 | #endif 49 | return ss.str(); 50 | } 51 | 52 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 53 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 54 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 55 | } 56 | 57 | } // namespace adet 58 | -------------------------------------------------------------------------------- /adet/layers/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | import warnings 9 | import math 10 | 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | from torch.nn.init import xavier_uniform_, constant_ 15 | from torch.autograd.function import once_differentiable 16 | 17 | from adet import _C 18 | 19 | class _MSDeformAttnFunction(torch.autograd.Function): 20 | @staticmethod 21 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 22 | ctx.im2col_step = im2col_step 23 | output = _C.ms_deform_attn_forward( 24 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 25 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 26 | return output 27 | 28 | @staticmethod 29 | @once_differentiable 30 | def backward(ctx, grad_output): 31 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 32 | grad_value, grad_sampling_loc, grad_attn_weight = \ 33 | _C.ms_deform_attn_backward( 34 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 35 | 36 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 37 | 38 | 39 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 40 | # for debug and test only, 41 | # need to use cuda version instead 42 | N_, S_, M_, D_ = value.shape 43 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 44 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 45 | sampling_grids = 2 * sampling_locations - 1 46 | sampling_value_list = [] 47 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 48 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 49 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 50 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 51 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 52 | # N_*M_, D_, Lq_, P_ 53 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 54 | mode='bilinear', padding_mode='zeros', align_corners=False) 55 | sampling_value_list.append(sampling_value_l_) 56 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 57 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 58 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 59 | return output.transpose(1, 2).contiguous() 60 | 61 | 62 | def _is_power_of_2(n): 63 | if (not isinstance(n, int)) or (n < 0): 64 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 65 | return (n & (n-1) == 0) and n != 0 66 | 67 | 68 | class MSDeformAttn(nn.Module): 69 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 70 | """ 71 | Multi-Scale Deformable Attention Module 72 | :param d_model hidden dimension 73 | :param n_levels number of feature levels 74 | :param n_heads number of attention heads 75 | :param n_points number of sampling points per attention head per feature level 76 | """ 77 | super().__init__() 78 | if d_model % n_heads != 0: 79 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 80 | _d_per_head = d_model // n_heads 81 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 82 | if not _is_power_of_2(_d_per_head): 83 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 84 | "which is more efficient in our CUDA implementation.") 85 | 86 | self.im2col_step = 64 87 | 88 | self.d_model = d_model 89 | self.n_levels = n_levels 90 | self.n_heads = n_heads 91 | self.n_points = n_points 92 | 93 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 94 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 95 | self.value_proj = nn.Linear(d_model, d_model) 96 | self.output_proj = nn.Linear(d_model, d_model) 97 | 98 | self._reset_parameters() 99 | 100 | def _reset_parameters(self): 101 | constant_(self.sampling_offsets.weight.data, 0.) 102 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 103 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 104 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 105 | for i in range(self.n_points): 106 | grid_init[:, :, i, :] *= i + 1 107 | with torch.no_grad(): 108 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 109 | constant_(self.attention_weights.weight.data, 0.) 110 | constant_(self.attention_weights.bias.data, 0.) 111 | xavier_uniform_(self.value_proj.weight.data) 112 | constant_(self.value_proj.bias.data, 0.) 113 | xavier_uniform_(self.output_proj.weight.data) 114 | constant_(self.output_proj.bias.data, 0.) 115 | 116 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 117 | """ 118 | :param query (N, Length_{query}, C) 119 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 120 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 121 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 122 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 123 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 124 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 125 | 126 | :return output (N, Length_{query}, C) 127 | """ 128 | N, Len_q, _ = query.shape 129 | N, Len_in, _ = input_flatten.shape 130 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 131 | 132 | value = self.value_proj(input_flatten) 133 | if input_padding_mask is not None: 134 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 135 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 136 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 137 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 138 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 139 | # N, Len_q, n_heads, n_levels, n_points, 2 140 | if reference_points.shape[-1] == 2: 141 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 142 | sampling_locations = reference_points[:, :, None, :, None, :] \ 143 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 144 | elif reference_points.shape[-1] == 4: 145 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 146 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 147 | else: 148 | raise ValueError( 149 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 150 | output = _MSDeformAttnFunction.apply( 151 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 152 | output = self.output_proj(output) 153 | return output 154 | -------------------------------------------------------------------------------- /adet/layers/pos_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | class PositionalEncoding1D(nn.Module): 6 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 7 | """ 8 | :param channels: The last dimension of the tensor you want to apply pos emb to. 9 | """ 10 | super().__init__() 11 | self.channels = num_pos_feats 12 | dim_t = torch.arange(0, self.channels, 2).float() 13 | if scale is not None and normalize is False: 14 | raise ValueError("normalize should be True if scale is passed") 15 | if scale is None: 16 | scale = 2 * np.pi 17 | self.scale = scale 18 | self.normalize = normalize 19 | inv_freq = 1. / (temperature ** (dim_t / self.channels)) 20 | self.register_buffer('inv_freq', inv_freq) 21 | 22 | def forward(self, tensor): 23 | """ 24 | :param tensor: A 2d tensor of size (len, c) 25 | :return: Positional Encoding Matrix of size (len, c) 26 | """ 27 | if tensor.ndim != 2: 28 | raise RuntimeError("The input tensor has to be 2D!") 29 | x, orig_ch = tensor.shape 30 | pos_x = torch.arange( 31 | 1, x + 1, device=tensor.device).type(self.inv_freq.type()) 32 | 33 | if self.normalize: 34 | eps = 1e-6 35 | pos_x = pos_x / (pos_x[-1:] + eps) * self.scale 36 | 37 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 38 | emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1) 39 | emb = torch.zeros((x, self.channels), 40 | device=tensor.device).type(tensor.type()) 41 | emb[:, :self.channels] = emb_x 42 | 43 | return emb[:, :orig_ch] 44 | 45 | 46 | class PositionalEncoding2D(nn.Module): 47 | """ 48 | This is a more standard version of the position embedding, very similar to the one 49 | used by the Attention is all you need paper, generalized to work on images. 50 | """ 51 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 52 | super().__init__() 53 | self.num_pos_feats = num_pos_feats 54 | self.temperature = temperature 55 | self.normalize = normalize 56 | if scale is not None and normalize is False: 57 | raise ValueError("normalize should be True if scale is passed") 58 | if scale is None: 59 | scale = 2 * np.pi 60 | self.scale = scale 61 | 62 | def forward(self, tensors): 63 | x = tensors.tensors 64 | mask = tensors.mask 65 | assert mask is not None 66 | not_mask = ~mask 67 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 68 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 69 | if self.normalize: 70 | eps = 1e-6 71 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 72 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 73 | 74 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 75 | dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats) 76 | 77 | pos_x = x_embed[:, :, :, None] / dim_t 78 | pos_y = y_embed[:, :, :, None] / dim_t 79 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 80 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 81 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 82 | return pos 83 | -------------------------------------------------------------------------------- /adet/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .transformer_detector import TransformerDetector 3 | 4 | _EXCLUDE = {"torch", "ShapeSpec"} 5 | __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] 6 | -------------------------------------------------------------------------------- /adet/modeling/testr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlpc-ucsd/TESTR/f369f138e041d1d27348a1f6600e456452001d23/adet/modeling/testr/__init__.py -------------------------------------------------------------------------------- /adet/modeling/testr/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | from adet.utils.misc import accuracy, generalized_box_iou, box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, is_dist_avail_and_initialized 6 | from detectron2.utils.comm import get_world_size 7 | 8 | 9 | def sigmoid_focal_loss(inputs, targets, num_inst, alpha: float = 0.25, gamma: float = 2): 10 | """ 11 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 12 | Args: 13 | inputs: A float tensor of arbitrary shape. 14 | The predictions for each example. 15 | targets: A float tensor with the same shape as inputs. Stores the binary 16 | classification label for each element in inputs 17 | (0 for the negative class and 1 for the positive class). 18 | alpha: (optional) Weighting factor in range (0,1) to balance 19 | positive vs negative examples. Default = -1 (no weighting). 20 | gamma: Exponent of the modulating factor (1 - p_t) to 21 | balance easy vs hard examples. 22 | Returns: 23 | Loss tensor 24 | """ 25 | prob = inputs.sigmoid() 26 | ce_loss = F.binary_cross_entropy_with_logits( 27 | inputs, targets, reduction="none") 28 | p_t = prob * targets + (1 - prob) * (1 - targets) 29 | loss = ce_loss * ((1 - p_t) ** gamma) 30 | 31 | if alpha >= 0: 32 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 33 | loss = alpha_t * loss 34 | 35 | if loss.ndim == 4: 36 | return loss.mean((1, 2)).sum() / num_inst 37 | elif loss.ndim == 3: 38 | return loss.mean(1).sum() / num_inst 39 | else: 40 | raise NotImplementedError(f"Unsupported dim {loss.ndim}") 41 | 42 | 43 | class SetCriterion(nn.Module): 44 | """ This class computes the loss for TESTR. 45 | The process happens in two steps: 46 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 47 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 48 | """ 49 | 50 | def __init__(self, num_classes, enc_matcher, dec_matcher, weight_dict, enc_losses, dec_losses, num_ctrl_points, focal_alpha=0.25, focal_gamma=2.0): 51 | """ Create the criterion. 52 | Parameters: 53 | num_classes: number of object categories, omitting the special no-object category 54 | matcher: module able to compute a matching between targets and proposals 55 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 56 | losses: list of all the losses to be applied. See get_loss for list of available losses. 57 | focal_alpha: alpha in Focal Loss 58 | """ 59 | super().__init__() 60 | self.num_classes = num_classes 61 | self.enc_matcher = enc_matcher 62 | self.dec_matcher = dec_matcher 63 | self.weight_dict = weight_dict 64 | self.enc_losses = enc_losses 65 | self.dec_losses = dec_losses 66 | self.focal_alpha = focal_alpha 67 | self.focal_gamma = focal_gamma 68 | self.num_ctrl_points = num_ctrl_points 69 | 70 | def loss_labels(self, outputs, targets, indices, num_inst, log=False): 71 | """Classification loss (NLL) 72 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 73 | """ 74 | assert 'pred_logits' in outputs 75 | src_logits = outputs['pred_logits'] 76 | 77 | idx = self._get_src_permutation_idx(indices) 78 | 79 | target_classes = torch.full(src_logits.shape[:-1], self.num_classes, 80 | dtype=torch.int64, device=src_logits.device) 81 | target_classes_o = torch.cat([t["labels"][J] 82 | for t, (_, J) in zip(targets, indices)]) 83 | if len(target_classes_o.shape) < len(target_classes[idx].shape): 84 | target_classes_o = target_classes_o[..., None] 85 | target_classes[idx] = target_classes_o 86 | 87 | shape = list(src_logits.shape) 88 | shape[-1] += 1 89 | target_classes_onehot = torch.zeros(shape, 90 | dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) 91 | target_classes_onehot.scatter_(-1, target_classes.unsqueeze(-1), 1) 92 | target_classes_onehot = target_classes_onehot[..., :-1] 93 | loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_inst, 94 | alpha=self.focal_alpha, gamma=self.focal_gamma) * src_logits.shape[1] 95 | losses = {'loss_ce': loss_ce} 96 | 97 | if log: 98 | # TODO this should probably be a separate loss, not hacked in this one here 99 | losses['class_error'] = 100 - \ 100 | accuracy(src_logits[idx], target_classes_o)[0] 101 | return losses 102 | 103 | @torch.no_grad() 104 | def loss_cardinality(self, outputs, targets, indices, num_inst): 105 | """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes 106 | This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients 107 | """ 108 | pred_logits = outputs['pred_logits'] 109 | device = pred_logits.device 110 | tgt_lengths = torch.as_tensor( 111 | [len(v["labels"]) for v in targets], device=device) 112 | card_pred = (pred_logits.mean(-2).argmax(-1) == 0).sum(1) 113 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) 114 | losses = {'cardinality_error': card_err} 115 | return losses 116 | 117 | def loss_boxes(self, outputs, targets, indices, num_inst): 118 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 119 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 120 | The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. 121 | """ 122 | assert 'pred_boxes' in outputs 123 | idx = self._get_src_permutation_idx(indices) 124 | src_boxes = outputs['pred_boxes'][idx] 125 | target_boxes = torch.cat([t['boxes'][i] 126 | for t, (_, i) in zip(targets, indices)], dim=0) 127 | 128 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') 129 | 130 | losses = {} 131 | losses['loss_bbox'] = loss_bbox.sum() / num_inst 132 | 133 | loss_giou = 1 - torch.diag(generalized_box_iou( 134 | box_cxcywh_to_xyxy(src_boxes), 135 | box_cxcywh_to_xyxy(target_boxes))) 136 | losses['loss_giou'] = loss_giou.sum() / num_inst 137 | return losses 138 | 139 | def loss_texts(self, outputs, targets, indices, num_inst): 140 | assert 'pred_texts' in outputs 141 | idx = self._get_src_permutation_idx(indices) 142 | src_texts = outputs['pred_texts'][idx] 143 | target_ctrl_points = torch.cat([t['texts'][i] for t, (_, i) in zip(targets, indices)], dim=0) 144 | return {'loss_texts': F.cross_entropy(src_texts.transpose(1, 2), target_ctrl_points.long())} 145 | 146 | 147 | def loss_ctrl_points(self, outputs, targets, indices, num_inst): 148 | """Compute the losses related to the keypoint coordinates, the L1 regression loss 149 | """ 150 | assert 'pred_ctrl_points' in outputs 151 | idx = self._get_src_permutation_idx(indices) 152 | src_ctrl_points = outputs['pred_ctrl_points'][idx] 153 | target_ctrl_points = torch.cat([t['ctrl_points'][i] for t, (_, i) in zip(targets, indices)], dim=0) 154 | 155 | loss_ctrl_points = F.l1_loss(src_ctrl_points, target_ctrl_points, reduction='sum') 156 | 157 | losses = {'loss_ctrl_points': loss_ctrl_points / num_inst} 158 | return losses 159 | 160 | @staticmethod 161 | def _get_src_permutation_idx(indices): 162 | # permute predictions following indices 163 | batch_idx = torch.cat([torch.full_like(src, i) 164 | for i, (src, _) in enumerate(indices)]) 165 | src_idx = torch.cat([src for (src, _) in indices]) 166 | return batch_idx, src_idx 167 | 168 | @staticmethod 169 | def _get_tgt_permutation_idx(indices): 170 | # permute targets following indices 171 | batch_idx = torch.cat([torch.full_like(tgt, i) 172 | for i, (_, tgt) in enumerate(indices)]) 173 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 174 | return batch_idx, tgt_idx 175 | 176 | def get_loss(self, loss, outputs, targets, indices, num_inst, **kwargs): 177 | loss_map = { 178 | 'labels': self.loss_labels, 179 | 'cardinality': self.loss_cardinality, 180 | 'ctrl_points': self.loss_ctrl_points, 181 | 'boxes': self.loss_boxes, 182 | 'texts': self.loss_texts, 183 | } 184 | assert loss in loss_map, f'do you really want to compute {loss} loss?' 185 | return loss_map[loss](outputs, targets, indices, num_inst, **kwargs) 186 | 187 | def forward(self, outputs, targets): 188 | """ This performs the loss computation. 189 | Parameters: 190 | outputs: dict of tensors, see the output specification of the model for the format 191 | targets: list of dicts, such that len(targets) == batch_size. 192 | The expected keys in each dict depends on the losses applied, see each loss' doc 193 | """ 194 | outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'} 195 | 196 | # Retrieve the matching between the outputs of the last layer and the targets 197 | indices = self.dec_matcher(outputs_without_aux, targets) 198 | 199 | # Compute the average number of target boxes accross all nodes, for normalization purposes 200 | num_inst = sum(len(t['ctrl_points']) for t in targets) 201 | num_inst = torch.as_tensor( 202 | [num_inst], dtype=torch.float, device=next(iter(outputs.values())).device) 203 | if is_dist_avail_and_initialized(): 204 | torch.distributed.all_reduce(num_inst) 205 | num_inst = torch.clamp(num_inst / get_world_size(), min=1).item() 206 | 207 | # Compute all the requested losses 208 | losses = {} 209 | for loss in self.dec_losses: 210 | kwargs = {} 211 | losses.update(self.get_loss(loss, outputs, targets, 212 | indices, num_inst, **kwargs)) 213 | 214 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 215 | if 'aux_outputs' in outputs: 216 | for i, aux_outputs in enumerate(outputs['aux_outputs']): 217 | indices = self.dec_matcher(aux_outputs, targets) 218 | for loss in self.dec_losses: 219 | kwargs = {} 220 | if loss == 'labels': 221 | # Logging is enabled only for the last layer 222 | kwargs['log'] = False 223 | l_dict = self.get_loss( 224 | loss, aux_outputs, targets, indices, num_inst, **kwargs) 225 | l_dict = {k + f'_{i}': v for k, v in l_dict.items()} 226 | losses.update(l_dict) 227 | 228 | if 'enc_outputs' in outputs: 229 | enc_outputs = outputs['enc_outputs'] 230 | indices = self.enc_matcher(enc_outputs, targets) 231 | for loss in self.enc_losses: 232 | kwargs = {} 233 | if loss == 'labels': 234 | kwargs['log'] = False 235 | l_dict = self.get_loss( 236 | loss, enc_outputs, targets, indices, num_inst, **kwargs) 237 | l_dict = {k + f'_enc': v for k, v in l_dict.items()} 238 | losses.update(l_dict) 239 | 240 | return losses -------------------------------------------------------------------------------- /adet/modeling/testr/matcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modules to compute the matching cost and solve the corresponding LSAP. 3 | """ 4 | import torch 5 | from scipy.optimize import linear_sum_assignment 6 | from torch import nn 7 | from adet.utils.misc import box_cxcywh_to_xyxy, generalized_box_iou 8 | 9 | 10 | class CtrlPointHungarianMatcher(nn.Module): 11 | """This class computes an assignment between the targets and the predictions of the network 12 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 13 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 14 | while the others are un-matched (and thus treated as non-objects). 15 | """ 16 | 17 | def __init__(self, 18 | class_weight: float = 1, 19 | coord_weight: float = 1, 20 | focal_alpha: float = 0.25, 21 | focal_gamma: float = 2.0): 22 | """Creates the matcher 23 | Params: 24 | class_weight: This is the relative weight of the classification error in the matching cost 25 | coord_weight: This is the relative weight of the L1 error of the keypoint coordinates in the matching cost 26 | """ 27 | super().__init__() 28 | self.class_weight = class_weight 29 | self.coord_weight = coord_weight 30 | self.alpha = focal_alpha 31 | self.gamma = focal_gamma 32 | assert class_weight != 0 or coord_weight != 0, "all costs cant be 0" 33 | 34 | def forward(self, outputs, targets): 35 | """ Performs the matching 36 | Params: 37 | outputs: This is a dict that contains at least these entries: 38 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 39 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 40 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 41 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 42 | objects in the target) containing the class labels 43 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 44 | Returns: 45 | A list of size batch_size, containing tuples of (index_i, index_j) where: 46 | - index_i is the indices of the selected predictions (in order) 47 | - index_j is the indices of the corresponding selected targets (in order) 48 | For each batch element, it holds: 49 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 50 | """ 51 | with torch.no_grad(): 52 | bs, num_queries = outputs["pred_logits"].shape[:2] 53 | 54 | # We flatten to compute the cost matrices in a batch 55 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() 56 | # [batch_size, n_queries, n_points, 2] --> [batch_size * num_queries, n_points * 2] 57 | out_pts = outputs["pred_ctrl_points"].flatten(0, 1).flatten(-2) 58 | 59 | # Also concat the target labels and boxes 60 | tgt_pts = torch.cat([v["ctrl_points"] for v in targets]).flatten(-2) 61 | neg_cost_class = (1 - self.alpha) * (out_prob ** self.gamma) * \ 62 | (-(1 - out_prob + 1e-8).log()) 63 | pos_cost_class = self.alpha * \ 64 | ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log()) 65 | # FIXME: hack here for label ID 0 66 | cost_class = (pos_cost_class[..., 0] - neg_cost_class[..., 0]).mean(-1, keepdims=True) 67 | 68 | cost_kpts = torch.cdist(out_pts, tgt_pts, p=1) 69 | 70 | C = self.class_weight * cost_class + self.coord_weight * cost_kpts 71 | C = C.view(bs, num_queries, -1).cpu() 72 | 73 | sizes = [len(v["ctrl_points"]) for v in targets] 74 | indices = [linear_sum_assignment( 75 | c[i]) for i, c in enumerate(C.split(sizes, -1))] 76 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 77 | 78 | 79 | class BoxHungarianMatcher(nn.Module): 80 | """This class computes an assignment between the targets and the predictions of the network 81 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 82 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 83 | while the others are un-matched (and thus treated as non-objects). 84 | """ 85 | 86 | def __init__(self, 87 | class_weight: float = 1, 88 | coord_weight: float = 1, 89 | giou_weight: float = 1, 90 | focal_alpha: float = 0.25, 91 | focal_gamma: float = 2.0): 92 | """Creates the matcher 93 | Params: 94 | class_weight: This is the relative weight of the classification error in the matching cost 95 | coord_weight: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 96 | giou_weight: This is the relative weight of the giou loss of the bounding box in the matching cost 97 | """ 98 | super().__init__() 99 | self.class_weight = class_weight 100 | self.coord_weight = coord_weight 101 | self.giou_weight = giou_weight 102 | self.alpha = focal_alpha 103 | self.gamma = focal_gamma 104 | assert class_weight != 0 or coord_weight != 0 or giou_weight != 0, "all costs cant be 0" 105 | 106 | def forward(self, outputs, targets): 107 | """ Performs the matching 108 | Params: 109 | outputs: This is a dict that contains at least these entries: 110 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 111 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 112 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 113 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 114 | objects in the target) containing the class labels 115 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 116 | Returns: 117 | A list of size batch_size, containing tuples of (index_i, index_j) where: 118 | - index_i is the indices of the selected predictions (in order) 119 | - index_j is the indices of the corresponding selected targets (in order) 120 | For each batch element, it holds: 121 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 122 | """ 123 | with torch.no_grad(): 124 | bs, num_queries = outputs["pred_logits"].shape[:2] 125 | 126 | # We flatten to compute the cost matrices in a batch 127 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() 128 | out_bbox = outputs["pred_boxes"].flatten( 129 | 0, 1) # [batch_size * num_queries, 4] 130 | 131 | # Also concat the target labels and boxes 132 | tgt_ids = torch.cat([v["labels"] for v in targets]) 133 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 134 | 135 | # Compute the classification cost. 136 | neg_cost_class = (1 - self.alpha) * (out_prob ** self.gamma) * \ 137 | (-(1 - out_prob + 1e-8).log()) 138 | pos_cost_class = self.alpha * \ 139 | ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log()) 140 | cost_class = pos_cost_class[:, tgt_ids] - \ 141 | neg_cost_class[:, tgt_ids] 142 | 143 | # Compute the L1 cost between boxes 144 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 145 | 146 | # Compute the giou cost betwen boxes 147 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), 148 | box_cxcywh_to_xyxy(tgt_bbox)) 149 | 150 | # Final cost matrix 151 | C = self.coord_weight * cost_bbox + self.class_weight * \ 152 | cost_class + self.giou_weight * cost_giou 153 | C = C.view(bs, num_queries, -1).cpu() 154 | 155 | sizes = [len(v["boxes"]) for v in targets] 156 | indices = [linear_sum_assignment( 157 | c[i]) for i, c in enumerate(C.split(sizes, -1))] 158 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 159 | 160 | 161 | def build_matcher(cfg): 162 | cfg = cfg.MODEL.TRANSFORMER.LOSS 163 | return BoxHungarianMatcher(class_weight=cfg.BOX_CLASS_WEIGHT, 164 | coord_weight=cfg.BOX_COORD_WEIGHT, 165 | giou_weight=cfg.BOX_GIOU_WEIGHT, 166 | focal_alpha=cfg.FOCAL_ALPHA, 167 | focal_gamma=cfg.FOCAL_GAMMA), \ 168 | CtrlPointHungarianMatcher(class_weight=cfg.POINT_CLASS_WEIGHT, 169 | coord_weight=cfg.POINT_COORD_WEIGHT, 170 | focal_alpha=cfg.FOCAL_ALPHA, 171 | focal_gamma=cfg.FOCAL_GAMMA) -------------------------------------------------------------------------------- /adet/modeling/testr/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from adet.layers.deformable_transformer import DeformableTransformer 7 | 8 | from adet.layers.pos_encoding import PositionalEncoding1D 9 | from adet.utils.misc import NestedTensor, inverse_sigmoid_offset, nested_tensor_from_tensor_list, sigmoid_offset 10 | 11 | class MLP(nn.Module): 12 | """ Very simple multi-layer perceptron (also called FFN)""" 13 | 14 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 15 | super().__init__() 16 | self.num_layers = num_layers 17 | h = [hidden_dim] * (num_layers - 1) 18 | self.layers = nn.ModuleList(nn.Linear(n, k) 19 | for n, k in zip([input_dim] + h, h + [output_dim])) 20 | 21 | def forward(self, x): 22 | for i, layer in enumerate(self.layers): 23 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 24 | return x 25 | 26 | class TESTR(nn.Module): 27 | """ 28 | Same as :class:`detectron2.modeling.ProposalNetwork`. 29 | Use one stage detector and a second stage for instance-wise prediction. 30 | """ 31 | def __init__(self, cfg, backbone): 32 | super().__init__() 33 | self.device = torch.device(cfg.MODEL.DEVICE) 34 | 35 | self.backbone = backbone 36 | 37 | # fmt: off 38 | self.d_model = cfg.MODEL.TRANSFORMER.HIDDEN_DIM 39 | self.nhead = cfg.MODEL.TRANSFORMER.NHEADS 40 | self.num_encoder_layers = cfg.MODEL.TRANSFORMER.ENC_LAYERS 41 | self.num_decoder_layers = cfg.MODEL.TRANSFORMER.DEC_LAYERS 42 | self.dim_feedforward = cfg.MODEL.TRANSFORMER.DIM_FEEDFORWARD 43 | self.dropout = cfg.MODEL.TRANSFORMER.DROPOUT 44 | self.activation = "relu" 45 | self.return_intermediate_dec = True 46 | self.num_feature_levels = cfg.MODEL.TRANSFORMER.NUM_FEATURE_LEVELS 47 | self.dec_n_points = cfg.MODEL.TRANSFORMER.ENC_N_POINTS 48 | self.enc_n_points = cfg.MODEL.TRANSFORMER.DEC_N_POINTS 49 | self.num_proposals = cfg.MODEL.TRANSFORMER.NUM_QUERIES 50 | self.pos_embed_scale = cfg.MODEL.TRANSFORMER.POSITION_EMBEDDING_SCALE 51 | self.num_ctrl_points = cfg.MODEL.TRANSFORMER.NUM_CTRL_POINTS 52 | self.num_classes = 1 53 | self.max_text_len = cfg.MODEL.TRANSFORMER.NUM_CHARS 54 | self.voc_size = cfg.MODEL.TRANSFORMER.VOC_SIZE 55 | self.sigmoid_offset = not cfg.MODEL.TRANSFORMER.USE_POLYGON 56 | 57 | self.text_pos_embed = PositionalEncoding1D(self.d_model, normalize=True, scale=self.pos_embed_scale) 58 | # fmt: on 59 | 60 | self.transformer = DeformableTransformer( 61 | d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.num_encoder_layers, 62 | num_decoder_layers=self.num_decoder_layers, dim_feedforward=self.dim_feedforward, 63 | dropout=self.dropout, activation=self.activation, return_intermediate_dec=self.return_intermediate_dec, 64 | num_feature_levels=self.num_feature_levels, dec_n_points=self.dec_n_points, 65 | enc_n_points=self.enc_n_points, num_proposals=self.num_proposals, 66 | ) 67 | self.ctrl_point_class = nn.Linear(self.d_model, self.num_classes) 68 | self.ctrl_point_coord = MLP(self.d_model, self.d_model, 2, 3) 69 | self.bbox_coord = MLP(self.d_model, self.d_model, 4, 3) 70 | self.bbox_class = nn.Linear(self.d_model, self.num_classes) 71 | self.text_class = nn.Linear(self.d_model, self.voc_size + 1) 72 | 73 | # shared prior between instances (objects) 74 | self.ctrl_point_embed = nn.Embedding(self.num_ctrl_points, self.d_model) 75 | self.text_embed = nn.Embedding(self.max_text_len, self.d_model) 76 | 77 | 78 | if self.num_feature_levels > 1: 79 | strides = [8, 16, 32] 80 | num_channels = [512, 1024, 2048] 81 | num_backbone_outs = len(strides) 82 | input_proj_list = [] 83 | for _ in range(num_backbone_outs): 84 | in_channels = num_channels[_] 85 | input_proj_list.append(nn.Sequential( 86 | nn.Conv2d(in_channels, self.d_model, kernel_size=1), 87 | nn.GroupNorm(32, self.d_model), 88 | )) 89 | for _ in range(self.num_feature_levels - num_backbone_outs): 90 | input_proj_list.append(nn.Sequential( 91 | nn.Conv2d(in_channels, self.d_model, 92 | kernel_size=3, stride=2, padding=1), 93 | nn.GroupNorm(32, self.d_model), 94 | )) 95 | in_channels = self.d_model 96 | self.input_proj = nn.ModuleList(input_proj_list) 97 | else: 98 | strides = [32] 99 | num_channels = [2048] 100 | self.input_proj = nn.ModuleList([ 101 | nn.Sequential( 102 | nn.Conv2d( 103 | num_channels[0], self.d_model, kernel_size=1), 104 | nn.GroupNorm(32, self.d_model), 105 | )]) 106 | self.aux_loss = cfg.MODEL.TRANSFORMER.AUX_LOSS 107 | 108 | prior_prob = 0.01 109 | bias_value = -np.log((1 - prior_prob) / prior_prob) 110 | self.ctrl_point_class.bias.data = torch.ones(self.num_classes) * bias_value 111 | self.bbox_class.bias.data = torch.ones(self.num_classes) * bias_value 112 | nn.init.constant_(self.ctrl_point_coord.layers[-1].weight.data, 0) 113 | nn.init.constant_(self.ctrl_point_coord.layers[-1].bias.data, 0) 114 | for proj in self.input_proj: 115 | nn.init.xavier_uniform_(proj[0].weight, gain=1) 116 | nn.init.constant_(proj[0].bias, 0) 117 | 118 | num_pred = self.num_decoder_layers 119 | self.ctrl_point_class = nn.ModuleList( 120 | [self.ctrl_point_class for _ in range(num_pred)]) 121 | self.ctrl_point_coord = nn.ModuleList( 122 | [self.ctrl_point_coord for _ in range(num_pred)]) 123 | self.transformer.decoder.bbox_embed = None 124 | 125 | nn.init.constant_(self.bbox_coord.layers[-1].bias.data[2:], 0.0) 126 | self.transformer.bbox_class_embed = self.bbox_class 127 | self.transformer.bbox_embed = self.bbox_coord 128 | 129 | self.to(self.device) 130 | 131 | 132 | def forward(self, samples: NestedTensor): 133 | """ The forward expects a NestedTensor, which consists of: 134 | - samples.tensor: batched images, of shape [batch_size x 3 x H x W] 135 | - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels 136 | It returns a dict with the following elements: 137 | - "pred_logits": the classification logits (including no-object) for all queries. 138 | Shape= [batch_size x num_queries x (num_classes + 1)] 139 | - "pred_keypoints": The normalized keypoint coordinates for all queries, represented as 140 | (x, y). These values are normalized in [0, 1], 141 | relative to the size of each individual image (disregarding possible padding). 142 | See PostProcess for information on how to retrieve the unnormalized bounding box. 143 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of 144 | dictionnaries containing the two above keys for each decoder layer. 145 | """ 146 | if isinstance(samples, (list, torch.Tensor)): 147 | samples = nested_tensor_from_tensor_list(samples) 148 | features, pos = self.backbone(samples) 149 | 150 | if self.num_feature_levels == 1: 151 | features = [features[-1]] 152 | pos = [pos[-1]] 153 | 154 | srcs = [] 155 | masks = [] 156 | for l, feat in enumerate(features): 157 | src, mask = feat.decompose() 158 | srcs.append(self.input_proj[l](src)) 159 | masks.append(mask) 160 | assert mask is not None 161 | if self.num_feature_levels > len(srcs): 162 | _len_srcs = len(srcs) 163 | for l in range(_len_srcs, self.num_feature_levels): 164 | if l == _len_srcs: 165 | src = self.input_proj[l](features[-1].tensors) 166 | else: 167 | src = self.input_proj[l](srcs[-1]) 168 | m = masks[0] 169 | mask = F.interpolate( 170 | m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] 171 | pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) 172 | srcs.append(src) 173 | masks.append(mask) 174 | pos.append(pos_l) 175 | 176 | # n_points, embed_dim --> n_objects, n_points, embed_dim 177 | ctrl_point_embed = self.ctrl_point_embed.weight[None, ...].repeat(self.num_proposals, 1, 1) 178 | text_pos_embed = self.text_pos_embed(self.text_embed.weight)[None, ...].repeat(self.num_proposals, 1, 1) 179 | text_embed = self.text_embed.weight[None, ...].repeat(self.num_proposals, 1, 1) 180 | 181 | hs, hs_text, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer( 182 | srcs, masks, pos, ctrl_point_embed, text_embed, text_pos_embed, text_mask=None) 183 | 184 | outputs_classes = [] 185 | outputs_coords = [] 186 | outputs_texts = [] 187 | for lvl in range(hs.shape[0]): 188 | if lvl == 0: 189 | reference = init_reference 190 | else: 191 | reference = inter_references[lvl - 1] 192 | reference = inverse_sigmoid_offset(reference, offset=self.sigmoid_offset) 193 | outputs_class = self.ctrl_point_class[lvl](hs[lvl]) 194 | tmp = self.ctrl_point_coord[lvl](hs[lvl]) 195 | if reference.shape[-1] == 2: 196 | tmp += reference[:, :, None, :] 197 | else: 198 | assert reference.shape[-1] == 4 199 | tmp += reference[:, :, None, :2] 200 | outputs_texts.append(self.text_class(hs_text[lvl])) 201 | outputs_coord = sigmoid_offset(tmp, offset=self.sigmoid_offset) 202 | outputs_classes.append(outputs_class) 203 | outputs_coords.append(outputs_coord) 204 | outputs_class = torch.stack(outputs_classes) 205 | outputs_coord = torch.stack(outputs_coords) 206 | outputs_text = torch.stack(outputs_texts) 207 | 208 | out = {'pred_logits': outputs_class[-1], 209 | 'pred_ctrl_points': outputs_coord[-1], 210 | 'pred_texts': outputs_text[-1]} 211 | if self.aux_loss: 212 | out['aux_outputs'] = self._set_aux_loss( 213 | outputs_class, outputs_coord, outputs_text) 214 | 215 | enc_outputs_coord = enc_outputs_coord_unact.sigmoid() 216 | out['enc_outputs'] = { 217 | 'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord} 218 | return out 219 | 220 | @torch.jit.unused 221 | def _set_aux_loss(self, outputs_class, outputs_coord, outputs_text): 222 | # this is a workaround to make torchscript happy, as torchscript 223 | # doesn't support dictionary with non-homogeneous values, such 224 | # as a dict having both a Tensor and a list. 225 | return [{'pred_logits': a, 'pred_ctrl_points': b, 'pred_texts': c} 226 | for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1], outputs_text[:-1])] -------------------------------------------------------------------------------- /adet/modeling/transformer_detector.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY 8 | from detectron2.modeling import build_backbone 9 | from detectron2.modeling.postprocessing import detector_postprocess as d2_postprocesss 10 | from detectron2.structures import ImageList, Instances 11 | 12 | from adet.layers.pos_encoding import PositionalEncoding2D 13 | from adet.modeling.testr.losses import SetCriterion 14 | from adet.modeling.testr.matcher import build_matcher 15 | from adet.modeling.testr.models import TESTR 16 | from adet.utils.misc import NestedTensor, box_xyxy_to_cxcywh 17 | 18 | 19 | class Joiner(nn.Sequential): 20 | def __init__(self, backbone, position_embedding): 21 | super().__init__(backbone, position_embedding) 22 | 23 | def forward(self, tensor_list: NestedTensor): 24 | xs = self[0](tensor_list) 25 | out: List[NestedTensor] = [] 26 | pos = [] 27 | for _, x in xs.items(): 28 | out.append(x) 29 | # position encoding 30 | pos.append(self[1](x).to(x.tensors.dtype)) 31 | 32 | return out, pos 33 | 34 | class MaskedBackbone(nn.Module): 35 | """ This is a thin wrapper around D2's backbone to provide padding masking""" 36 | def __init__(self, cfg): 37 | super().__init__() 38 | self.backbone = build_backbone(cfg) 39 | backbone_shape = self.backbone.output_shape() 40 | self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()] 41 | self.num_channels = backbone_shape[list(backbone_shape.keys())[-1]].channels 42 | 43 | def forward(self, images): 44 | features = self.backbone(images.tensor) 45 | masks = self.mask_out_padding( 46 | [features_per_level.shape for features_per_level in features.values()], 47 | images.image_sizes, 48 | images.tensor.device, 49 | ) 50 | assert len(features) == len(masks) 51 | for i, k in enumerate(features.keys()): 52 | features[k] = NestedTensor(features[k], masks[i]) 53 | return features 54 | 55 | def mask_out_padding(self, feature_shapes, image_sizes, device): 56 | masks = [] 57 | assert len(feature_shapes) == len(self.feature_strides) 58 | for idx, shape in enumerate(feature_shapes): 59 | N, _, H, W = shape 60 | masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device) 61 | for img_idx, (h, w) in enumerate(image_sizes): 62 | masks_per_feature_level[ 63 | img_idx, 64 | : int(np.ceil(float(h) / self.feature_strides[idx])), 65 | : int(np.ceil(float(w) / self.feature_strides[idx])), 66 | ] = 0 67 | masks.append(masks_per_feature_level) 68 | return masks 69 | 70 | 71 | def detector_postprocess(results, output_height, output_width, mask_threshold=0.5): 72 | """ 73 | In addition to the post processing of detectron2, we add scalign for 74 | bezier control points. 75 | """ 76 | scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0]) 77 | # results = d2_postprocesss(results, output_height, output_width, mask_threshold) 78 | 79 | # scale bezier points 80 | if results.has("beziers"): 81 | beziers = results.beziers 82 | # scale and clip in place 83 | h, w = results.image_size 84 | beziers[:, 0].clamp_(min=0, max=w) 85 | beziers[:, 1].clamp_(min=0, max=h) 86 | beziers[:, 6].clamp_(min=0, max=w) 87 | beziers[:, 7].clamp_(min=0, max=h) 88 | beziers[:, 8].clamp_(min=0, max=w) 89 | beziers[:, 9].clamp_(min=0, max=h) 90 | beziers[:, 14].clamp_(min=0, max=w) 91 | beziers[:, 15].clamp_(min=0, max=h) 92 | beziers[:, 0::2] *= scale_x 93 | beziers[:, 1::2] *= scale_y 94 | 95 | if results.has("polygons"): 96 | polygons = results.polygons 97 | polygons[:, 0::2] *= scale_x 98 | polygons[:, 1::2] *= scale_y 99 | 100 | return results 101 | 102 | 103 | @META_ARCH_REGISTRY.register() 104 | class TransformerDetector(nn.Module): 105 | """ 106 | Same as :class:`detectron2.modeling.ProposalNetwork`. 107 | Use one stage detector and a second stage for instance-wise prediction. 108 | """ 109 | def __init__(self, cfg): 110 | super().__init__() 111 | self.device = torch.device(cfg.MODEL.DEVICE) 112 | 113 | d2_backbone = MaskedBackbone(cfg) 114 | N_steps = cfg.MODEL.TRANSFORMER.HIDDEN_DIM // 2 115 | self.test_score_threshold = cfg.MODEL.TRANSFORMER.INFERENCE_TH_TEST 116 | self.use_polygon = cfg.MODEL.TRANSFORMER.USE_POLYGON 117 | backbone = Joiner(d2_backbone, PositionalEncoding2D(N_steps, normalize=True)) 118 | backbone.num_channels = d2_backbone.num_channels 119 | self.testr = TESTR(cfg, backbone) 120 | 121 | box_matcher, point_matcher = build_matcher(cfg) 122 | 123 | loss_cfg = cfg.MODEL.TRANSFORMER.LOSS 124 | weight_dict = {'loss_ce': loss_cfg.POINT_CLASS_WEIGHT, 'loss_ctrl_points': loss_cfg.POINT_COORD_WEIGHT, 'loss_texts': loss_cfg.POINT_TEXT_WEIGHT} 125 | enc_weight_dict = {'loss_bbox': loss_cfg.BOX_COORD_WEIGHT, 'loss_giou': loss_cfg.BOX_GIOU_WEIGHT, 'loss_ce': loss_cfg.BOX_CLASS_WEIGHT} 126 | if loss_cfg.AUX_LOSS: 127 | aux_weight_dict = {} 128 | # decoder aux loss 129 | for i in range(cfg.MODEL.TRANSFORMER.DEC_LAYERS - 1): 130 | aux_weight_dict.update( 131 | {k + f'_{i}': v for k, v in weight_dict.items()}) 132 | # encoder aux loss 133 | aux_weight_dict.update( 134 | {k + f'_enc': v for k, v in enc_weight_dict.items()}) 135 | weight_dict.update(aux_weight_dict) 136 | 137 | enc_losses = ['labels', 'boxes'] 138 | dec_losses = ['labels', 'ctrl_points', 'texts'] 139 | 140 | self.criterion = SetCriterion(self.testr.num_classes, box_matcher, point_matcher, 141 | weight_dict, enc_losses, dec_losses, self.testr.num_ctrl_points, 142 | focal_alpha=loss_cfg.FOCAL_ALPHA, focal_gamma=loss_cfg.FOCAL_GAMMA) 143 | 144 | pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) 145 | pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) 146 | self.normalizer = lambda x: (x - pixel_mean) / pixel_std 147 | self.to(self.device) 148 | 149 | def preprocess_image(self, batched_inputs): 150 | """ 151 | Normalize, pad and batch the input images. 152 | """ 153 | images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs] 154 | images = ImageList.from_tensors(images) 155 | return images 156 | 157 | def forward(self, batched_inputs): 158 | """ 159 | Args: 160 | batched_inputs: a list, batched outputs of :class:`DatasetMapper` . 161 | Each item in the list contains the inputs for one image. 162 | For now, each item in the list is a dict that contains: 163 | 164 | * image: Tensor, image in (C, H, W) format. 165 | * instances (optional): groundtruth :class:`Instances` 166 | * proposals (optional): :class:`Instances`, precomputed proposals. 167 | 168 | Other information that's included in the original dicts, such as: 169 | 170 | * "height", "width" (int): the output resolution of the model, used in inference. 171 | See :meth:`postprocess` for details. 172 | 173 | Returns: 174 | list[dict]: 175 | Each dict is the output for one input image. 176 | The dict contains one key "instances" whose value is a :class:`Instances`. 177 | The :class:`Instances` object has the following keys: 178 | "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" 179 | """ 180 | images = self.preprocess_image(batched_inputs) 181 | output = self.testr(images) 182 | 183 | if self.training: 184 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 185 | targets = self.prepare_targets(gt_instances) 186 | loss_dict = self.criterion(output, targets) 187 | weight_dict = self.criterion.weight_dict 188 | for k in loss_dict.keys(): 189 | if k in weight_dict: 190 | loss_dict[k] *= weight_dict[k] 191 | return loss_dict 192 | else: 193 | ctrl_point_cls = output["pred_logits"] 194 | ctrl_point_coord = output["pred_ctrl_points"] 195 | text_pred = output["pred_texts"] 196 | results = self.inference(ctrl_point_cls, ctrl_point_coord, text_pred, images.image_sizes) 197 | processed_results = [] 198 | for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): 199 | height = input_per_image.get("height", image_size[0]) 200 | width = input_per_image.get("width", image_size[1]) 201 | r = detector_postprocess(results_per_image, height, width) 202 | processed_results.append({"instances": r}) 203 | return processed_results 204 | 205 | def prepare_targets(self, targets): 206 | new_targets = [] 207 | for targets_per_image in targets: 208 | h, w = targets_per_image.image_size 209 | image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) 210 | gt_classes = targets_per_image.gt_classes 211 | gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy 212 | gt_boxes = box_xyxy_to_cxcywh(gt_boxes) 213 | raw_ctrl_points = targets_per_image.polygons if self.use_polygon else targets_per_image.beziers 214 | gt_ctrl_points = raw_ctrl_points.reshape(-1, self.testr.num_ctrl_points, 2) / torch.as_tensor([w, h], dtype=torch.float, device=self.device)[None, None, :] 215 | gt_text = targets_per_image.text 216 | new_targets.append({"labels": gt_classes, "boxes": gt_boxes, "ctrl_points": gt_ctrl_points, "texts": gt_text}) 217 | return new_targets 218 | 219 | def inference(self, ctrl_point_cls, ctrl_point_coord, text_pred, image_sizes): 220 | assert len(ctrl_point_cls) == len(image_sizes) 221 | results = [] 222 | 223 | text_pred = torch.softmax(text_pred, dim=-1) 224 | prob = ctrl_point_cls.mean(-2).sigmoid() 225 | scores, labels = prob.max(-1) 226 | 227 | for scores_per_image, labels_per_image, ctrl_point_per_image, text_per_image, image_size in zip( 228 | scores, labels, ctrl_point_coord, text_pred, image_sizes 229 | ): 230 | selector = scores_per_image >= self.test_score_threshold 231 | scores_per_image = scores_per_image[selector] 232 | labels_per_image = labels_per_image[selector] 233 | ctrl_point_per_image = ctrl_point_per_image[selector] 234 | text_per_image = text_per_image[selector] 235 | result = Instances(image_size) 236 | result.scores = scores_per_image 237 | result.pred_classes = labels_per_image 238 | result.rec_scores = text_per_image 239 | ctrl_point_per_image[..., 0] *= image_size[1] 240 | ctrl_point_per_image[..., 1] *= image_size[0] 241 | if self.use_polygon: 242 | result.polygons = ctrl_point_per_image.flatten(1) 243 | else: 244 | result.beziers = ctrl_point_per_image.flatten(1) 245 | _, topi = text_per_image.topk(1) 246 | result.recs = topi.squeeze(-1) 247 | results.append(result) 248 | return results 249 | -------------------------------------------------------------------------------- /adet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlpc-ucsd/TESTR/f369f138e041d1d27348a1f6600e456452001d23/adet/utils/__init__.py -------------------------------------------------------------------------------- /adet/utils/comm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.distributed as dist 4 | 5 | from detectron2.utils.comm import get_world_size 6 | 7 | 8 | def reduce_sum(tensor): 9 | world_size = get_world_size() 10 | if world_size < 2: 11 | return tensor 12 | tensor = tensor.clone() 13 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 14 | return tensor 15 | 16 | 17 | def reduce_mean(tensor): 18 | num_gpus = get_world_size() 19 | total = reduce_sum(tensor) 20 | return total.float() / num_gpus 21 | 22 | 23 | def aligned_bilinear(tensor, factor): 24 | assert tensor.dim() == 4 25 | assert factor >= 1 26 | assert int(factor) == factor 27 | 28 | if factor == 1: 29 | return tensor 30 | 31 | h, w = tensor.size()[2:] 32 | tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode="replicate") 33 | oh = factor * h + 1 34 | ow = factor * w + 1 35 | tensor = F.interpolate( 36 | tensor, size=(oh, ow), 37 | mode='bilinear', 38 | align_corners=True 39 | ) 40 | tensor = F.pad( 41 | tensor, pad=(factor // 2, 0, factor // 2, 0), 42 | mode="replicate" 43 | ) 44 | 45 | return tensor[:, :, :oh - 1, :ow - 1] 46 | 47 | 48 | def compute_locations(h, w, stride, device): 49 | shifts_x = torch.arange( 50 | 0, w * stride, step=stride, 51 | dtype=torch.float32, device=device 52 | ) 53 | shifts_y = torch.arange( 54 | 0, h * stride, step=stride, 55 | dtype=torch.float32, device=device 56 | ) 57 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 58 | shift_x = shift_x.reshape(-1) 59 | shift_y = shift_y.reshape(-1) 60 | locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 61 | return locations 62 | 63 | 64 | def compute_ious(pred, target): 65 | """ 66 | Args: 67 | pred: Nx4 predicted bounding boxes 68 | target: Nx4 target bounding boxes 69 | Both are in the form of FCOS prediction (l, t, r, b) 70 | """ 71 | pred_left = pred[:, 0] 72 | pred_top = pred[:, 1] 73 | pred_right = pred[:, 2] 74 | pred_bottom = pred[:, 3] 75 | 76 | target_left = target[:, 0] 77 | target_top = target[:, 1] 78 | target_right = target[:, 2] 79 | target_bottom = target[:, 3] 80 | 81 | target_aera = (target_left + target_right) * \ 82 | (target_top + target_bottom) 83 | pred_aera = (pred_left + pred_right) * \ 84 | (pred_top + pred_bottom) 85 | 86 | w_intersect = torch.min(pred_left, target_left) + \ 87 | torch.min(pred_right, target_right) 88 | h_intersect = torch.min(pred_bottom, target_bottom) + \ 89 | torch.min(pred_top, target_top) 90 | 91 | g_w_intersect = torch.max(pred_left, target_left) + \ 92 | torch.max(pred_right, target_right) 93 | g_h_intersect = torch.max(pred_bottom, target_bottom) + \ 94 | torch.max(pred_top, target_top) 95 | ac_uion = g_w_intersect * g_h_intersect 96 | 97 | area_intersect = w_intersect * h_intersect 98 | area_union = target_aera + pred_aera - area_intersect 99 | 100 | ious = (area_intersect + 1.0) / (area_union + 1.0) 101 | gious = ious - (ac_uion - area_union) / ac_uion 102 | 103 | return ious, gious 104 | -------------------------------------------------------------------------------- /adet/utils/misc.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import torch 3 | from torch.functional import Tensor 4 | from torchvision.ops.boxes import box_area 5 | import torch.distributed as dist 6 | 7 | 8 | def is_dist_avail_and_initialized(): 9 | if not dist.is_available(): 10 | return False 11 | if not dist.is_initialized(): 12 | return False 13 | return True 14 | 15 | 16 | @torch.no_grad() 17 | def accuracy(output, target, topk=(1,)): 18 | """Computes the precision@k for the specified values of k""" 19 | if target.numel() == 0: 20 | return [torch.zeros([], device=output.device)] 21 | if target.ndim == 2: 22 | assert output.ndim == 3 23 | output = output.mean(1) 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | 27 | _, pred = output.topk(maxk, -1) 28 | pred = pred.t() 29 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].view(-1).float().sum(0) 34 | res.append(correct_k.mul_(100.0 / batch_size)) 35 | return res 36 | 37 | 38 | def box_cxcywh_to_xyxy(x): 39 | x_c, y_c, w, h = x.unbind(-1) 40 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 41 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 42 | return torch.stack(b, dim=-1) 43 | 44 | 45 | def box_xyxy_to_cxcywh(x): 46 | x0, y0, x1, y1 = x.unbind(-1) 47 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 48 | (x1 - x0), (y1 - y0)] 49 | return torch.stack(b, dim=-1) 50 | 51 | 52 | # modified from torchvision to also return the union 53 | def box_iou(boxes1, boxes2): 54 | area1 = box_area(boxes1) 55 | area2 = box_area(boxes2) 56 | 57 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 58 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 59 | 60 | wh = (rb - lt).clamp(min=0) # [N,M,2] 61 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 62 | 63 | union = area1[:, None] + area2 - inter 64 | 65 | iou = inter / union 66 | return iou, union 67 | 68 | 69 | def generalized_box_iou(boxes1, boxes2): 70 | """ 71 | Generalized IoU from https://giou.stanford.edu/ 72 | The boxes should be in [x0, y0, x1, y1] format 73 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 74 | and M = len(boxes2) 75 | """ 76 | # degenerate boxes gives inf / nan results 77 | # so do an early check 78 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 79 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 80 | iou, union = box_iou(boxes1, boxes2) 81 | 82 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 83 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 84 | 85 | wh = (rb - lt).clamp(min=0) # [N,M,2] 86 | area = wh[:, :, 0] * wh[:, :, 1] 87 | 88 | return iou - (area - union) / area 89 | 90 | 91 | def masks_to_boxes(masks): 92 | """Compute the bounding boxes around the provided masks 93 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 94 | Returns a [N, 4] tensors, with the boxes in xyxy format 95 | """ 96 | if masks.numel() == 0: 97 | return torch.zeros((0, 4), device=masks.device) 98 | 99 | h, w = masks.shape[-2:] 100 | 101 | y = torch.arange(0, h, dtype=torch.float) 102 | x = torch.arange(0, w, dtype=torch.float) 103 | y, x = torch.meshgrid(y, x) 104 | 105 | x_mask = (masks * x.unsqueeze(0)) 106 | x_max = x_mask.flatten(1).max(-1)[0] 107 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 108 | 109 | y_mask = (masks * y.unsqueeze(0)) 110 | y_max = y_mask.flatten(1).max(-1)[0] 111 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 112 | 113 | return torch.stack([x_min, y_min, x_max, y_max], 1) 114 | 115 | def inverse_sigmoid(x, eps=1e-5): 116 | x = x.clamp(min=0, max=1) 117 | x1 = x.clamp(min=eps) 118 | x2 = (1 - x).clamp(min=eps) 119 | return torch.log(x1/x2) 120 | 121 | def sigmoid_offset(x, offset=True): 122 | # modified sigmoid for range [-0.5, 1.5] 123 | if offset: 124 | return x.sigmoid() * 2 - 0.5 125 | else: 126 | return x.sigmoid() 127 | 128 | def inverse_sigmoid_offset(x, eps=1e-5, offset=True): 129 | if offset: 130 | x = (x + 0.5) / 2.0 131 | return inverse_sigmoid(x, eps) 132 | 133 | def _max_by_axis(the_list): 134 | # type: (List[List[int]]) -> List[int] 135 | maxes = the_list[0] 136 | for sublist in the_list[1:]: 137 | for index, item in enumerate(sublist): 138 | maxes[index] = max(maxes[index], item) 139 | return maxes 140 | 141 | 142 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 143 | # TODO make this more general 144 | if tensor_list[0].ndim == 3: 145 | # TODO make it support different-sized images 146 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 147 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 148 | batch_shape = [len(tensor_list)] + max_size 149 | b, c, h, w = batch_shape 150 | dtype = tensor_list[0].dtype 151 | device = tensor_list[0].device 152 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 153 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 154 | for img, pad_img, m in zip(tensor_list, tensor, mask): 155 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 156 | m[: img.shape[1], :img.shape[2]] = False 157 | else: 158 | raise ValueError('not supported') 159 | return NestedTensor(tensor, mask) 160 | 161 | 162 | class NestedTensor(object): 163 | def __init__(self, tensors, mask: Optional[Tensor]): 164 | self.tensors = tensors 165 | self.mask = mask 166 | 167 | def to(self, device): 168 | # type: (Device) -> NestedTensor # noqa 169 | cast_tensor = self.tensors.to(device) 170 | mask = self.mask 171 | if mask is not None: 172 | assert mask is not None 173 | cast_mask = mask.to(device) 174 | else: 175 | cast_mask = None 176 | return NestedTensor(cast_tensor, cast_mask) 177 | 178 | def decompose(self): 179 | return self.tensors, self.mask 180 | 181 | def __repr__(self): 182 | return str(self.tensors) 183 | -------------------------------------------------------------------------------- /adet/utils/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from detectron2.utils.visualizer import Visualizer 4 | import matplotlib.colors as mplc 5 | import matplotlib.font_manager as mfm 6 | 7 | class TextVisualizer(Visualizer): 8 | def __init__(self, image, metadata, instance_mode, cfg): 9 | Visualizer.__init__(self, image, metadata, instance_mode=instance_mode) 10 | self.voc_size = cfg.MODEL.BATEXT.VOC_SIZE 11 | self.use_customer_dictionary = cfg.MODEL.BATEXT.CUSTOM_DICT 12 | self.use_polygon = cfg.MODEL.TRANSFORMER.USE_POLYGON 13 | if not self.use_customer_dictionary: 14 | self.CTLABELS = [' ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~'] 15 | else: 16 | with open(self.use_customer_dictionary, 'rb') as fp: 17 | self.CTLABELS = pickle.load(fp) 18 | assert(int(self.voc_size - 1) == len(self.CTLABELS)), "voc_size is not matched dictionary size, got {} and {}.".format(int(self.voc_size - 1), len(self.CTLABELS)) 19 | 20 | def draw_instance_predictions(self, predictions): 21 | if self.use_polygon: 22 | ctrl_pnts = predictions.polygons.numpy() 23 | else: 24 | ctrl_pnts = predictions.beziers.numpy() 25 | scores = predictions.scores.tolist() 26 | recs = predictions.recs 27 | 28 | self.overlay_instances(ctrl_pnts, recs, scores) 29 | 30 | return self.output 31 | 32 | def _ctrl_pnt_to_poly(self, pnt): 33 | if self.use_polygon: 34 | points = pnt.reshape(-1, 2) 35 | else: 36 | # bezier to polygon 37 | u = np.linspace(0, 1, 20) 38 | pnt = pnt.reshape(2, 4, 2).transpose(0, 2, 1).reshape(4, 4) 39 | points = np.outer((1 - u) ** 3, pnt[:, 0]) \ 40 | + np.outer(3 * u * ((1 - u) ** 2), pnt[:, 1]) \ 41 | + np.outer(3 * (u ** 2) * (1 - u), pnt[:, 2]) \ 42 | + np.outer(u ** 3, pnt[:, 3]) 43 | points = np.concatenate((points[:, :2], points[:, 2:]), axis=0) 44 | 45 | return points 46 | 47 | def _decode_recognition(self, rec): 48 | s = '' 49 | for c in rec: 50 | c = int(c) 51 | if c < self.voc_size - 1: 52 | if self.voc_size == 96: 53 | s += self.CTLABELS[c] 54 | else: 55 | s += str(chr(self.CTLABELS[c])) 56 | elif c == self.voc_size -1: 57 | s += u'口' 58 | return s 59 | 60 | def _ctc_decode_recognition(self, rec): 61 | # ctc decoding 62 | last_char = False 63 | s = '' 64 | for c in rec: 65 | c = int(c) 66 | if c < self.voc_size - 1: 67 | if last_char != c: 68 | if self.voc_size == 96: 69 | s += self.CTLABELS[c] 70 | last_char = c 71 | else: 72 | s += str(chr(self.CTLABELS[c])) 73 | last_char = c 74 | elif c == self.voc_size -1: 75 | s += u'口' 76 | else: 77 | last_char = False 78 | return s 79 | 80 | def overlay_instances(self, ctrl_pnts, recs, scores, alpha=0.5): 81 | color = (0.1, 0.2, 0.5) 82 | 83 | for ctrl_pnt, rec, score in zip(ctrl_pnts, recs, scores): 84 | polygon = self._ctrl_pnt_to_poly(ctrl_pnt) 85 | self.draw_polygon(polygon, color, alpha=alpha) 86 | 87 | # draw text in the top left corner 88 | text = self._decode_recognition(rec) 89 | text = "{:.3f}: {}".format(score, text) 90 | lighter_color = self._change_color_brightness(color, brightness_factor=0.7) 91 | text_pos = polygon[0] 92 | horiz_align = "left" 93 | font_size = self._default_font_size 94 | 95 | self.draw_text( 96 | text, 97 | text_pos, 98 | color=lighter_color, 99 | horizontal_alignment=horiz_align, 100 | font_size=font_size, 101 | draw_chinese=False if self.voc_size == 96 else True 102 | ) 103 | 104 | 105 | def draw_text( 106 | self, 107 | text, 108 | position, 109 | *, 110 | font_size=None, 111 | color="g", 112 | horizontal_alignment="center", 113 | rotation=0, 114 | draw_chinese=False 115 | ): 116 | """ 117 | Args: 118 | text (str): class label 119 | position (tuple): a tuple of the x and y coordinates to place text on image. 120 | font_size (int, optional): font of the text. If not provided, a font size 121 | proportional to the image width is calculated and used. 122 | color: color of the text. Refer to `matplotlib.colors` for full list 123 | of formats that are accepted. 124 | horizontal_alignment (str): see `matplotlib.text.Text` 125 | rotation: rotation angle in degrees CCW 126 | Returns: 127 | output (VisImage): image object with text drawn. 128 | """ 129 | if not font_size: 130 | font_size = self._default_font_size 131 | 132 | # since the text background is dark, we don't want the text to be dark 133 | color = np.maximum(list(mplc.to_rgb(color)), 0.2) 134 | color[np.argmax(color)] = max(0.8, np.max(color)) 135 | 136 | x, y = position 137 | if draw_chinese: 138 | font_path = "./simsun.ttc" 139 | prop = mfm.FontProperties(fname=font_path) 140 | self.output.ax.text( 141 | x, 142 | y, 143 | text, 144 | size=font_size * self.output.scale, 145 | family="sans-serif", 146 | bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, 147 | verticalalignment="top", 148 | horizontalalignment=horizontal_alignment, 149 | color=color, 150 | zorder=10, 151 | rotation=rotation, 152 | fontproperties=prop 153 | ) 154 | else: 155 | self.output.ax.text( 156 | x, 157 | y, 158 | text, 159 | size=font_size * self.output.scale, 160 | family="sans-serif", 161 | bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, 162 | verticalalignment="top", 163 | horizontalalignment=horizontal_alignment, 164 | color=color, 165 | zorder=10, 166 | rotation=rotation, 167 | ) 168 | return self.output -------------------------------------------------------------------------------- /configs/TESTR/Base-TESTR.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "TransformerDetector" 3 | MASK_ON: False 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | BACKBONE: 7 | NAME: "build_resnet_backbone" 8 | RESNETS: 9 | STRIDE_IN_1X1: False 10 | OUT_FEATURES: ["res3", "res4", "res5"] 11 | TRANSFORMER: 12 | ENABLED: True 13 | INFERENCE_TH_TEST: 0.45 14 | SOLVER: 15 | WEIGHT_DECAY: 1e-4 16 | OPTIMIZER: "ADAMW" 17 | LR_BACKBONE_NAMES: ['backbone.0'] 18 | LR_LINEAR_PROJ_NAMES: ['reference_points', 'sampling_offsets'] 19 | LR_LINEAR_PROJ_MULT: 0.1 20 | CLIP_GRADIENTS: 21 | ENABLED: True 22 | CLIP_TYPE: "full_model" 23 | CLIP_VALUE: 0.1 24 | NORM_TYPE: 2.0 25 | INPUT: 26 | HFLIP_TRAIN: False 27 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896) 28 | MAX_SIZE_TRAIN: 1600 29 | MIN_SIZE_TEST: 1600 30 | MAX_SIZE_TEST: 1824 31 | CROP: 32 | ENABLED: True 33 | CROP_INSTANCE: False 34 | SIZE: [0.1, 0.1] 35 | FORMAT: "RGB" -------------------------------------------------------------------------------- /configs/TESTR/CTW1500/Base-CTW1500-Polygon.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-TESTR.yaml" 2 | MODEL: 3 | TRANSFORMER: 4 | INFERENCE_TH_TEST: 0.7 5 | NUM_CHARS: 100 6 | USE_POLYGON: True 7 | NUM_CTRL_POINTS: 16 8 | LOSS: 9 | POINT_TEXT_WEIGHT: 4.0 10 | DATASETS: 11 | TRAIN: ("ctw1500_word_poly_train",) 12 | TEST: ("ctw1500_word_poly_test",) 13 | INPUT: 14 | MIN_SIZE_TEST: 1000 15 | MAX_SIZE_TRAIN: 1333 16 | TEST: 17 | USE_LEXICON: False 18 | LEXICON_TYPE: 1 19 | 20 | -------------------------------------------------------------------------------- /configs/TESTR/CTW1500/Base-CTW1500.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-TESTR.yaml" 2 | MODEL: 3 | TRANSFORMER: 4 | INFERENCE_TH_TEST: 0.6 5 | NUM_CHARS: 100 6 | DATASETS: 7 | TRAIN: ("ctw1500_word_train",) 8 | TEST: ("ctw1500_word_test",) 9 | INPUT: 10 | MIN_SIZE_TEST: 1000 11 | MAX_SIZE_TRAIN: 1333 12 | TEST: 13 | USE_LEXICON: False 14 | LEXICON_TYPE: 1 15 | 16 | -------------------------------------------------------------------------------- /configs/TESTR/CTW1500/TESTR_R_50.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-CTW1500.yaml" 2 | MODEL: 3 | WEIGHTS: "weights/TESTR/pretrain_testr_R_50.pth" 4 | RESNETS: 5 | DEPTH: 50 6 | TRANSFORMER: 7 | NUM_FEATURE_LEVELS: 4 8 | ENC_LAYERS: 6 9 | DEC_LAYERS: 6 10 | DIM_FEEDFORWARD: 1024 11 | HIDDEN_DIM: 256 12 | DROPOUT: 0.1 13 | NHEADS: 8 14 | NUM_QUERIES: 100 15 | ENC_N_POINTS: 4 16 | DEC_N_POINTS: 4 17 | SOLVER: 18 | STEPS: (300000,) 19 | IMS_PER_BATCH: 8 20 | BASE_LR: 2e-5 21 | LR_BACKBONE: 2e-6 22 | WARMUP_ITERS: 0 23 | MAX_ITER: 200000 24 | CHECKPOINT_PERIOD: 10000 25 | TEST: 26 | EVAL_PERIOD: 10000 27 | OUTPUT_DIR: "output/TESTR/ctw1500/TESTR_R_50" 28 | -------------------------------------------------------------------------------- /configs/TESTR/CTW1500/TESTR_R_50_Polygon.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-CTW1500-Polygon.yaml" 2 | MODEL: 3 | WEIGHTS: "weights/TESTR/pretrain_testr_R_50_polygon.pth" 4 | RESNETS: 5 | DEPTH: 50 6 | TRANSFORMER: 7 | NUM_FEATURE_LEVELS: 4 8 | ENC_LAYERS: 6 9 | DEC_LAYERS: 6 10 | DIM_FEEDFORWARD: 1024 11 | HIDDEN_DIM: 256 12 | DROPOUT: 0.1 13 | NHEADS: 8 14 | NUM_QUERIES: 100 15 | ENC_N_POINTS: 4 16 | DEC_N_POINTS: 4 17 | SOLVER: 18 | STEPS: (300000,) 19 | IMS_PER_BATCH: 8 20 | BASE_LR: 1e-5 21 | LR_BACKBONE: 1e-6 22 | WARMUP_ITERS: 0 23 | MAX_ITER: 200000 24 | CHECKPOINT_PERIOD: 10000 25 | TEST: 26 | EVAL_PERIOD: 10000 27 | OUTPUT_DIR: "output/TESTR/ctw1500/TESTR_R_50_Polygon" 28 | -------------------------------------------------------------------------------- /configs/TESTR/ICDAR15/Base-ICDAR15-Polygon.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-TESTR.yaml" 2 | DATASETS: 3 | TRAIN: ("icdar2015_train",) 4 | TEST: ("icdar2015_test",) 5 | MODEL: 6 | TRANSFORMER: 7 | USE_POLYGON: True 8 | NUM_CTRL_POINTS: 16 9 | LOSS: 10 | POINT_TEXT_WEIGHT: 4.0 11 | TEST: 12 | USE_LEXICON: True 13 | LEXICON_TYPE: 3 14 | WEIGHTED_EDIT_DIST: True 15 | INPUT: 16 | MIN_SIZE_TRAIN: (800, 832, 864, 896, 1000, 1200, 1400) 17 | MAX_SIZE_TRAIN: 2333 18 | MIN_SIZE_TEST: 1440 19 | MAX_SIZE_TEST: 4000 20 | -------------------------------------------------------------------------------- /configs/TESTR/ICDAR15/TESTR_R_50_Polygon.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-ICDAR15-Polygon.yaml" 2 | MODEL: 3 | WEIGHTS: "weights/TESTR/pretrain_testr_R_50_polygon.pth" 4 | RESNETS: 5 | DEPTH: 50 6 | TRANSFORMER: 7 | NUM_FEATURE_LEVELS: 4 8 | INFERENCE_TH_TEST: 0.3 9 | ENC_LAYERS: 6 10 | DEC_LAYERS: 6 11 | DIM_FEEDFORWARD: 1024 12 | HIDDEN_DIM: 256 13 | DROPOUT: 0.1 14 | NHEADS: 8 15 | NUM_QUERIES: 100 16 | ENC_N_POINTS: 4 17 | DEC_N_POINTS: 4 18 | SOLVER: 19 | IMS_PER_BATCH: 8 20 | BASE_LR: 1e-5 21 | LR_BACKBONE: 1e-6 22 | WARMUP_ITERS: 0 23 | STEPS: (200000,) 24 | MAX_ITER: 20000 25 | CHECKPOINT_PERIOD: 1000 26 | TEST: 27 | EVAL_PERIOD: 1000 28 | OUTPUT_DIR: "output/TESTR/icdar15/TESTR_R_50_Polygon" 29 | -------------------------------------------------------------------------------- /configs/TESTR/Pretrain/Base-Pretrain-Polygon.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-TESTR.yaml" 2 | DATASETS: 3 | TRAIN: ("mltbezier_word_poly_train", "totaltext_poly_train", "syntext1_poly_train", "syntext2_poly_train",) 4 | TEST: ("totaltext_poly_val",) 5 | MODEL: 6 | TRANSFORMER: 7 | USE_POLYGON: True 8 | NUM_CTRL_POINTS: 16 9 | LOSS: 10 | POINT_TEXT_WEIGHT: 4.0 11 | -------------------------------------------------------------------------------- /configs/TESTR/Pretrain/Base-Pretrain.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-TESTR.yaml" 2 | DATASETS: 3 | TRAIN: ("mltbezier_word_train", "totaltext_train", "syntext1_train", "syntext2_train",) 4 | TEST: ("totaltext_val",) 5 | -------------------------------------------------------------------------------- /configs/TESTR/Pretrain/TESTR_R_50.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-Pretrain.yaml" 2 | MODEL: 3 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 4 | RESNETS: 5 | DEPTH: 50 6 | TRANSFORMER: 7 | NUM_FEATURE_LEVELS: 4 8 | ENC_LAYERS: 6 9 | DEC_LAYERS: 6 10 | DIM_FEEDFORWARD: 1024 11 | HIDDEN_DIM: 256 12 | DROPOUT: 0.1 13 | NHEADS: 8 14 | NUM_QUERIES: 100 15 | ENC_N_POINTS: 4 16 | DEC_N_POINTS: 4 17 | SOLVER: 18 | IMS_PER_BATCH: 8 19 | BASE_LR: 2e-4 20 | LR_BACKBONE: 2e-5 21 | WARMUP_ITERS: 0 22 | STEPS: (340000,) 23 | MAX_ITER: 440000 24 | CHECKPOINT_PERIOD: 10000 25 | TEST: 26 | EVAL_PERIOD: 10000 27 | OUTPUT_DIR: "output/TESTR/pretrain/TESTR_R_50" 28 | -------------------------------------------------------------------------------- /configs/TESTR/Pretrain/TESTR_R_50_Polygon.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-Pretrain-Polygon.yaml" 2 | MODEL: 3 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 4 | RESNETS: 5 | DEPTH: 50 6 | TRANSFORMER: 7 | NUM_FEATURE_LEVELS: 4 8 | ENC_LAYERS: 6 9 | DEC_LAYERS: 6 10 | DIM_FEEDFORWARD: 1024 11 | HIDDEN_DIM: 256 12 | DROPOUT: 0.1 13 | NHEADS: 8 14 | NUM_QUERIES: 100 15 | ENC_N_POINTS: 4 16 | DEC_N_POINTS: 4 17 | SOLVER: 18 | IMS_PER_BATCH: 8 19 | BASE_LR: 1e-4 20 | LR_BACKBONE: 1e-5 21 | WARMUP_ITERS: 0 22 | STEPS: (340000,) 23 | MAX_ITER: 440000 24 | CHECKPOINT_PERIOD: 10000 25 | TEST: 26 | EVAL_PERIOD: 10000 27 | OUTPUT_DIR: "output/TESTR/pretrain/TESTR_R_50_Polygon" 28 | -------------------------------------------------------------------------------- /configs/TESTR/TotalText/Base-TotalText-Polygon.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-TESTR.yaml" 2 | DATASETS: 3 | TRAIN: ("totaltext_poly_train",) 4 | TEST: ("totaltext_poly_val",) 5 | MODEL: 6 | TRANSFORMER: 7 | USE_POLYGON: True 8 | NUM_CTRL_POINTS: 16 9 | LOSS: 10 | POINT_TEXT_WEIGHT: 4.0 11 | TEST: 12 | USE_LEXICON: False 13 | LEXICON_TYPE: 1 -------------------------------------------------------------------------------- /configs/TESTR/TotalText/Base-TotalText.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-TESTR.yaml" 2 | DATASETS: 3 | TRAIN: ("totaltext_train",) 4 | TEST: ("totaltext_val",) 5 | TEST: 6 | USE_LEXICON: False 7 | LEXICON_TYPE: 1 8 | -------------------------------------------------------------------------------- /configs/TESTR/TotalText/TESTR_R_50.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-TotalText.yaml" 2 | MODEL: 3 | WEIGHTS: "weights/TESTR/pretrain_testr_R_50.pth" 4 | RESNETS: 5 | DEPTH: 50 6 | TRANSFORMER: 7 | NUM_FEATURE_LEVELS: 4 8 | ENC_LAYERS: 6 9 | DEC_LAYERS: 6 10 | DIM_FEEDFORWARD: 1024 11 | HIDDEN_DIM: 256 12 | DROPOUT: 0.1 13 | NHEADS: 8 14 | NUM_QUERIES: 100 15 | ENC_N_POINTS: 4 16 | DEC_N_POINTS: 4 17 | SOLVER: 18 | IMS_PER_BATCH: 8 19 | BASE_LR: 2e-5 20 | LR_BACKBONE: 2e-6 21 | WARMUP_ITERS: 0 22 | STEPS: (200000,) 23 | MAX_ITER: 20000 24 | CHECKPOINT_PERIOD: 1000 25 | TEST: 26 | EVAL_PERIOD: 1000 27 | OUTPUT_DIR: "output/TESTR/totaltext/TESTR_R_50" 28 | -------------------------------------------------------------------------------- /configs/TESTR/TotalText/TESTR_R_50_Polygon.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-TotalText-Polygon.yaml" 2 | MODEL: 3 | WEIGHTS: "weights/TESTR/pretrain_testr_R_50_polygon.pth" 4 | RESNETS: 5 | DEPTH: 50 6 | TRANSFORMER: 7 | NUM_FEATURE_LEVELS: 4 8 | ENC_LAYERS: 6 9 | DEC_LAYERS: 6 10 | DIM_FEEDFORWARD: 1024 11 | HIDDEN_DIM: 256 12 | DROPOUT: 0.1 13 | NHEADS: 8 14 | NUM_QUERIES: 100 15 | ENC_N_POINTS: 4 16 | DEC_N_POINTS: 4 17 | SOLVER: 18 | IMS_PER_BATCH: 8 19 | BASE_LR: 1e-5 20 | LR_BACKBONE: 1e-6 21 | WARMUP_ITERS: 0 22 | STEPS: (200000,) 23 | MAX_ITER: 20000 24 | CHECKPOINT_PERIOD: 1000 25 | TEST: 26 | EVAL_PERIOD: 1000 27 | OUTPUT_DIR: "output/TESTR/totaltext/TESTR_R_50_Polygon" 28 | -------------------------------------------------------------------------------- /demo/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import argparse 3 | import glob 4 | import multiprocessing as mp 5 | import os 6 | import time 7 | import cv2 8 | import tqdm 9 | 10 | from detectron2.data.detection_utils import read_image 11 | from detectron2.utils.logger import setup_logger 12 | 13 | from predictor import VisualizationDemo 14 | from adet.config import get_cfg 15 | 16 | # constants 17 | WINDOW_NAME = "COCO detections" 18 | 19 | 20 | def setup_cfg(args): 21 | # load config from file and command-line arguments 22 | cfg = get_cfg() 23 | cfg.merge_from_file(args.config_file) 24 | cfg.merge_from_list(args.opts) 25 | # Set score_threshold for builtin models 26 | cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold 27 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold 28 | cfg.MODEL.FCOS.INFERENCE_TH_TEST = args.confidence_threshold 29 | cfg.MODEL.MEInst.INFERENCE_TH_TEST = args.confidence_threshold 30 | cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold 31 | cfg.freeze() 32 | return cfg 33 | 34 | 35 | def get_parser(): 36 | parser = argparse.ArgumentParser(description="Detectron2 Demo") 37 | parser.add_argument( 38 | "--config-file", 39 | default="configs/quick_schedules/e2e_mask_rcnn_R_50_FPN_inference_acc_test.yaml", 40 | metavar="FILE", 41 | help="path to config file", 42 | ) 43 | parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") 44 | parser.add_argument("--video-input", help="Path to video file.") 45 | parser.add_argument("--input", nargs="+", help="A list of space separated input images") 46 | parser.add_argument( 47 | "--output", 48 | help="A file or directory to save output visualizations. " 49 | "If not given, will show output in an OpenCV window.", 50 | ) 51 | 52 | parser.add_argument( 53 | "--confidence-threshold", 54 | type=float, 55 | default=0.3, 56 | help="Minimum score for instance predictions to be shown", 57 | ) 58 | parser.add_argument( 59 | "--opts", 60 | help="Modify config options using the command-line 'KEY VALUE' pairs", 61 | default=[], 62 | nargs=argparse.REMAINDER, 63 | ) 64 | return parser 65 | 66 | 67 | if __name__ == "__main__": 68 | mp.set_start_method("spawn", force=True) 69 | args = get_parser().parse_args() 70 | logger = setup_logger() 71 | logger.info("Arguments: " + str(args)) 72 | 73 | cfg = setup_cfg(args) 74 | 75 | demo = VisualizationDemo(cfg) 76 | 77 | if args.input: 78 | if os.path.isdir(args.input[0]): 79 | args.input = [os.path.join(args.input[0], fname) for fname in os.listdir(args.input[0])] 80 | elif len(args.input) == 1: 81 | args.input = glob.glob(os.path.expanduser(args.input[0])) 82 | assert args.input, "The input path(s) was not found" 83 | for path in tqdm.tqdm(args.input, disable=not args.output): 84 | # use PIL, to be consistent with evaluation 85 | img = read_image(path, format="BGR") 86 | start_time = time.time() 87 | predictions, visualized_output = demo.run_on_image(img) 88 | logger.info( 89 | "{}: detected {} instances in {:.2f}s".format( 90 | path, len(predictions["instances"]), time.time() - start_time 91 | ) 92 | ) 93 | 94 | if args.output: 95 | if os.path.isdir(args.output): 96 | assert os.path.isdir(args.output), args.output 97 | out_filename = os.path.join(args.output, os.path.basename(path)) 98 | else: 99 | assert len(args.input) == 1, "Please specify a directory with args.output" 100 | out_filename = args.output 101 | visualized_output.save(out_filename) 102 | else: 103 | cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) 104 | if cv2.waitKey(0) == 27: 105 | break # esc to quit 106 | elif args.webcam: 107 | assert args.input is None, "Cannot have both --input and --webcam!" 108 | cam = cv2.VideoCapture(0) 109 | for vis in tqdm.tqdm(demo.run_on_video(cam)): 110 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 111 | cv2.imshow(WINDOW_NAME, vis) 112 | if cv2.waitKey(1) == 27: 113 | break # esc to quit 114 | cv2.destroyAllWindows() 115 | elif args.video_input: 116 | video = cv2.VideoCapture(args.video_input) 117 | width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) 118 | height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 119 | frames_per_second = video.get(cv2.CAP_PROP_FPS) 120 | num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 121 | basename = os.path.basename(args.video_input) 122 | 123 | if args.output: 124 | if os.path.isdir(args.output): 125 | output_fname = os.path.join(args.output, basename) 126 | output_fname = os.path.splitext(output_fname)[0] + ".mkv" 127 | else: 128 | output_fname = args.output 129 | assert not os.path.isfile(output_fname), output_fname 130 | output_file = cv2.VideoWriter( 131 | filename=output_fname, 132 | # some installation of opencv may not support x264 (due to its license), 133 | # you can try other format (e.g. MPEG) 134 | fourcc=cv2.VideoWriter_fourcc(*"x264"), 135 | fps=float(frames_per_second), 136 | frameSize=(width, height), 137 | isColor=True, 138 | ) 139 | assert os.path.isfile(args.video_input) 140 | for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames): 141 | if args.output: 142 | output_file.write(vis_frame) 143 | else: 144 | cv2.namedWindow(basename, cv2.WINDOW_NORMAL) 145 | cv2.imshow(basename, vis_frame) 146 | if cv2.waitKey(1) == 27: 147 | break # esc to quit 148 | video.release() 149 | if args.output: 150 | output_file.release() 151 | else: 152 | cv2.destroyAllWindows() 153 | -------------------------------------------------------------------------------- /demo/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | import atexit 4 | import bisect 5 | import multiprocessing as mp 6 | from collections import deque 7 | import cv2 8 | import torch 9 | import matplotlib.pyplot as plt 10 | 11 | from detectron2.data import MetadataCatalog 12 | from detectron2.engine.defaults import DefaultPredictor 13 | from detectron2.utils.video_visualizer import VideoVisualizer 14 | from detectron2.utils.visualizer import ColorMode, Visualizer 15 | 16 | from adet.utils.visualizer import TextVisualizer 17 | 18 | 19 | class VisualizationDemo(object): 20 | def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): 21 | """ 22 | Args: 23 | cfg (CfgNode): 24 | instance_mode (ColorMode): 25 | parallel (bool): whether to run the model in different processes from visualization. 26 | Useful since the visualization logic can be slow. 27 | """ 28 | self.metadata = MetadataCatalog.get( 29 | cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" 30 | ) 31 | self.cfg = cfg 32 | self.cpu_device = torch.device("cpu") 33 | self.instance_mode = instance_mode 34 | self.vis_text = cfg.MODEL.ROI_HEADS.NAME == "TextHead" or cfg.MODEL.TRANSFORMER.ENABLED 35 | 36 | self.parallel = parallel 37 | if parallel: 38 | num_gpu = torch.cuda.device_count() 39 | self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu) 40 | else: 41 | self.predictor = DefaultPredictor(cfg) 42 | 43 | def run_on_image(self, image): 44 | """ 45 | Args: 46 | image (np.ndarray): an image of shape (H, W, C) (in BGR order). 47 | This is the format used by OpenCV. 48 | 49 | Returns: 50 | predictions (dict): the output of the model. 51 | vis_output (VisImage): the visualized image output. 52 | """ 53 | vis_output = None 54 | predictions = self.predictor(image) 55 | # Convert image from OpenCV BGR format to Matplotlib RGB format. 56 | image = image[:, :, ::-1] 57 | if self.vis_text: 58 | visualizer = TextVisualizer(image, self.metadata, instance_mode=self.instance_mode, cfg=self.cfg) 59 | else: 60 | visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) 61 | 62 | if "bases" in predictions: 63 | self.vis_bases(predictions["bases"]) 64 | if "panoptic_seg" in predictions: 65 | panoptic_seg, segments_info = predictions["panoptic_seg"] 66 | vis_output = visualizer.draw_panoptic_seg_predictions( 67 | panoptic_seg.to(self.cpu_device), segments_info 68 | ) 69 | else: 70 | if "sem_seg" in predictions: 71 | vis_output = visualizer.draw_sem_seg( 72 | predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)) 73 | if "instances" in predictions: 74 | instances = predictions["instances"].to(self.cpu_device) 75 | vis_output = visualizer.draw_instance_predictions(predictions=instances) 76 | 77 | return predictions, vis_output 78 | 79 | def _frame_from_video(self, video): 80 | while video.isOpened(): 81 | success, frame = video.read() 82 | if success: 83 | yield frame 84 | else: 85 | break 86 | 87 | def vis_bases(self, bases): 88 | basis_colors = [[2, 200, 255], [107, 220, 255], [30, 200, 255], [60, 220, 255]] 89 | bases = bases[0].squeeze() 90 | bases = (bases / 8).tanh().cpu().numpy() 91 | num_bases = len(bases) 92 | fig, axes = plt.subplots(nrows=num_bases // 2, ncols=2) 93 | for i, basis in enumerate(bases): 94 | basis = (basis + 1) / 2 95 | basis = basis / basis.max() 96 | basis_viz = np.zeros((basis.shape[0], basis.shape[1], 3), dtype=np.uint8) 97 | basis_viz[:, :, 0] = basis_colors[i][0] 98 | basis_viz[:, :, 1] = basis_colors[i][1] 99 | basis_viz[:, :, 2] = np.uint8(basis * 255) 100 | basis_viz = cv2.cvtColor(basis_viz, cv2.COLOR_HSV2RGB) 101 | axes[i // 2][i % 2].imshow(basis_viz) 102 | plt.show() 103 | 104 | def run_on_video(self, video): 105 | """ 106 | Visualizes predictions on frames of the input video. 107 | 108 | Args: 109 | video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be 110 | either a webcam or a video file. 111 | 112 | Yields: 113 | ndarray: BGR visualizations of each video frame. 114 | """ 115 | video_visualizer = VideoVisualizer(self.metadata, self.instance_mode) 116 | 117 | def process_predictions(frame, predictions): 118 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 119 | if "panoptic_seg" in predictions: 120 | panoptic_seg, segments_info = predictions["panoptic_seg"] 121 | vis_frame = video_visualizer.draw_panoptic_seg_predictions( 122 | frame, panoptic_seg.to(self.cpu_device), segments_info 123 | ) 124 | elif "instances" in predictions: 125 | predictions = predictions["instances"].to(self.cpu_device) 126 | vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) 127 | elif "sem_seg" in predictions: 128 | vis_frame = video_visualizer.draw_sem_seg( 129 | frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) 130 | ) 131 | 132 | # Converts Matplotlib RGB format to OpenCV BGR format 133 | vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) 134 | return vis_frame 135 | 136 | frame_gen = self._frame_from_video(video) 137 | if self.parallel: 138 | buffer_size = self.predictor.default_buffer_size 139 | 140 | frame_data = deque() 141 | 142 | for cnt, frame in enumerate(frame_gen): 143 | frame_data.append(frame) 144 | self.predictor.put(frame) 145 | 146 | if cnt >= buffer_size: 147 | frame = frame_data.popleft() 148 | predictions = self.predictor.get() 149 | yield process_predictions(frame, predictions) 150 | 151 | while len(frame_data): 152 | frame = frame_data.popleft() 153 | predictions = self.predictor.get() 154 | yield process_predictions(frame, predictions) 155 | else: 156 | for frame in frame_gen: 157 | yield process_predictions(frame, self.predictor(frame)) 158 | 159 | 160 | class AsyncPredictor: 161 | """ 162 | A predictor that runs the model asynchronously, possibly on >1 GPUs. 163 | Because rendering the visualization takes considerably amount of time, 164 | this helps improve throughput when rendering videos. 165 | """ 166 | 167 | class _StopToken: 168 | pass 169 | 170 | class _PredictWorker(mp.Process): 171 | def __init__(self, cfg, task_queue, result_queue): 172 | self.cfg = cfg 173 | self.task_queue = task_queue 174 | self.result_queue = result_queue 175 | super().__init__() 176 | 177 | def run(self): 178 | predictor = DefaultPredictor(self.cfg) 179 | 180 | while True: 181 | task = self.task_queue.get() 182 | if isinstance(task, AsyncPredictor._StopToken): 183 | break 184 | idx, data = task 185 | result = predictor(data) 186 | self.result_queue.put((idx, result)) 187 | 188 | def __init__(self, cfg, num_gpus: int = 1): 189 | """ 190 | Args: 191 | cfg (CfgNode): 192 | num_gpus (int): if 0, will run on CPU 193 | """ 194 | num_workers = max(num_gpus, 1) 195 | self.task_queue = mp.Queue(maxsize=num_workers * 3) 196 | self.result_queue = mp.Queue(maxsize=num_workers * 3) 197 | self.procs = [] 198 | for gpuid in range(max(num_gpus, 1)): 199 | cfg = cfg.clone() 200 | cfg.defrost() 201 | cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu" 202 | self.procs.append( 203 | AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) 204 | ) 205 | 206 | self.put_idx = 0 207 | self.get_idx = 0 208 | self.result_rank = [] 209 | self.result_data = [] 210 | 211 | for p in self.procs: 212 | p.start() 213 | atexit.register(self.shutdown) 214 | 215 | def put(self, image): 216 | self.put_idx += 1 217 | self.task_queue.put((self.put_idx, image)) 218 | 219 | def get(self): 220 | self.get_idx += 1 # the index needed for this request 221 | if len(self.result_rank) and self.result_rank[0] == self.get_idx: 222 | res = self.result_data[0] 223 | del self.result_data[0], self.result_rank[0] 224 | return res 225 | 226 | while True: 227 | # make sure the results are returned in the correct order 228 | idx, res = self.result_queue.get() 229 | if idx == self.get_idx: 230 | return res 231 | insert = bisect.bisect(self.result_rank, idx) 232 | self.result_rank.insert(insert, idx) 233 | self.result_data.insert(insert, res) 234 | 235 | def __len__(self): 236 | return self.put_idx - self.get_idx 237 | 238 | def __call__(self, image): 239 | self.put(image) 240 | return self.get() 241 | 242 | def shutdown(self): 243 | for _ in self.procs: 244 | self.task_queue.put(AsyncPredictor._StopToken()) 245 | 246 | @property 247 | def default_buffer_size(self): 248 | return len(self.procs) * 5 249 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import glob 5 | import os 6 | from setuptools import find_packages, setup 7 | import torch 8 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 9 | 10 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] 11 | assert torch_ver >= [1, 3], "Requires PyTorch >= 1.3" 12 | 13 | 14 | def get_extensions(): 15 | this_dir = os.path.dirname(os.path.abspath(__file__)) 16 | extensions_dir = os.path.join(this_dir, "adet", "layers", "csrc") 17 | 18 | main_source = os.path.join(extensions_dir, "vision.cpp") 19 | sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) 20 | source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob( 21 | os.path.join(extensions_dir, "*.cu") 22 | ) 23 | 24 | sources = [main_source] + sources 25 | 26 | extension = CppExtension 27 | 28 | extra_compile_args = {"cxx": []} 29 | define_macros = [] 30 | 31 | if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": 32 | extension = CUDAExtension 33 | sources += source_cuda 34 | define_macros += [("WITH_CUDA", None)] 35 | extra_compile_args["nvcc"] = [ 36 | "-DCUDA_HAS_FP16=1", 37 | "-D__CUDA_NO_HALF_OPERATORS__", 38 | "-D__CUDA_NO_HALF_CONVERSIONS__", 39 | "-D__CUDA_NO_HALF2_OPERATORS__", 40 | ] 41 | 42 | if torch_ver < [1, 7]: 43 | # supported by https://github.com/pytorch/pytorch/pull/43931 44 | CC = os.environ.get("CC", None) 45 | if CC is not None: 46 | extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) 47 | 48 | sources = [os.path.join(extensions_dir, s) for s in sources] 49 | 50 | include_dirs = [extensions_dir] 51 | 52 | ext_modules = [ 53 | extension( 54 | "adet._C", 55 | sources, 56 | include_dirs=include_dirs, 57 | define_macros=define_macros, 58 | extra_compile_args=extra_compile_args, 59 | ) 60 | ] 61 | 62 | return ext_modules 63 | 64 | 65 | setup( 66 | name="AdelaiDet", 67 | version="0.2.0", 68 | author="Adelaide Intelligent Machines", 69 | url="https://github.com/stanstarks/AdelaiDet", 70 | description="AdelaiDet is AIM's research " 71 | "platform for instance-level detection tasks based on Detectron2.", 72 | packages=find_packages(exclude=("configs", "tests")), 73 | python_requires=">=3.6", 74 | install_requires=[ 75 | "termcolor>=1.1", 76 | "Pillow>=6.0", 77 | "yacs>=0.1.6", 78 | "tabulate", 79 | "cloudpickle", 80 | "matplotlib", 81 | "tqdm>4.29.0", 82 | "tensorboard", 83 | "rapidfuzz", 84 | "Polygon3", 85 | "shapely", 86 | "scikit-image", 87 | "editdistance", 88 | "opencv-python", 89 | "numba", 90 | ], 91 | extras_require={"all": ["psutil"]}, 92 | ext_modules=get_extensions(), 93 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 94 | ) 95 | -------------------------------------------------------------------------------- /tools/train_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Detection Training Script. 4 | 5 | This scripts reads a given config file and runs the training or evaluation. 6 | It is an entry point that is made to train standard models in detectron2. 7 | 8 | In order to let one script support training of many models, 9 | this script contains logic that are specific to these built-in models and therefore 10 | may not be suitable for your own project. 11 | For example, your research project perhaps only needs a single "evaluator". 12 | 13 | Therefore, we recommend you to use detectron2 as an library and take 14 | this file as an example of how to use the library. 15 | You may want to write your own script with your datasets and other customizations. 16 | """ 17 | 18 | import logging 19 | import os 20 | from collections import OrderedDict 21 | from typing import Any, Dict, List, Set 22 | import torch 23 | import itertools 24 | from torch.nn.parallel import DistributedDataParallel 25 | 26 | import detectron2.utils.comm as comm 27 | from detectron2.data import MetadataCatalog, build_detection_train_loader 28 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch 29 | from detectron2.utils.events import EventStorage 30 | from detectron2.evaluation import ( 31 | COCOEvaluator, 32 | COCOPanopticEvaluator, 33 | DatasetEvaluators, 34 | LVISEvaluator, 35 | PascalVOCDetectionEvaluator, 36 | SemSegEvaluator, 37 | verify_results, 38 | ) 39 | from detectron2.solver.build import maybe_add_gradient_clipping 40 | from detectron2.modeling import GeneralizedRCNNWithTTA 41 | from detectron2.utils.logger import setup_logger 42 | 43 | from adet.data.dataset_mapper import DatasetMapperWithBasis 44 | from adet.config import get_cfg 45 | from adet.checkpoint import AdetCheckpointer 46 | from adet.evaluation import TextEvaluator 47 | 48 | 49 | class Trainer(DefaultTrainer): 50 | """ 51 | This is the same Trainer except that we rewrite the 52 | `build_train_loader`/`resume_or_load` method. 53 | """ 54 | def build_hooks(self): 55 | """ 56 | Replace `DetectionCheckpointer` with `AdetCheckpointer`. 57 | 58 | Build a list of default hooks, including timing, evaluation, 59 | checkpointing, lr scheduling, precise BN, writing events. 60 | """ 61 | ret = super().build_hooks() 62 | for i in range(len(ret)): 63 | if isinstance(ret[i], hooks.PeriodicCheckpointer): 64 | self.checkpointer = AdetCheckpointer( 65 | self.model, 66 | self.cfg.OUTPUT_DIR, 67 | optimizer=self.optimizer, 68 | scheduler=self.scheduler, 69 | ) 70 | ret[i] = hooks.PeriodicCheckpointer(self.checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD) 71 | return ret 72 | 73 | def resume_or_load(self, resume=True): 74 | checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume) 75 | if resume and self.checkpointer.has_checkpoint(): 76 | self.start_iter = checkpoint.get("iteration", -1) + 1 77 | 78 | def train_loop(self, start_iter: int, max_iter: int): 79 | """ 80 | Args: 81 | start_iter, max_iter (int): See docs above 82 | """ 83 | logger = logging.getLogger("adet.trainer") 84 | logger.info("Starting training from iteration {}".format(start_iter)) 85 | 86 | self.iter = self.start_iter = start_iter 87 | self.max_iter = max_iter 88 | 89 | with EventStorage(start_iter) as self.storage: 90 | self.before_train() 91 | for self.iter in range(start_iter, max_iter): 92 | self.before_step() 93 | self.run_step() 94 | self.after_step() 95 | self.after_train() 96 | 97 | def train(self): 98 | """ 99 | Run training. 100 | 101 | Returns: 102 | OrderedDict of results, if evaluation is enabled. Otherwise None. 103 | """ 104 | self.train_loop(self.start_iter, self.max_iter) 105 | if hasattr(self, "_last_eval_results") and comm.is_main_process(): 106 | verify_results(self.cfg, self._last_eval_results) 107 | return self._last_eval_results 108 | 109 | @classmethod 110 | def build_train_loader(cls, cfg): 111 | """ 112 | Returns: 113 | iterable 114 | 115 | It calls :func:`detectron2.data.build_detection_train_loader` with a customized 116 | DatasetMapper, which adds categorical labels as a semantic mask. 117 | """ 118 | mapper = DatasetMapperWithBasis(cfg, True) 119 | return build_detection_train_loader(cfg, mapper=mapper) 120 | 121 | @classmethod 122 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 123 | """ 124 | Create evaluator(s) for a given dataset. 125 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 126 | For your own dataset, you can simply create an evaluator manually in your 127 | script and do not have to worry about the hacky if-else logic here. 128 | """ 129 | if output_folder is None: 130 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 131 | evaluator_list = [] 132 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 133 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: 134 | evaluator_list.append( 135 | SemSegEvaluator( 136 | dataset_name, 137 | distributed=True, 138 | num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 139 | ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 140 | output_dir=output_folder, 141 | ) 142 | ) 143 | if evaluator_type in ["coco", "coco_panoptic_seg"]: 144 | evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder)) 145 | if evaluator_type == "coco_panoptic_seg": 146 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 147 | if evaluator_type == "pascal_voc": 148 | return PascalVOCDetectionEvaluator(dataset_name) 149 | if evaluator_type == "lvis": 150 | return LVISEvaluator(dataset_name, cfg, True, output_folder) 151 | if evaluator_type == "text": 152 | return TextEvaluator(dataset_name, cfg, True, output_folder) 153 | if len(evaluator_list) == 0: 154 | raise NotImplementedError( 155 | "no Evaluator for the dataset {} with the type {}".format( 156 | dataset_name, evaluator_type 157 | ) 158 | ) 159 | if len(evaluator_list) == 1: 160 | return evaluator_list[0] 161 | return DatasetEvaluators(evaluator_list) 162 | 163 | @classmethod 164 | def test_with_TTA(cls, cfg, model): 165 | logger = logging.getLogger("adet.trainer") 166 | # In the end of training, run an evaluation with TTA 167 | # Only support some R-CNN models. 168 | logger.info("Running inference with test-time augmentation ...") 169 | model = GeneralizedRCNNWithTTA(cfg, model) 170 | evaluators = [ 171 | cls.build_evaluator( 172 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 173 | ) 174 | for name in cfg.DATASETS.TEST 175 | ] 176 | res = cls.test(cfg, model, evaluators) 177 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 178 | return res 179 | 180 | @classmethod 181 | def build_optimizer(cls, cfg, model): 182 | def match_name_keywords(n, name_keywords): 183 | out = False 184 | for b in name_keywords: 185 | if b in n: 186 | out = True 187 | break 188 | return out 189 | 190 | params: List[Dict[str, Any]] = [] 191 | memo: Set[torch.nn.parameter.Parameter] = set() 192 | for key, value in model.named_parameters(recurse=True): 193 | if not value.requires_grad: 194 | continue 195 | # Avoid duplicating parameters 196 | if value in memo: 197 | continue 198 | memo.add(value) 199 | lr = cfg.SOLVER.BASE_LR 200 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 201 | 202 | if match_name_keywords(key, cfg.SOLVER.LR_BACKBONE_NAMES): 203 | lr = cfg.SOLVER.LR_BACKBONE 204 | elif match_name_keywords(key, cfg.SOLVER.LR_LINEAR_PROJ_NAMES): 205 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.LR_LINEAR_PROJ_MULT 206 | 207 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 208 | 209 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class 210 | # detectron2 doesn't have full model gradient clipping now 211 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE 212 | enable = ( 213 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED 214 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" 215 | and clip_norm_val > 0.0 216 | ) 217 | 218 | class FullModelGradientClippingOptimizer(optim): 219 | def step(self, closure=None): 220 | all_params = itertools.chain(*[x["params"] for x in self.param_groups]) 221 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 222 | super().step(closure=closure) 223 | 224 | return FullModelGradientClippingOptimizer if enable else optim 225 | 226 | optimizer_type = cfg.SOLVER.OPTIMIZER 227 | if optimizer_type == "SGD": 228 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( 229 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM 230 | ) 231 | elif optimizer_type == "ADAMW": 232 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( 233 | params, cfg.SOLVER.BASE_LR 234 | ) 235 | else: 236 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 237 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": 238 | optimizer = maybe_add_gradient_clipping(cfg, optimizer) 239 | return optimizer 240 | 241 | 242 | def setup(args): 243 | """ 244 | Create configs and perform basic setups. 245 | """ 246 | cfg = get_cfg() 247 | cfg.merge_from_file(args.config_file) 248 | cfg.merge_from_list(args.opts) 249 | cfg.freeze() 250 | default_setup(cfg, args) 251 | 252 | rank = comm.get_rank() 253 | setup_logger(cfg.OUTPUT_DIR, distributed_rank=rank, name="adet") 254 | 255 | return cfg 256 | 257 | 258 | def main(args): 259 | cfg = setup(args) 260 | 261 | if args.eval_only: 262 | model = Trainer.build_model(cfg) 263 | AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 264 | cfg.MODEL.WEIGHTS, resume=args.resume 265 | ) 266 | res = Trainer.test(cfg, model) # d2 defaults.py 267 | if comm.is_main_process(): 268 | verify_results(cfg, res) 269 | if cfg.TEST.AUG.ENABLED: 270 | res.update(Trainer.test_with_TTA(cfg, model)) 271 | return res 272 | 273 | """ 274 | If you'd like to do anything fancier than the standard training logic, 275 | consider writing your own training loop or subclassing the trainer. 276 | """ 277 | trainer = Trainer(cfg) 278 | trainer.resume_or_load(resume=args.resume) 279 | if cfg.TEST.AUG.ENABLED: 280 | trainer.register_hooks( 281 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] 282 | ) 283 | return trainer.train() 284 | 285 | 286 | if __name__ == "__main__": 287 | args = default_argument_parser().parse_args() 288 | print("Command Line Args:", args) 289 | launch( 290 | main, 291 | args.num_gpus, 292 | num_machines=args.num_machines, 293 | machine_rank=args.machine_rank, 294 | dist_url=args.dist_url, 295 | args=(args,), 296 | ) 297 | --------------------------------------------------------------------------------